diff --git a/jwt.go b/jwt.go index 4a0f2d3..ca155b3 100644 --- a/jwt.go +++ b/jwt.go @@ -31,7 +31,9 @@ type Config struct { type AuthFunc func(string, string) (bool, error) -type ClaimsFunc func(id string) (map[string]interface{}, error) +type ClaimsFunc func(string) (map[string]interface{}, error) + +type VerifyClaimsFunc func([]byte) (bool, error) type JWTMiddleware struct { secret string @@ -60,7 +62,7 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) { return m, nil } -func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { +func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { @@ -74,7 +76,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { } tokenParts := strings.Split(token, ".") - // Verify JOSE header + // First, verify JOSE header var t struct { Typ string Alg string @@ -92,7 +94,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { return } - // Verify signature + // Then, verify signature mac := hmac.New(sha256.New, []byte(m.secret)) message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) mac.Write(message) @@ -106,6 +108,21 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized) return } + + // Finally, check claims + claimSet, err := decode(tokenParts[1]) + if err != nil { + panic(err) + return + } + claimsTest, err := v(claimSet) + if !claimsTest { + log.Printf("test: %v, error: %v", claimsTest, err) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + // If we make it this far, process the downstream handler h.ServeHTTP(w, r) }) } diff --git a/jwt_test.go b/jwt_test.go index 205cbdd..3c3cc51 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -31,6 +32,22 @@ var claimsFunc = func(id string) (map[string]interface{}, error) { }, nil } +var verifyClaimsFunc = func(claims []byte) (bool, error) { + currentTime := time.Now() + var c struct { + Exp int64 + Iat int64 + } + err := json.Unmarshal(claims, &c) + if err != nil { + return false, err + } + if currentTime.After(time.Unix(c.Exp, 0)) { + return false, errors.New("expired") + } + return true, nil +} + func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { config := &Config{ Secret: "password", @@ -148,7 +165,7 @@ func TestSecureHandlerNoToken(t *testing.T) { middleware := newJWTMiddlewareOrFatal(t) resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com", nil) - middleware.Secure(testHandler).ServeHTTP(resp, req) + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) body := strings.TrimSpace(resp.Body.String()) if body != ErrMissingToken.Error() { t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body) @@ -160,7 +177,7 @@ func TestSecureHandlerBadToken(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com", nil) req.Header.Set("Authorization", "Bearer abcdefg") - middleware.Secure(testHandler).ServeHTTP(resp, req) + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) body := strings.TrimSpace(resp.Body.String()) if body != ErrMalformedToken.Error() { t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) @@ -169,7 +186,7 @@ func TestSecureHandlerBadToken(t *testing.T) { resp = httptest.NewRecorder() req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("Authorization", "Bearer abcd.abcd.abcd") - middleware.Secure(testHandler).ServeHTTP(resp, req) + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) body = strings.TrimSpace(resp.Body.String()) if body != ErrMalformedToken.Error() { t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) @@ -183,7 +200,7 @@ func TestSecureHandlerBadSignature(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com", nil) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - middleware.Secure(testHandler).ServeHTTP(resp, req) + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) body := strings.TrimSpace(resp.Body.String()) if body != ErrInvalidSignature.Error() { t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body) @@ -195,7 +212,7 @@ func TestSecureHandlerGoodToken(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com", nil) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - middleware.Secure(testHandler).ServeHTTP(resp, req) + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) body := strings.TrimSpace(resp.Body.String()) if body != "test" { t.Errorf("wanted %s, got %s", "test", body)