diff --git a/jwt.go b/jwt.go index 80c0707..a79d0f5 100644 --- a/jwt.go +++ b/jwt.go @@ -107,7 +107,11 @@ func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken} } } else { - token = strings.Split(authHeader, " ")[1] + token_parts := strings.Split(authHeader, " ") + if len(token_parts) != 2 { + return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken} + } + token = token_parts[1] } if status, err, message := m.VerifyToken(token, v, r); err != nil { diff --git a/jwt_test.go b/jwt_test.go index 44ec5d3..01db635 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -232,3 +232,16 @@ func TestGenerateTokenHandlerNotPOST(t *testing.T) { t.Errorf("wanted %q, got %q", ErrInvalidMethod.Error(), body) } } + +func TestMalformedAuthorizationHeader(t *testing.T) { + _, middleware := newToken(t) + token := "hello!" + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("Authorization", token) // No "Bearer " portion of header + middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) + body := strings.TrimSpace(resp.Body.String()) + if body != ErrMalformedToken.Error() { + t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) + } +}