Verify claims

This commit is contained in:
Matthew Dillon 2015-04-18 14:36:22 -08:00
parent 16c379b2c9
commit 9e2ad61d1c
2 changed files with 43 additions and 9 deletions

25
jwt.go
View file

@ -31,7 +31,9 @@ type Config struct {
type AuthFunc func(string, string) (bool, error) type AuthFunc func(string, string) (bool, error)
type ClaimsFunc func(id string) (map[string]interface{}, error) type ClaimsFunc func(string) (map[string]interface{}, error)
type VerifyClaimsFunc func([]byte) (bool, error)
type JWTMiddleware struct { type JWTMiddleware struct {
secret string secret string
@ -60,7 +62,7 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
return m, nil return m, nil
} }
func (m *JWTMiddleware) Secure(h http.Handler) http.Handler { func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
@ -74,7 +76,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
} }
tokenParts := strings.Split(token, ".") tokenParts := strings.Split(token, ".")
// Verify JOSE header // First, verify JOSE header
var t struct { var t struct {
Typ string Typ string
Alg string Alg string
@ -92,7 +94,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
return return
} }
// Verify signature // Then, verify signature
mac := hmac.New(sha256.New, []byte(m.secret)) mac := hmac.New(sha256.New, []byte(m.secret))
message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, "."))
mac.Write(message) mac.Write(message)
@ -106,6 +108,21 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized) http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized)
return return
} }
// Finally, check claims
claimSet, err := decode(tokenParts[1])
if err != nil {
panic(err)
return
}
claimsTest, err := v(claimSet)
if !claimsTest {
log.Printf("test: %v, error: %v", claimsTest, err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// If we make it this far, process the downstream handler
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) })
} }

View file

@ -6,6 +6,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -31,6 +32,22 @@ var claimsFunc = func(id string) (map[string]interface{}, error) {
}, nil }, nil
} }
var verifyClaimsFunc = func(claims []byte) (bool, error) {
currentTime := time.Now()
var c struct {
Exp int64
Iat int64
}
err := json.Unmarshal(claims, &c)
if err != nil {
return false, err
}
if currentTime.After(time.Unix(c.Exp, 0)) {
return false, errors.New("expired")
}
return true, nil
}
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
config := &Config{ config := &Config{
Secret: "password", Secret: "password",
@ -148,7 +165,7 @@ func TestSecureHandlerNoToken(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t) middleware := newJWTMiddlewareOrFatal(t)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
middleware.Secure(testHandler).ServeHTTP(resp, req) middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String()) body := strings.TrimSpace(resp.Body.String())
if body != ErrMissingToken.Error() { if body != ErrMissingToken.Error() {
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body) t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
@ -160,7 +177,7 @@ func TestSecureHandlerBadToken(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", "Bearer abcdefg") req.Header.Set("Authorization", "Bearer abcdefg")
middleware.Secure(testHandler).ServeHTTP(resp, req) middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String()) body := strings.TrimSpace(resp.Body.String())
if body != ErrMalformedToken.Error() { if body != ErrMalformedToken.Error() {
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
@ -169,7 +186,7 @@ func TestSecureHandlerBadToken(t *testing.T) {
resp = httptest.NewRecorder() resp = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "http://example.com", nil) req, _ = http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd") req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
middleware.Secure(testHandler).ServeHTTP(resp, req) middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
body = strings.TrimSpace(resp.Body.String()) body = strings.TrimSpace(resp.Body.String())
if body != ErrMalformedToken.Error() { if body != ErrMalformedToken.Error() {
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
@ -183,7 +200,7 @@ func TestSecureHandlerBadSignature(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
middleware.Secure(testHandler).ServeHTTP(resp, req) middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String()) body := strings.TrimSpace(resp.Body.String())
if body != ErrInvalidSignature.Error() { if body != ErrInvalidSignature.Error() {
t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body) t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body)
@ -195,7 +212,7 @@ func TestSecureHandlerGoodToken(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
middleware.Secure(testHandler).ServeHTTP(resp, req) middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String()) body := strings.TrimSpace(resp.Body.String())
if body != "test" { if body != "test" {
t.Errorf("wanted %s, got %s", "test", body) t.Errorf("wanted %s, got %s", "test", body)