Added private error handler middleware

This commit is contained in:
Matthew Dillon 2015-04-20 10:12:42 -08:00
parent 537b1ab886
commit 61507766fe

73
jwt.go
View file

@ -69,18 +69,33 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
return m, nil return m, nil
} }
type jwtError struct {
status int
err error
message string
}
type errorHandler func(http.ResponseWriter, *http.Request) *jwtError
func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := e(w, r); err != nil {
if err.message != "" {
log.Printf("error (%v) while %s", err.err, err.message)
}
http.Error(w, err.err.Error(), err.status)
}
}
func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { secureHandler := func(w http.ResponseWriter, r *http.Request) *jwtError {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
http.Error(w, ErrMissingToken.Error(), http.StatusUnauthorized) return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken}
return
} }
token := strings.Split(authHeader, " ")[1] token := strings.Split(authHeader, " ")[1]
tokenParts := strings.Split(token, ".") tokenParts := strings.Split(token, ".")
if len(tokenParts) != 3 { if len(tokenParts) != 3 {
http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized) return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken}
return
} }
// First, verify JOSE header // First, verify JOSE header
@ -90,15 +105,19 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler
} }
header, err := decode(tokenParts[0]) header, err := decode(tokenParts[0])
if err != nil { if err != nil {
log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0]) return &jwtError{
http.Error(w, err.Error(), http.StatusInternalServerError) status: http.StatusInternalServerError,
return err: err,
message: fmt.Sprintf("decoding header (%v)", tokenParts[0]),
}
} }
err = json.Unmarshal(header, &t) err = json.Unmarshal(header, &t)
if err != nil { if err != nil {
log.Printf("error (%v) while unmarshalling header (%s)", err, header) return &jwtError{
http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError) status: http.StatusInternalServerError,
return err: ErrMalformedToken,
message: fmt.Sprintf("unmarshalling header (%s)", header),
}
} }
// Then, verify signature // Then, verify signature
@ -107,32 +126,40 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler
mac.Write(message) mac.Write(message)
expectedMac, err := encode(mac.Sum(nil)) expectedMac, err := encode(mac.Sum(nil))
if err != nil { if err != nil {
panic(err) return &jwtError{status: http.StatusInternalServerError, err: err}
return
} }
if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) {
log.Printf("invalid signature: %v", tokenParts[2]) return &jwtError{
http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized) status: http.StatusUnauthorized,
return err: ErrInvalidSignature,
message: fmt.Sprintf("checking signature (%v)", tokenParts[2]),
}
} }
// Finally, check claims // Finally, check claims
claimSet, err := decode(tokenParts[1]) claimSet, err := decode(tokenParts[1])
if err != nil { if err != nil {
log.Printf("error (%v) while decoding claims", err) return &jwtError{
http.Error(w, ErrDecoding.Error(), http.StatusInternalServerError) status: http.StatusInternalServerError,
return err: ErrDecoding,
message: "decoding claims",
}
} }
err = v(claimSet) err = v(claimSet)
if err != nil { if err != nil {
log.Printf("claims handler error: %v", err) return &jwtError{
http.Error(w, err.Error(), http.StatusUnauthorized) status: http.StatusUnauthorized,
return err: err,
message: "handling claims callback",
}
} }
// If we make it this far, process the downstream handler // If we make it this far, process the downstream handler
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) return nil
}
return errorHandler(secureHandler)
} }
func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) { func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) {