diff --git a/jwt.go b/jwt.go index 62b77d3..52a066b 100644 --- a/jwt.go +++ b/jwt.go @@ -5,24 +5,31 @@ import ( "encoding/json" "errors" "net/http" + "strings" ) var ( - ErrMissingConfig = errors.New("missing configuration") - ErrMissingSecret = errors.New("please provide a shared secret") - ErrMissingAuthFunc = errors.New("please provide an auth function") + ErrMissingConfig = errors.New("missing configuration") + ErrMissingSecret = errors.New("please provide a shared secret") + ErrMissingAuthFunc = errors.New("please provide an auth function") + ErrMissingClaimsFunc = errors.New("please provide a claims function") + ErrEncoding = errors.New("error encoding value") ) type Config struct { Secret string Auth AuthFunc + Claims ClaimsFunc } type AuthFunc func(string, string) (bool, error) +type ClaimsFunc func(id string) (map[string]interface{}, error) + type JWTMiddleware struct { secret string auth AuthFunc + claims ClaimsFunc } func NewMiddleware(c *Config) (*JWTMiddleware, error) { @@ -35,9 +42,13 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) { if c.Auth == nil { return nil, ErrMissingAuthFunc } + if c.Claims == nil { + return nil, ErrMissingClaimsFunc + } m := &JWTMiddleware{ secret: c.Secret, auth: c.Auth, + claims: c.Claims, } return m, nil } @@ -64,8 +75,40 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) { } // For now, the header will be static - resp := `{"typ":"JWT","alg":"HS256"}` - resp = base64.StdEncoding.EncodeToString([]byte(resp)) + header, err := encode(`{"typ":"JWT","alg":"HS256"}`) + if err != nil { + panic(err) + } - w.Write([]byte(resp)) + claims, err := m.claims(b["email"]) + if err != nil { + panic(err) + } + + claimsJson, err := json.Marshal(claims) + if err != nil { + panic(err) + } + + claimsSet, err := encode(claimsJson) + if err != nil { + panic(err) + } + + response := strings.Join([]string{header, claimsSet}, ".") + + w.Write([]byte(response)) +} + +func encode(s interface{}) (string, error) { + var r []byte + switch v := s.(type) { + case string: + r = []byte(v) + case []byte: + r = v + default: + return "", ErrEncoding + } + return base64.StdEncoding.EncodeToString(r), nil } diff --git a/jwt_test.go b/jwt_test.go index 6381ffe..8b14044 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -2,11 +2,14 @@ package jwt import ( "bytes" + "encoding/base64" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" + "time" ) var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -17,10 +20,19 @@ var authFunc = func(email, password string) (bool, error) { return true, nil } +var claimsFunc = func(id string) (map[string]interface{}, error) { + currentTime := time.Now() + return map[string]interface{}{ + "iat": currentTime.Unix(), + "exp": currentTime.Add(time.Minute * 60 * 24).Unix(), + }, nil +} + func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { config := &Config{ Secret: "password", Auth: authFunc, + Claims: claimsFunc, } middleware, err := NewMiddleware(config) if err != nil { @@ -34,21 +46,38 @@ func TestNewJWTMiddleware(t *testing.T) { if middleware.secret != "password" { t.Errorf("wanted password, got %v", middleware.secret) } - val, err := middleware.auth("", "") + authVal, err := middleware.auth("", "") if err != nil { t.Fatal(err) } - if val != true { - t.Errorf("wanted true, got %v", val) + if authVal != true { + t.Errorf("wanted true, got %v", authVal) + } + claimsVal, err := middleware.claims("1") + if err != nil { + t.Fatal(err) + } + if _, ok := claimsVal["iat"]; !ok { + t.Errorf("wanted a claims set, got %v", claimsVal) } } func TestNewJWTMiddlewareNoConfig(t *testing.T) { cases := map[*Config]error{ - nil: ErrMissingConfig, - &Config{}: ErrMissingSecret, - &Config{Auth: authFunc}: ErrMissingSecret, - &Config{Secret: "secret"}: ErrMissingAuthFunc, + nil: ErrMissingConfig, + &Config{}: ErrMissingSecret, + &Config{ + Auth: authFunc, + Claims: claimsFunc, + }: ErrMissingSecret, + &Config{ + Secret: "secret", + Claims: claimsFunc, + }: ErrMissingAuthFunc, + &Config{ + Auth: authFunc, + Secret: "secret", + }: ErrMissingClaimsFunc, } for config, jwtErr := range cases { _, err := NewMiddleware(config) @@ -78,15 +107,36 @@ func TestGenerateTokenHandler(t *testing.T) { if err != nil { t.Error(err) } + ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken)) defer ts.Close() + resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) respBody, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { t.Error(err) } - if string(respBody) != "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" { - t.Errorf("wanted eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9, got %v", string(respBody)) + + j := strings.Split(string(respBody), ".") + + header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`)) + if j[0] != header { + t.Errorf("wanted %v, got %v", header, j[0]) + } + + claims, err := base64.StdEncoding.DecodeString(j[1]) + var c struct { + Exp int + Iat int + } + err = json.Unmarshal(claims, &c) + if err != nil { + t.Error(err) + } + duration := time.Duration(c.Exp-c.Iat) * time.Second + d := time.Minute * 60 * 24 + if duration != d { + t.Errorf("wanted %v, got %v", d, duration) } }