Roughed in claims
This commit is contained in:
parent
c102229487
commit
5d9f1a3b5f
2 changed files with 108 additions and 15 deletions
55
jwt.go
55
jwt.go
|
@ -5,24 +5,31 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingConfig = errors.New("missing configuration")
|
ErrMissingConfig = errors.New("missing configuration")
|
||||||
ErrMissingSecret = errors.New("please provide a shared secret")
|
ErrMissingSecret = errors.New("please provide a shared secret")
|
||||||
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
||||||
|
ErrMissingClaimsFunc = errors.New("please provide a claims function")
|
||||||
|
ErrEncoding = errors.New("error encoding value")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Secret string
|
Secret string
|
||||||
Auth AuthFunc
|
Auth AuthFunc
|
||||||
|
Claims ClaimsFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthFunc func(string, string) (bool, error)
|
type AuthFunc func(string, string) (bool, error)
|
||||||
|
|
||||||
|
type ClaimsFunc func(id string) (map[string]interface{}, error)
|
||||||
|
|
||||||
type JWTMiddleware struct {
|
type JWTMiddleware struct {
|
||||||
secret string
|
secret string
|
||||||
auth AuthFunc
|
auth AuthFunc
|
||||||
|
claims ClaimsFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMiddleware(c *Config) (*JWTMiddleware, error) {
|
func NewMiddleware(c *Config) (*JWTMiddleware, error) {
|
||||||
|
@ -35,9 +42,13 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
|
||||||
if c.Auth == nil {
|
if c.Auth == nil {
|
||||||
return nil, ErrMissingAuthFunc
|
return nil, ErrMissingAuthFunc
|
||||||
}
|
}
|
||||||
|
if c.Claims == nil {
|
||||||
|
return nil, ErrMissingClaimsFunc
|
||||||
|
}
|
||||||
m := &JWTMiddleware{
|
m := &JWTMiddleware{
|
||||||
secret: c.Secret,
|
secret: c.Secret,
|
||||||
auth: c.Auth,
|
auth: c.Auth,
|
||||||
|
claims: c.Claims,
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
@ -64,8 +75,40 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, the header will be static
|
// For now, the header will be static
|
||||||
resp := `{"typ":"JWT","alg":"HS256"}`
|
header, err := encode(`{"typ":"JWT","alg":"HS256"}`)
|
||||||
resp = base64.StdEncoding.EncodeToString([]byte(resp))
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
w.Write([]byte(resp))
|
claims, err := m.claims(b["email"])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claimsJson, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claimsSet, err := encode(claimsJson)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := strings.Join([]string{header, claimsSet}, ".")
|
||||||
|
|
||||||
|
w.Write([]byte(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
func encode(s interface{}) (string, error) {
|
||||||
|
var r []byte
|
||||||
|
switch v := s.(type) {
|
||||||
|
case string:
|
||||||
|
r = []byte(v)
|
||||||
|
case []byte:
|
||||||
|
r = v
|
||||||
|
default:
|
||||||
|
return "", ErrEncoding
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(r), nil
|
||||||
}
|
}
|
||||||
|
|
68
jwt_test.go
68
jwt_test.go
|
@ -2,11 +2,14 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -17,10 +20,19 @@ var authFunc = func(email, password string) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var claimsFunc = func(id string) (map[string]interface{}, error) {
|
||||||
|
currentTime := time.Now()
|
||||||
|
return map[string]interface{}{
|
||||||
|
"iat": currentTime.Unix(),
|
||||||
|
"exp": currentTime.Add(time.Minute * 60 * 24).Unix(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Secret: "password",
|
Secret: "password",
|
||||||
Auth: authFunc,
|
Auth: authFunc,
|
||||||
|
Claims: claimsFunc,
|
||||||
}
|
}
|
||||||
middleware, err := NewMiddleware(config)
|
middleware, err := NewMiddleware(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,21 +46,38 @@ func TestNewJWTMiddleware(t *testing.T) {
|
||||||
if middleware.secret != "password" {
|
if middleware.secret != "password" {
|
||||||
t.Errorf("wanted password, got %v", middleware.secret)
|
t.Errorf("wanted password, got %v", middleware.secret)
|
||||||
}
|
}
|
||||||
val, err := middleware.auth("", "")
|
authVal, err := middleware.auth("", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if val != true {
|
if authVal != true {
|
||||||
t.Errorf("wanted true, got %v", val)
|
t.Errorf("wanted true, got %v", authVal)
|
||||||
|
}
|
||||||
|
claimsVal, err := middleware.claims("1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, ok := claimsVal["iat"]; !ok {
|
||||||
|
t.Errorf("wanted a claims set, got %v", claimsVal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewJWTMiddlewareNoConfig(t *testing.T) {
|
func TestNewJWTMiddlewareNoConfig(t *testing.T) {
|
||||||
cases := map[*Config]error{
|
cases := map[*Config]error{
|
||||||
nil: ErrMissingConfig,
|
nil: ErrMissingConfig,
|
||||||
&Config{}: ErrMissingSecret,
|
&Config{}: ErrMissingSecret,
|
||||||
&Config{Auth: authFunc}: ErrMissingSecret,
|
&Config{
|
||||||
&Config{Secret: "secret"}: ErrMissingAuthFunc,
|
Auth: authFunc,
|
||||||
|
Claims: claimsFunc,
|
||||||
|
}: ErrMissingSecret,
|
||||||
|
&Config{
|
||||||
|
Secret: "secret",
|
||||||
|
Claims: claimsFunc,
|
||||||
|
}: ErrMissingAuthFunc,
|
||||||
|
&Config{
|
||||||
|
Auth: authFunc,
|
||||||
|
Secret: "secret",
|
||||||
|
}: ErrMissingClaimsFunc,
|
||||||
}
|
}
|
||||||
for config, jwtErr := range cases {
|
for config, jwtErr := range cases {
|
||||||
_, err := NewMiddleware(config)
|
_, err := NewMiddleware(config)
|
||||||
|
@ -78,15 +107,36 @@ func TestGenerateTokenHandler(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
|
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
|
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
|
||||||
respBody, err := ioutil.ReadAll(resp.Body)
|
respBody, err := ioutil.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if string(respBody) != "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" {
|
|
||||||
t.Errorf("wanted eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9, got %v", string(respBody))
|
j := strings.Split(string(respBody), ".")
|
||||||
|
|
||||||
|
header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`))
|
||||||
|
if j[0] != header {
|
||||||
|
t.Errorf("wanted %v, got %v", header, j[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := base64.StdEncoding.DecodeString(j[1])
|
||||||
|
var c struct {
|
||||||
|
Exp int
|
||||||
|
Iat int
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(claims, &c)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
duration := time.Duration(c.Exp-c.Iat) * time.Second
|
||||||
|
d := time.Minute * 60 * 24
|
||||||
|
if duration != d {
|
||||||
|
t.Errorf("wanted %v, got %v", d, duration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue