diff --git a/jwt.go b/jwt.go index e300664..8d4ffbf 100644 --- a/jwt.go +++ b/jwt.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "log" "net/http" "strings" ) @@ -16,6 +17,9 @@ var ( ErrMissingAuthFunc = errors.New("please provide an auth function") ErrMissingClaimsFunc = errors.New("please provide a claims function") ErrEncoding = errors.New("error encoding value") + ErrMissingToken = errors.New("please provide a token") + ErrMalformedToken = errors.New("please provide a valid token") + ErrDecodingHeader = errors.New("could not decode JOSE header") ) type Config struct { @@ -56,8 +60,35 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) { } func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { - // This is just a placeholder for now return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, ErrMissingToken.Error(), http.StatusUnauthorized) + return + } + token := strings.Split(authHeader, " ")[1] + if strings.LastIndex(token, ".") == -1 { + http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized) + return + } + // Verify JOSE header + var t struct { + Typ string + Alg string + } + tokenParts := strings.Split(token, ".") + header, err := decode(tokenParts[0]) + if err != nil { + log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0]) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = json.Unmarshal(header, &t) + if err != nil { + log.Printf("error (%v) while unmarshalling header (%s)", err, header) + http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError) + return + } h.ServeHTTP(w, r) }) } @@ -122,3 +153,7 @@ func encode(s interface{}) (string, error) { } return base64.StdEncoding.EncodeToString(r), nil } + +func decode(s string) ([]byte, error) { + return base64.StdEncoding.DecodeString(s) +} diff --git a/jwt_test.go b/jwt_test.go index 0f97934..dca513f 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -43,6 +44,29 @@ func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { return middleware } +func newToken(t *testing.T) (string, *JWTMiddleware) { + middleware := newJWTMiddlewareOrFatal(t) + authBody := map[string]interface{}{ + "email": "user@example.com", + "password": "password", + } + body, err := json.Marshal(authBody) + if err != nil { + t.Error(err) + } + + ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken)) + defer ts.Close() + + resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) + respBody, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Error(err) + } + return string(respBody), middleware +} + func TestNewJWTMiddleware(t *testing.T) { middleware := newJWTMiddlewareOrFatal(t) if middleware.secret != "password" { @@ -88,39 +112,9 @@ func TestNewJWTMiddlewareNoConfig(t *testing.T) { } } } - -func TestSecureHandler(t *testing.T) { - middleware := newJWTMiddlewareOrFatal(t) - resp := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "http://example.com", nil) - middleware.Secure(testHandler).ServeHTTP(resp, req) - if resp.Body.String() != "test" { - t.Errorf("wanted test, got %v", resp.Body.String()) - } -} - func TestGenerateTokenHandler(t *testing.T) { - middleware := newJWTMiddlewareOrFatal(t) - authBody := map[string]interface{}{ - "email": "user@example.com", - "password": "password", - } - body, err := json.Marshal(authBody) - if err != nil { - t.Error(err) - } - - ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken)) - defer ts.Close() - - resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) - respBody, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - t.Error(err) - } - - j := strings.Split(string(respBody), ".") + token, m := newToken(t) + j := strings.Split(token, ".") header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`)) if j[0] != header { @@ -141,7 +135,7 @@ func TestGenerateTokenHandler(t *testing.T) { if duration != d { t.Errorf("wanted %v, got %v", d, duration) } - mac := hmac.New(sha256.New, []byte(middleware.secret)) + mac := hmac.New(sha256.New, []byte(m.secret)) message := []byte(strings.Join([]string{j[0], j[1]}, ".")) mac.Write(message) expectedMac := base64.StdEncoding.EncodeToString(mac.Sum(nil)) @@ -149,3 +143,47 @@ func TestGenerateTokenHandler(t *testing.T) { t.Errorf("wanted %v, got %v", expectedMac, j[2]) } } + +func TestSecureHandlerNoToken(t *testing.T) { + middleware := newJWTMiddlewareOrFatal(t) + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com", nil) + middleware.Secure(testHandler).ServeHTTP(resp, req) + body := strings.TrimSpace(resp.Body.String()) + if body != ErrMissingToken.Error() { + t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body) + } +} + +func TestSecureHandlerBadToken(t *testing.T) { + middleware := newJWTMiddlewareOrFatal(t) + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("Authorization", "Bearer abcdefg") + middleware.Secure(testHandler).ServeHTTP(resp, req) + body := strings.TrimSpace(resp.Body.String()) + if body != ErrMalformedToken.Error() { + t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) + } + + resp = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("Authorization", "Bearer abcd.abcd.abcd") + middleware.Secure(testHandler).ServeHTTP(resp, req) + body = strings.TrimSpace(resp.Body.String()) + if body != ErrMalformedToken.Error() { + t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) + } +} + +func TestSecureHandlerGoodToken(t *testing.T) { + token, middleware := newToken(t) + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + middleware.Secure(testHandler).ServeHTTP(resp, req) + body := strings.TrimSpace(resp.Body.String()) + if body != "test" { + t.Errorf("wanted %s, got %s", "test", body) + } +}