// Package jwt implements a simple, opinionated net/http-compatible middleware for // integrating JSON Web Tokens (JWT). package jwt import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "log" "net/http" "strings" ) const ( typ = "JWT" alg = "HS256" ) // Errors introduced by this package. var ( ErrMissingConfig = errors.New("missing configuration") ErrMissingSecret = errors.New("please provide a shared secret") ErrMissingAuthFunc = errors.New("please provide an auth function") ErrMissingClaimsFunc = errors.New("please provide a claims function") ErrEncoding = errors.New("error encoding value") ErrDecoding = errors.New("error decoding value") ErrMissingToken = errors.New("please provide a token") ErrMalformedToken = errors.New("please provide a valid token") ErrInvalidSignature = errors.New("signature could not be verified") ErrParsingCredentials = errors.New("error parsing credentials") ErrInvalidMethod = errors.New("invalid request method") ) // AuthFunc is a type for delegating user authentication to the client-code. type AuthFunc func(string, string) error // ClaimsFunc is a type for delegating claims generation to the client-code. type ClaimsFunc func(string) (map[string]interface{}, error) // VerifyClaimsFunc is a type for processing and validating JWT claims on one // or more routes in the client-code. type VerifyClaimsFunc func([]byte, *http.Request) error // Config is a container for setting up the JWT middleware. type Config struct { Secret string Auth AuthFunc Claims ClaimsFunc IdentityField string VerifyField string } // Middleware is where we store all the specifics related to the client's // JWT needs. type Middleware struct { secret string auth AuthFunc claims ClaimsFunc identityField string verifyField string } // New creates a new Middleware from a user-specified configuration. func New(c *Config) (*Middleware, error) { if c == nil { return nil, ErrMissingConfig } if c.Secret == "" { return nil, ErrMissingSecret } if c.Auth == nil { return nil, ErrMissingAuthFunc } if c.Claims == nil { return nil, ErrMissingClaimsFunc } if c.IdentityField == "" { c.IdentityField = "email" } if c.VerifyField == "" { c.VerifyField = "password" } m := &Middleware{ secret: c.Secret, auth: c.Auth, claims: c.Claims, identityField: c.IdentityField, verifyField: c.VerifyField, } return m, nil } // Secure wraps a client-specified http.Handler with a verification function, // as well as-built in parsing of the request's JWT. This allows each handler // to have it's own verification/validation protocol. func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { secureHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { var token string authHeader := r.Header.Get("Authorization") if authHeader == "" { token = r.FormValue("token") if token == "" { return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken} } } else { tokenParts := strings.Split(authHeader, " ") if len(tokenParts) != 2 { return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken} } token = tokenParts[1] } if status, message, err := m.VerifyToken(token, v, r); err != nil { return &jwtError{ status: status, message: message, err: err, } } // If we make it this far, process the downstream handler h.ServeHTTP(w, r) return nil } return errorHandler(secureHandler) } // Authenticate returns a middleware that parsing an incoming request for a JWT, // calls the client-supplied auth function, and if successful, returns a JWT to // the requester. func (m *Middleware) Authenticate() http.Handler { generateHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { if r.Method != "POST" { return &jwtError{ status: http.StatusBadRequest, err: ErrInvalidMethod, message: "receiving request", } } b := make(map[string]string, 0) contentType := r.Header.Get("content-type") switch contentType { case "application/x-www-form-urlencoded", "application/x-www-form-urlencoded; charset=UTF-8": identity, verify := r.FormValue(m.identityField), r.FormValue(m.verifyField) if identity == "" || verify == "" { return &jwtError{ status: http.StatusInternalServerError, err: ErrParsingCredentials, message: "parsing authorization", } } b[m.identityField], b[m.verifyField] = identity, verify default: err := json.NewDecoder(r.Body).Decode(&b) if err != nil { return &jwtError{ status: http.StatusInternalServerError, err: ErrParsingCredentials, message: "parsing authorization", } } } // Check if required fields are in the body if _, ok := b[m.identityField]; !ok { return &jwtError{ status: http.StatusBadRequest, err: ErrParsingCredentials, message: "parsing credentials, missing identity field", } } if _, ok := b[m.verifyField]; !ok { return &jwtError{ status: http.StatusBadRequest, err: ErrParsingCredentials, message: "parsing credentials, missing verify field", } } err := m.auth(b[m.identityField], b[m.verifyField]) if err != nil { return &jwtError{ status: http.StatusInternalServerError, err: err, message: "performing authorization", } } response, err := m.CreateToken(b[m.identityField]) if err != nil { return &jwtError{ status: http.StatusInternalServerError, err: err, message: response, } } w.Write([]byte(response)) return nil } return errorHandler(generateHandler) } // CreateToken generates a token from a user's identity func (m *Middleware) CreateToken(identity string) (string, error) { // For now, the header will be static header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) if err != nil { return "encoding header", ErrEncoding } // Generate claims for user claims, err := m.claims(identity) if err != nil { return "generating claims", err } claimsJSON, err := json.Marshal(claims) if err != nil { return "mashalling claims", ErrEncoding } claimsSet, err := encode(claimsJSON) if err != nil { return "encoding claims", ErrEncoding } toSig := strings.Join([]string{header, claimsSet}, ".") h := hmac.New(sha256.New, []byte(m.secret)) h.Write([]byte(toSig)) sig, err := encode(h.Sum(nil)) if err != nil { return "encoding signature", ErrEncoding } response := strings.Join([]string{toSig, sig}, ".") return response, nil } // VerifyToken verifies a token func (m *Middleware) VerifyToken(token string, v VerifyClaimsFunc, r *http.Request) (int, string, error) { tokenParts := strings.Split(token, ".") if len(tokenParts) != 3 { return http.StatusUnauthorized, "", ErrMalformedToken } // First, verify JOSE header header, err := decode(tokenParts[0]) if err != nil { return http.StatusInternalServerError, fmt.Sprintf("decoding header (%v)", tokenParts[0]), err } var t struct { Typ string Alg string } err = json.Unmarshal(header, &t) if err != nil { return http.StatusInternalServerError, fmt.Sprintf("unmarshalling header (%s)", header), ErrMalformedToken } // Then, verify signature mac := hmac.New(sha256.New, []byte(m.secret)) message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) mac.Write(message) expectedMac, err := encode(mac.Sum(nil)) if err != nil { return http.StatusInternalServerError, "", err } if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { return http.StatusUnauthorized, fmt.Sprintf("checking signature (%v)", tokenParts[2]), ErrInvalidSignature } // Finally, check claims claimSet, err := decode(tokenParts[1]) if err != nil { return http.StatusInternalServerError, "decoding claims", ErrDecoding } err = v(claimSet, r) if err != nil { return http.StatusUnauthorized, "handling claims callback", err } return 200, "", nil } type jwtError struct { status int message string err error } type errorHandler func(http.ResponseWriter, *http.Request) *jwtError func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := e(w, r); err != nil { if err.message != "" { log.Printf("error (%v) while %s", err.err, err.message) } http.Error(w, err.err.Error(), err.status) } } func encode(s interface{}) (string, error) { var r []byte switch v := s.(type) { case string: r = []byte(v) case []byte: r = v default: return "", ErrEncoding } return base64.RawURLEncoding.EncodeToString(r), nil } func decode(s string) ([]byte, error) { return base64.RawURLEncoding.DecodeString(s) }