diff --git a/jwt.go b/jwt.go index 8366378..80381e3 100644 --- a/jwt.go +++ b/jwt.go @@ -33,13 +33,6 @@ var ( ErrParsingCredentials = errors.New("error parsing credentials") ) -// Config is a container for setting up the JWT middleware. -type Config struct { - Secret string - Auth AuthFunc - Claims ClaimsFunc -} - // AuthFunc is a type for delegating user authentication to the client-code. type AuthFunc func(string, string) error @@ -50,12 +43,23 @@ type ClaimsFunc func(string) (map[string]interface{}, error) // or more route's in the client-code. type VerifyClaimsFunc func([]byte) error +// Config is a container for setting up the JWT middleware. +type Config struct { + Secret string + Auth AuthFunc + Claims ClaimsFunc + IdentityField string + VerifyField string +} + // Middleware is where we store all the specifics related to the client's // JWT needs. type Middleware struct { - secret string - auth AuthFunc - claims ClaimsFunc + secret string + auth AuthFunc + claims ClaimsFunc + identityField string + verifyField string } // New creates a new Middleware from a user-specified configuration. @@ -72,10 +76,18 @@ func New(c *Config) (*Middleware, error) { if c.Claims == nil { return nil, ErrMissingClaimsFunc } + if c.IdentityField == "" { + c.IdentityField = "email" + } + if c.VerifyField == "" { + c.VerifyField = "password" + } m := &Middleware{ - secret: c.Secret, - auth: c.Auth, - claims: c.Claims, + secret: c.Secret, + auth: c.Auth, + claims: c.Claims, + identityField: c.IdentityField, + verifyField: c.VerifyField, } return m, nil } @@ -172,7 +184,22 @@ func (m *Middleware) GenerateToken() http.Handler { message: "parsing authorization", } } - err = m.auth(b["email"], b["password"]) + // Check if required fields are in the body + if _, ok := b[m.identityField]; !ok { + return &jwtError{ + status: http.StatusBadRequest, + err: ErrParsingCredentials, + message: "parsing credentials, missing identity field", + } + } + if _, ok := b[m.verifyField]; !ok { + return &jwtError{ + status: http.StatusBadRequest, + err: ErrParsingCredentials, + message: "parsing credentials, missing verify field", + } + } + err = m.auth(b[m.identityField], b[m.verifyField]) if err != nil { return &jwtError{ status: http.StatusInternalServerError, @@ -192,7 +219,7 @@ func (m *Middleware) GenerateToken() http.Handler { } // Generate claims for user - claims, err := m.claims(b["email"]) + claims, err := m.claims(m.identityField) if err != nil { return &jwtError{ status: http.StatusInternalServerError, diff --git a/jwt_test.go b/jwt_test.go index c2ad4ba..ecdd39c 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -100,6 +100,12 @@ func TestNewJWTMiddleware(t *testing.T) { if _, ok := claimsVal["iat"]; !ok { t.Errorf("wanted a claims set, got %v", claimsVal) } + if middleware.identityField != "email" { + t.Errorf("wanted email, got %v", middleware.identityField) + } + if middleware.verifyField != "password" { + t.Errorf("wanted password, got %v", middleware.verifyField) + } } func TestNewJWTMiddlewareNoConfig(t *testing.T) {