diff --git a/jwt.go b/jwt.go index de26278..848a430 100644 --- a/jwt.go +++ b/jwt.go @@ -69,18 +69,33 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) { return m, nil } +type jwtError struct { + status int + err error + message string +} + +type errorHandler func(http.ResponseWriter, *http.Request) *jwtError + +func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := e(w, r); err != nil { + if err.message != "" { + log.Printf("error (%v) while %s", err.err, err.message) + } + http.Error(w, err.err.Error(), err.status) + } +} + func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secureHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { authHeader := r.Header.Get("Authorization") if authHeader == "" { - http.Error(w, ErrMissingToken.Error(), http.StatusUnauthorized) - return + return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken} } token := strings.Split(authHeader, " ")[1] tokenParts := strings.Split(token, ".") if len(tokenParts) != 3 { - http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized) - return + return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken} } // First, verify JOSE header @@ -90,15 +105,19 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler } header, err := decode(tokenParts[0]) if err != nil { - log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0]) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + return &jwtError{ + status: http.StatusInternalServerError, + err: err, + message: fmt.Sprintf("decoding header (%v)", tokenParts[0]), + } } err = json.Unmarshal(header, &t) if err != nil { - log.Printf("error (%v) while unmarshalling header (%s)", err, header) - http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError) - return + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrMalformedToken, + message: fmt.Sprintf("unmarshalling header (%s)", header), + } } // Then, verify signature @@ -107,32 +126,40 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler mac.Write(message) expectedMac, err := encode(mac.Sum(nil)) if err != nil { - panic(err) - return + return &jwtError{status: http.StatusInternalServerError, err: err} } if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { - log.Printf("invalid signature: %v", tokenParts[2]) - http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized) - return + return &jwtError{ + status: http.StatusUnauthorized, + err: ErrInvalidSignature, + message: fmt.Sprintf("checking signature (%v)", tokenParts[2]), + } } // Finally, check claims claimSet, err := decode(tokenParts[1]) if err != nil { - log.Printf("error (%v) while decoding claims", err) - http.Error(w, ErrDecoding.Error(), http.StatusInternalServerError) - return + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrDecoding, + message: "decoding claims", + } } err = v(claimSet) if err != nil { - log.Printf("claims handler error: %v", err) - http.Error(w, err.Error(), http.StatusUnauthorized) - return + return &jwtError{ + status: http.StatusUnauthorized, + err: err, + message: "handling claims callback", + } } // If we make it this far, process the downstream handler h.ServeHTTP(w, r) - }) + return nil + } + + return errorHandler(secureHandler) } func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) {