diff --git a/jwt.go b/jwt.go index 8d4ffbf..4a0f2d3 100644 --- a/jwt.go +++ b/jwt.go @@ -20,6 +20,7 @@ var ( ErrMissingToken = errors.New("please provide a token") ErrMalformedToken = errors.New("please provide a valid token") ErrDecodingHeader = errors.New("could not decode JOSE header") + ErrInvalidSignature = errors.New("signature could not be verified") ) type Config struct { @@ -71,12 +72,13 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized) return } + tokenParts := strings.Split(token, ".") + // Verify JOSE header var t struct { Typ string Alg string } - tokenParts := strings.Split(token, ".") header, err := decode(tokenParts[0]) if err != nil { log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0]) @@ -89,6 +91,21 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError) return } + + // Verify signature + mac := hmac.New(sha256.New, []byte(m.secret)) + message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) + mac.Write(message) + expectedMac, err := encode(mac.Sum(nil)) + if err != nil { + panic(err) + return + } + if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { + log.Printf("invalid signature: %v", tokenParts[2]) + http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized) + return + } h.ServeHTTP(w, r) }) } diff --git a/jwt_test.go b/jwt_test.go index dca513f..205cbdd 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -176,6 +176,20 @@ func TestSecureHandlerBadToken(t *testing.T) { } } +func TestSecureHandlerBadSignature(t *testing.T) { + token, middleware := newToken(t) + parts := strings.Split(token, ".") + token = strings.Join([]string{parts[0], parts[1], "abcd"}, ".") + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + middleware.Secure(testHandler).ServeHTTP(resp, req) + body := strings.TrimSpace(resp.Body.String()) + if body != ErrInvalidSignature.Error() { + t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body) + } +} + func TestSecureHandlerGoodToken(t *testing.T) { token, middleware := newToken(t) resp := httptest.NewRecorder()