diff --git a/README.md b/README.md index c97712c..226ae87 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ func main() { protect := http.HandlerFunc(protectMe) dontProtect := http.HandlerFunc(dontProtectMe) - http.Handle("/authenticate", j.GenerateToken()) + http.Handle("/authenticate", j.Authenticate()) http.Handle("/secure", j.Secure(protect, verifyClaims)) http.Handle("/insecure", dontProtect) http.ListenAndServe(":8080", nil) @@ -137,7 +137,7 @@ Once the middleware is instantiated, create a route for users to generate a JWT at. ```go -http.Handle("/authenticate", j.GenerateToken()) +http.Handle("/authenticate", j.Authenticate()) ``` The auth function takes two arguments (the identity, and the authorization diff --git a/jwt.go b/jwt.go index d37be84..00eb489 100644 --- a/jwt.go +++ b/jwt.go @@ -171,10 +171,10 @@ func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { return errorHandler(secureHandler) } -// GenerateToken returns a middleware that parsing an incoming request for a JWT, +// Authenticate returns a middleware that parsing an incoming request for a JWT, // calls the client-supplied auth function, and if successful, returns a JWT to // the requester. -func (m *Middleware) GenerateToken() http.Handler { +func (m *Middleware) Authenticate() http.Handler { generateHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { if r.Method != "POST" { return &jwtError{ @@ -215,59 +215,14 @@ func (m *Middleware) GenerateToken() http.Handler { message: "performing authorization", } } - - // For now, the header will be static - header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrEncoding, - message: "encoding header", - } - } - - // Generate claims for user - claims, err := m.claims(b[m.identityField]) + response, err := m.CreateToken(b[m.identityField]) if err != nil { return &jwtError{ status: http.StatusInternalServerError, err: err, - message: "generating claims", + message: response, } } - - claimsJSON, err := json.Marshal(claims) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrEncoding, - message: "marshalling claims", - } - } - - claimsSet, err := encode(claimsJSON) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrEncoding, - message: "encoding claims", - } - } - - toSig := strings.Join([]string{header, claimsSet}, ".") - - h := hmac.New(sha256.New, []byte(m.secret)) - h.Write([]byte(toSig)) - sig, err := encode(h.Sum(nil)) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrEncoding, - message: "encoding signature", - } - } - - response := strings.Join([]string{toSig, sig}, ".") w.Write([]byte(response)) return nil } @@ -275,6 +230,43 @@ func (m *Middleware) GenerateToken() http.Handler { return errorHandler(generateHandler) } +// CreateToken generates a token from a user's identity +func (m *Middleware) CreateToken(identity string) (string, error) { + // For now, the header will be static + header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) + if err != nil { + return "encoding header", ErrEncoding + } + + // Generate claims for user + claims, err := m.claims(identity) + if err != nil { + return "generating claims", err + } + + claimsJSON, err := json.Marshal(claims) + if err != nil { + return "mashalling claims", ErrEncoding + } + + claimsSet, err := encode(claimsJSON) + if err != nil { + return "encoding claims", ErrEncoding + } + + toSig := strings.Join([]string{header, claimsSet}, ".") + + h := hmac.New(sha256.New, []byte(m.secret)) + h.Write([]byte(toSig)) + sig, err := encode(h.Sum(nil)) + if err != nil { + return "encoding signature", ErrEncoding + } + + response := strings.Join([]string{toSig, sig}, ".") + return response, nil +} + type jwtError struct { status int err error diff --git a/jwt_test.go b/jwt_test.go index c5579b8..44ec5d3 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -72,7 +72,7 @@ func newToken(t *testing.T) (string, *Middleware) { t.Error(err) } - ts := httptest.NewServer(middleware.GenerateToken()) + ts := httptest.NewServer(middleware.Authenticate()) defer ts.Close() resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) @@ -226,7 +226,7 @@ func TestGenerateTokenHandlerNotPOST(t *testing.T) { middleware := newMiddlewareOrFatal(t) resp := httptest.NewRecorder() req, _ := http.NewRequest("PUT", "http://example.com", nil) - middleware.GenerateToken().ServeHTTP(resp, req) + middleware.Authenticate().ServeHTTP(resp, req) body := strings.TrimSpace(resp.Body.String()) if body != ErrInvalidMethod.Error() { t.Errorf("wanted %q, got %q", ErrInvalidMethod.Error(), body)