diff --git a/examples/net-http.go b/examples/net-http.go index cf74d0a..796630d 100644 --- a/examples/net-http.go +++ b/examples/net-http.go @@ -45,7 +45,7 @@ func main() { panic(err) } protect := http.HandlerFunc(protectMe) - http.HandleFunc("/authenticate", j.GenerateToken) + http.Handle("/authenticate", j.GenerateToken()) http.Handle("/secure", j.Secure(protect, verifyClaimsFunc)) http.ListenAndServe(":8080", nil) } diff --git a/jwt.go b/jwt.go index 848a430..5410cf2 100644 --- a/jwt.go +++ b/jwt.go @@ -162,64 +162,83 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler return errorHandler(secureHandler) } -func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) { - var b map[string]string - err := json.NewDecoder(r.Body).Decode(&b) - if err != nil { - log.Printf("error (%v) while parsing authorization", err) - http.Error(w, ErrParsingCredentials.Error(), http.StatusInternalServerError) - return - } - err = m.auth(b["email"], b["password"]) - if err != nil { - log.Printf("error (%v) while performing authorization", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return +func (m *JWTMiddleware) GenerateToken() http.Handler { + generateHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { + var b map[string]string + err := json.NewDecoder(r.Body).Decode(&b) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrParsingCredentials, + message: "parsing authorization", + } + } + err = m.auth(b["email"], b["password"]) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: err, + message: "performing authorization", + } + } + + // For now, the header will be static + header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrEncoding, + message: "encoding header", + } + } + + // Generate claims for user + claims, err := m.claims(b["email"]) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: err, + message: "generating claims", + } + } + + claimsJson, err := json.Marshal(claims) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrEncoding, + message: "marshalling claims", + } + } + + claimsSet, err := encode(claimsJson) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrEncoding, + message: "encoding claims", + } + } + + toSig := strings.Join([]string{header, claimsSet}, ".") + + h := hmac.New(sha256.New, []byte(m.secret)) + h.Write([]byte(toSig)) + sig, err := encode(h.Sum(nil)) + if err != nil { + return &jwtError{ + status: http.StatusInternalServerError, + err: ErrEncoding, + message: "encoding signature", + } + } + + response := strings.Join([]string{toSig, sig}, ".") + w.Write([]byte(response)) + return nil } - // For now, the header will be static - header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) - if err != nil { - log.Printf("error (%v) while encoding header", err) - http.Error(w, ErrEncoding.Error(), http.StatusInternalServerError) - return - } - - // Generate claims for user - claims, err := m.claims(b["email"]) - if err != nil { - log.Printf("error (%v) while generating claims", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - claimsJson, err := json.Marshal(claims) - if err != nil { - log.Printf("error (%v) while marshalling claims") - http.Error(w, ErrEncoding.Error(), http.StatusInternalServerError) - return - } - - claimsSet, err := encode(claimsJson) - if err != nil { - log.Printf("error (%v) while encoding claims") - http.Error(w, ErrEncoding.Error(), http.StatusInternalServerError) - return - } - - toSig := strings.Join([]string{header, claimsSet}, ".") - - h := hmac.New(sha256.New, []byte(m.secret)) - h.Write([]byte(toSig)) - sig, err := encode(h.Sum(nil)) - if err != nil { - log.Printf("error (%v) while encoding signature") - http.Error(w, ErrEncoding.Error(), http.StatusInternalServerError) - return - } - - response := strings.Join([]string{toSig, sig}, ".") - w.Write([]byte(response)) + return errorHandler(generateHandler) } func encode(s interface{}) (string, error) { diff --git a/jwt_test.go b/jwt_test.go index d71adf4..dafa1c1 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -72,7 +72,7 @@ func newToken(t *testing.T) (string, *JWTMiddleware) { t.Error(err) } - ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken)) + ts := httptest.NewServer(middleware.GenerateToken()) defer ts.Close() resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))