Roughed in claims

This commit is contained in:
Matthew Dillon 2015-04-18 11:38:44 -08:00
parent c102229487
commit 5d9f1a3b5f
2 changed files with 108 additions and 15 deletions

55
jwt.go
View file

@ -5,24 +5,31 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
"strings"
) )
var ( var (
ErrMissingConfig = errors.New("missing configuration") ErrMissingConfig = errors.New("missing configuration")
ErrMissingSecret = errors.New("please provide a shared secret") ErrMissingSecret = errors.New("please provide a shared secret")
ErrMissingAuthFunc = errors.New("please provide an auth function") ErrMissingAuthFunc = errors.New("please provide an auth function")
ErrMissingClaimsFunc = errors.New("please provide a claims function")
ErrEncoding = errors.New("error encoding value")
) )
type Config struct { type Config struct {
Secret string Secret string
Auth AuthFunc Auth AuthFunc
Claims ClaimsFunc
} }
type AuthFunc func(string, string) (bool, error) type AuthFunc func(string, string) (bool, error)
type ClaimsFunc func(id string) (map[string]interface{}, error)
type JWTMiddleware struct { type JWTMiddleware struct {
secret string secret string
auth AuthFunc auth AuthFunc
claims ClaimsFunc
} }
func NewMiddleware(c *Config) (*JWTMiddleware, error) { func NewMiddleware(c *Config) (*JWTMiddleware, error) {
@ -35,9 +42,13 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
if c.Auth == nil { if c.Auth == nil {
return nil, ErrMissingAuthFunc return nil, ErrMissingAuthFunc
} }
if c.Claims == nil {
return nil, ErrMissingClaimsFunc
}
m := &JWTMiddleware{ m := &JWTMiddleware{
secret: c.Secret, secret: c.Secret,
auth: c.Auth, auth: c.Auth,
claims: c.Claims,
} }
return m, nil return m, nil
} }
@ -64,8 +75,40 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) {
} }
// For now, the header will be static // For now, the header will be static
resp := `{"typ":"JWT","alg":"HS256"}` header, err := encode(`{"typ":"JWT","alg":"HS256"}`)
resp = base64.StdEncoding.EncodeToString([]byte(resp)) if err != nil {
panic(err)
}
w.Write([]byte(resp)) claims, err := m.claims(b["email"])
if err != nil {
panic(err)
}
claimsJson, err := json.Marshal(claims)
if err != nil {
panic(err)
}
claimsSet, err := encode(claimsJson)
if err != nil {
panic(err)
}
response := strings.Join([]string{header, claimsSet}, ".")
w.Write([]byte(response))
}
func encode(s interface{}) (string, error) {
var r []byte
switch v := s.(type) {
case string:
r = []byte(v)
case []byte:
r = v
default:
return "", ErrEncoding
}
return base64.StdEncoding.EncodeToString(r), nil
} }

View file

@ -2,11 +2,14 @@ package jwt
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time"
) )
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -17,10 +20,19 @@ var authFunc = func(email, password string) (bool, error) {
return true, nil return true, nil
} }
var claimsFunc = func(id string) (map[string]interface{}, error) {
currentTime := time.Now()
return map[string]interface{}{
"iat": currentTime.Unix(),
"exp": currentTime.Add(time.Minute * 60 * 24).Unix(),
}, nil
}
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
config := &Config{ config := &Config{
Secret: "password", Secret: "password",
Auth: authFunc, Auth: authFunc,
Claims: claimsFunc,
} }
middleware, err := NewMiddleware(config) middleware, err := NewMiddleware(config)
if err != nil { if err != nil {
@ -34,21 +46,38 @@ func TestNewJWTMiddleware(t *testing.T) {
if middleware.secret != "password" { if middleware.secret != "password" {
t.Errorf("wanted password, got %v", middleware.secret) t.Errorf("wanted password, got %v", middleware.secret)
} }
val, err := middleware.auth("", "") authVal, err := middleware.auth("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if val != true { if authVal != true {
t.Errorf("wanted true, got %v", val) t.Errorf("wanted true, got %v", authVal)
}
claimsVal, err := middleware.claims("1")
if err != nil {
t.Fatal(err)
}
if _, ok := claimsVal["iat"]; !ok {
t.Errorf("wanted a claims set, got %v", claimsVal)
} }
} }
func TestNewJWTMiddlewareNoConfig(t *testing.T) { func TestNewJWTMiddlewareNoConfig(t *testing.T) {
cases := map[*Config]error{ cases := map[*Config]error{
nil: ErrMissingConfig, nil: ErrMissingConfig,
&Config{}: ErrMissingSecret, &Config{}: ErrMissingSecret,
&Config{Auth: authFunc}: ErrMissingSecret, &Config{
&Config{Secret: "secret"}: ErrMissingAuthFunc, Auth: authFunc,
Claims: claimsFunc,
}: ErrMissingSecret,
&Config{
Secret: "secret",
Claims: claimsFunc,
}: ErrMissingAuthFunc,
&Config{
Auth: authFunc,
Secret: "secret",
}: ErrMissingClaimsFunc,
} }
for config, jwtErr := range cases { for config, jwtErr := range cases {
_, err := NewMiddleware(config) _, err := NewMiddleware(config)
@ -78,15 +107,36 @@ func TestGenerateTokenHandler(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken)) ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
defer ts.Close() defer ts.Close()
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
respBody, err := ioutil.ReadAll(resp.Body) respBody, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if string(respBody) != "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" {
t.Errorf("wanted eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9, got %v", string(respBody)) j := strings.Split(string(respBody), ".")
header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`))
if j[0] != header {
t.Errorf("wanted %v, got %v", header, j[0])
}
claims, err := base64.StdEncoding.DecodeString(j[1])
var c struct {
Exp int
Iat int
}
err = json.Unmarshal(claims, &c)
if err != nil {
t.Error(err)
}
duration := time.Duration(c.Exp-c.Iat) * time.Second
d := time.Minute * 60 * 24
if duration != d {
t.Errorf("wanted %v, got %v", d, duration)
} }
} }