diff --git a/jwt.go b/jwt.go index 0575a44..cd5f2ec 100644 --- a/jwt.go +++ b/jwt.go @@ -6,6 +6,12 @@ import ( "net/http" ) +var ( + ErrMissingConfig = errors.New("missing configuration") + ErrMissingSecret = errors.New("please provide a shared secret") + ErrMissingAuthFunc = errors.New("please provide an auth function") +) + type Config struct { Secret string Auth AuthFunc @@ -14,14 +20,24 @@ type Config struct { type AuthFunc func(string, string) (bool, error) type JWTMiddleware struct { - config Config + secret string + auth AuthFunc } func NewMiddleware(c *Config) (*JWTMiddleware, error) { if c == nil { - return nil, errors.New("missing configuration") + return nil, ErrMissingConfig + } + if c.Secret == "" { + return nil, ErrMissingSecret + } + if c.Auth == nil { + return nil, ErrMissingAuthFunc + } + m := &JWTMiddleware{ + secret: c.Secret, + auth: c.Auth, } - m := &JWTMiddleware{config: *c} return m, nil } @@ -37,7 +53,7 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) { if err != nil { panic(err) } - result, err := m.config.Auth(b["email"], b["password"]) + result, err := m.auth(b["email"], b["password"]) if err != nil { panic(err) } diff --git a/jwt_test.go b/jwt_test.go index 29facdd..be1aafc 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -31,16 +31,24 @@ func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { func TestNewJWTMiddleware(t *testing.T) { middleware := newJWTMiddlewareOrFatal(t) - if middleware.config.Secret != "password" { - t.Errorf("expected 'password', got %v", middleware.config.Secret) + if middleware.secret != "password" { + t.Errorf("expected 'password', got %v", middleware.secret) } // TODO: test auth func init } func TestNewJWTMiddlewareNoConfig(t *testing.T) { - _, err := NewMiddleware(nil) - if err == nil { - t.Error("expected configuration error, received none") + cases := map[*Config]error{ + nil: ErrMissingConfig, + &Config{}: ErrMissingSecret, + &Config{Auth: authFunc}: ErrMissingSecret, + &Config{Secret: "secret"}: ErrMissingAuthFunc, + } + for config, jwtErr := range cases { + _, err := NewMiddleware(config) + if err != jwtErr { + t.Errorf("wanted error: %v, got error: %v using config: %v", jwtErr, err, config) + } } }