
StdEncoding is used for standard base64 encoding. The base64 alphabet includes '+', '/' and '=' characters which are not URL safe and not JWT compatible. Furthermore padding '=' characters are added, but the JWT defenition want it to be compact. Using RawURLEncoding instead of StdEncoding solve this issues.
322 lines
8.6 KiB
Go
322 lines
8.6 KiB
Go
// Package jwt implements a simple, opinionated net/http-compatible middleware for
|
|
// integrating JSON Web Tokens (JWT).
|
|
package jwt
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
typ = "JWT"
|
|
alg = "HS256"
|
|
)
|
|
|
|
// Errors introduced by this package.
|
|
var (
|
|
ErrMissingConfig = errors.New("missing configuration")
|
|
ErrMissingSecret = errors.New("please provide a shared secret")
|
|
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
|
ErrMissingClaimsFunc = errors.New("please provide a claims function")
|
|
ErrEncoding = errors.New("error encoding value")
|
|
ErrDecoding = errors.New("error decoding value")
|
|
ErrMissingToken = errors.New("please provide a token")
|
|
ErrMalformedToken = errors.New("please provide a valid token")
|
|
ErrInvalidSignature = errors.New("signature could not be verified")
|
|
ErrParsingCredentials = errors.New("error parsing credentials")
|
|
ErrInvalidMethod = errors.New("invalid request method")
|
|
)
|
|
|
|
// AuthFunc is a type for delegating user authentication to the client-code.
|
|
type AuthFunc func(string, string) error
|
|
|
|
// ClaimsFunc is a type for delegating claims generation to the client-code.
|
|
type ClaimsFunc func(string) (map[string]interface{}, error)
|
|
|
|
// VerifyClaimsFunc is a type for processing and validating JWT claims on one
|
|
// or more routes in the client-code.
|
|
type VerifyClaimsFunc func([]byte, *http.Request) error
|
|
|
|
// Config is a container for setting up the JWT middleware.
|
|
type Config struct {
|
|
Secret string
|
|
Auth AuthFunc
|
|
Claims ClaimsFunc
|
|
IdentityField string
|
|
VerifyField string
|
|
}
|
|
|
|
// Middleware is where we store all the specifics related to the client's
|
|
// JWT needs.
|
|
type Middleware struct {
|
|
secret string
|
|
auth AuthFunc
|
|
claims ClaimsFunc
|
|
identityField string
|
|
verifyField string
|
|
}
|
|
|
|
// New creates a new Middleware from a user-specified configuration.
|
|
func New(c *Config) (*Middleware, error) {
|
|
if c == nil {
|
|
return nil, ErrMissingConfig
|
|
}
|
|
if c.Secret == "" {
|
|
return nil, ErrMissingSecret
|
|
}
|
|
if c.Auth == nil {
|
|
return nil, ErrMissingAuthFunc
|
|
}
|
|
if c.Claims == nil {
|
|
return nil, ErrMissingClaimsFunc
|
|
}
|
|
if c.IdentityField == "" {
|
|
c.IdentityField = "email"
|
|
}
|
|
if c.VerifyField == "" {
|
|
c.VerifyField = "password"
|
|
}
|
|
m := &Middleware{
|
|
secret: c.Secret,
|
|
auth: c.Auth,
|
|
claims: c.Claims,
|
|
identityField: c.IdentityField,
|
|
verifyField: c.VerifyField,
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// Secure wraps a client-specified http.Handler with a verification function,
|
|
// as well as-built in parsing of the request's JWT. This allows each handler
|
|
// to have it's own verification/validation protocol.
|
|
func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler {
|
|
secureHandler := func(w http.ResponseWriter, r *http.Request) *jwtError {
|
|
var token string
|
|
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
token = r.FormValue("token")
|
|
if token == "" {
|
|
return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken}
|
|
}
|
|
} else {
|
|
tokenParts := strings.Split(authHeader, " ")
|
|
if len(tokenParts) != 2 {
|
|
return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken}
|
|
}
|
|
token = tokenParts[1]
|
|
}
|
|
|
|
if status, message, err := m.VerifyToken(token, v, r); err != nil {
|
|
return &jwtError{
|
|
status: status,
|
|
message: message,
|
|
err: err,
|
|
}
|
|
}
|
|
|
|
// If we make it this far, process the downstream handler
|
|
h.ServeHTTP(w, r)
|
|
return nil
|
|
}
|
|
return errorHandler(secureHandler)
|
|
}
|
|
|
|
// Authenticate returns a middleware that parsing an incoming request for a JWT,
|
|
// calls the client-supplied auth function, and if successful, returns a JWT to
|
|
// the requester.
|
|
func (m *Middleware) Authenticate() http.Handler {
|
|
generateHandler := func(w http.ResponseWriter, r *http.Request) *jwtError {
|
|
if r.Method != "POST" {
|
|
return &jwtError{
|
|
status: http.StatusBadRequest,
|
|
err: ErrInvalidMethod,
|
|
message: "receiving request",
|
|
}
|
|
}
|
|
|
|
b := make(map[string]string, 0)
|
|
contentType := r.Header.Get("content-type")
|
|
switch contentType {
|
|
case "application/x-www-form-urlencoded", "application/x-www-form-urlencoded; charset=UTF-8":
|
|
identity, verify := r.FormValue(m.identityField), r.FormValue(m.verifyField)
|
|
if identity == "" || verify == "" {
|
|
return &jwtError{
|
|
status: http.StatusInternalServerError,
|
|
err: ErrParsingCredentials,
|
|
message: "parsing authorization",
|
|
}
|
|
}
|
|
b[m.identityField], b[m.verifyField] = identity, verify
|
|
default:
|
|
err := json.NewDecoder(r.Body).Decode(&b)
|
|
if err != nil {
|
|
return &jwtError{
|
|
status: http.StatusInternalServerError,
|
|
err: ErrParsingCredentials,
|
|
message: "parsing authorization",
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if required fields are in the body
|
|
if _, ok := b[m.identityField]; !ok {
|
|
return &jwtError{
|
|
status: http.StatusBadRequest,
|
|
err: ErrParsingCredentials,
|
|
message: "parsing credentials, missing identity field",
|
|
}
|
|
}
|
|
if _, ok := b[m.verifyField]; !ok {
|
|
return &jwtError{
|
|
status: http.StatusBadRequest,
|
|
err: ErrParsingCredentials,
|
|
message: "parsing credentials, missing verify field",
|
|
}
|
|
}
|
|
err := m.auth(b[m.identityField], b[m.verifyField])
|
|
if err != nil {
|
|
return &jwtError{
|
|
status: http.StatusInternalServerError,
|
|
err: err,
|
|
message: "performing authorization",
|
|
}
|
|
}
|
|
response, err := m.CreateToken(b[m.identityField])
|
|
if err != nil {
|
|
return &jwtError{
|
|
status: http.StatusInternalServerError,
|
|
err: err,
|
|
message: response,
|
|
}
|
|
}
|
|
w.Write([]byte(response))
|
|
return nil
|
|
}
|
|
|
|
return errorHandler(generateHandler)
|
|
}
|
|
|
|
// CreateToken generates a token from a user's identity
|
|
func (m *Middleware) CreateToken(identity string) (string, error) {
|
|
// For now, the header will be static
|
|
header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg))
|
|
if err != nil {
|
|
return "encoding header", ErrEncoding
|
|
}
|
|
|
|
// Generate claims for user
|
|
claims, err := m.claims(identity)
|
|
if err != nil {
|
|
return "generating claims", err
|
|
}
|
|
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "mashalling claims", ErrEncoding
|
|
}
|
|
|
|
claimsSet, err := encode(claimsJSON)
|
|
if err != nil {
|
|
return "encoding claims", ErrEncoding
|
|
}
|
|
|
|
toSig := strings.Join([]string{header, claimsSet}, ".")
|
|
|
|
h := hmac.New(sha256.New, []byte(m.secret))
|
|
h.Write([]byte(toSig))
|
|
sig, err := encode(h.Sum(nil))
|
|
if err != nil {
|
|
return "encoding signature", ErrEncoding
|
|
}
|
|
|
|
response := strings.Join([]string{toSig, sig}, ".")
|
|
return response, nil
|
|
}
|
|
|
|
// VerifyToken verifies a token
|
|
func (m *Middleware) VerifyToken(token string, v VerifyClaimsFunc, r *http.Request) (int, string, error) {
|
|
tokenParts := strings.Split(token, ".")
|
|
if len(tokenParts) != 3 {
|
|
return http.StatusUnauthorized, "", ErrMalformedToken
|
|
}
|
|
|
|
// First, verify JOSE header
|
|
header, err := decode(tokenParts[0])
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Sprintf("decoding header (%v)", tokenParts[0]), err
|
|
}
|
|
var t struct {
|
|
Typ string
|
|
Alg string
|
|
}
|
|
err = json.Unmarshal(header, &t)
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Sprintf("unmarshalling header (%s)", header), ErrMalformedToken
|
|
}
|
|
|
|
// Then, verify signature
|
|
mac := hmac.New(sha256.New, []byte(m.secret))
|
|
message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, "."))
|
|
mac.Write(message)
|
|
expectedMac, err := encode(mac.Sum(nil))
|
|
if err != nil {
|
|
return http.StatusInternalServerError, "", err
|
|
}
|
|
if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) {
|
|
return http.StatusUnauthorized, fmt.Sprintf("checking signature (%v)", tokenParts[2]), ErrInvalidSignature
|
|
}
|
|
|
|
// Finally, check claims
|
|
claimSet, err := decode(tokenParts[1])
|
|
if err != nil {
|
|
return http.StatusInternalServerError, "decoding claims", ErrDecoding
|
|
}
|
|
err = v(claimSet, r)
|
|
if err != nil {
|
|
return http.StatusUnauthorized, "handling claims callback", err
|
|
}
|
|
|
|
return 200, "", nil
|
|
}
|
|
|
|
type jwtError struct {
|
|
status int
|
|
message string
|
|
err error
|
|
}
|
|
|
|
type errorHandler func(http.ResponseWriter, *http.Request) *jwtError
|
|
|
|
func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if err := e(w, r); err != nil {
|
|
if err.message != "" {
|
|
log.Printf("error (%v) while %s", err.err, err.message)
|
|
}
|
|
http.Error(w, err.err.Error(), err.status)
|
|
}
|
|
}
|
|
|
|
func encode(s interface{}) (string, error) {
|
|
var r []byte
|
|
switch v := s.(type) {
|
|
case string:
|
|
r = []byte(v)
|
|
case []byte:
|
|
r = v
|
|
default:
|
|
return "", ErrEncoding
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(r), nil
|
|
}
|
|
|
|
func decode(s string) ([]byte, error) {
|
|
return base64.RawURLEncoding.DecodeString(s)
|
|
}
|