Verify claims
This commit is contained in:
parent
16c379b2c9
commit
9e2ad61d1c
2 changed files with 43 additions and 9 deletions
25
jwt.go
25
jwt.go
|
@ -31,7 +31,9 @@ type Config struct {
|
||||||
|
|
||||||
type AuthFunc func(string, string) (bool, error)
|
type AuthFunc func(string, string) (bool, error)
|
||||||
|
|
||||||
type ClaimsFunc func(id string) (map[string]interface{}, error)
|
type ClaimsFunc func(string) (map[string]interface{}, error)
|
||||||
|
|
||||||
|
type VerifyClaimsFunc func([]byte) (bool, error)
|
||||||
|
|
||||||
type JWTMiddleware struct {
|
type JWTMiddleware struct {
|
||||||
secret string
|
secret string
|
||||||
|
@ -60,7 +62,7 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
|
func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
|
@ -74,7 +76,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
tokenParts := strings.Split(token, ".")
|
tokenParts := strings.Split(token, ".")
|
||||||
|
|
||||||
// Verify JOSE header
|
// First, verify JOSE header
|
||||||
var t struct {
|
var t struct {
|
||||||
Typ string
|
Typ string
|
||||||
Alg string
|
Alg string
|
||||||
|
@ -92,7 +94,7 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// Then, verify signature
|
||||||
mac := hmac.New(sha256.New, []byte(m.secret))
|
mac := hmac.New(sha256.New, []byte(m.secret))
|
||||||
message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, "."))
|
message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, "."))
|
||||||
mac.Write(message)
|
mac.Write(message)
|
||||||
|
@ -106,6 +108,21 @@ func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
|
||||||
http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized)
|
http.Error(w, ErrInvalidSignature.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finally, check claims
|
||||||
|
claimSet, err := decode(tokenParts[1])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
claimsTest, err := v(claimSet)
|
||||||
|
if !claimsTest {
|
||||||
|
log.Printf("test: %v, error: %v", claimsTest, err)
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we make it this far, process the downstream handler
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
27
jwt_test.go
27
jwt_test.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -31,6 +32,22 @@ var claimsFunc = func(id string) (map[string]interface{}, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var verifyClaimsFunc = func(claims []byte) (bool, error) {
|
||||||
|
currentTime := time.Now()
|
||||||
|
var c struct {
|
||||||
|
Exp int64
|
||||||
|
Iat int64
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(claims, &c)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if currentTime.After(time.Unix(c.Exp, 0)) {
|
||||||
|
return false, errors.New("expired")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Secret: "password",
|
Secret: "password",
|
||||||
|
@ -148,7 +165,7 @@ func TestSecureHandlerNoToken(t *testing.T) {
|
||||||
middleware := newJWTMiddlewareOrFatal(t)
|
middleware := newJWTMiddlewareOrFatal(t)
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
||||||
body := strings.TrimSpace(resp.Body.String())
|
body := strings.TrimSpace(resp.Body.String())
|
||||||
if body != ErrMissingToken.Error() {
|
if body != ErrMissingToken.Error() {
|
||||||
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
|
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
|
||||||
|
@ -160,7 +177,7 @@ func TestSecureHandlerBadToken(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
req.Header.Set("Authorization", "Bearer abcdefg")
|
req.Header.Set("Authorization", "Bearer abcdefg")
|
||||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
||||||
body := strings.TrimSpace(resp.Body.String())
|
body := strings.TrimSpace(resp.Body.String())
|
||||||
if body != ErrMalformedToken.Error() {
|
if body != ErrMalformedToken.Error() {
|
||||||
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
||||||
|
@ -169,7 +186,7 @@ func TestSecureHandlerBadToken(t *testing.T) {
|
||||||
resp = httptest.NewRecorder()
|
resp = httptest.NewRecorder()
|
||||||
req, _ = http.NewRequest("GET", "http://example.com", nil)
|
req, _ = http.NewRequest("GET", "http://example.com", nil)
|
||||||
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
|
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
|
||||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
||||||
body = strings.TrimSpace(resp.Body.String())
|
body = strings.TrimSpace(resp.Body.String())
|
||||||
if body != ErrMalformedToken.Error() {
|
if body != ErrMalformedToken.Error() {
|
||||||
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
||||||
|
@ -183,7 +200,7 @@ func TestSecureHandlerBadSignature(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
||||||
body := strings.TrimSpace(resp.Body.String())
|
body := strings.TrimSpace(resp.Body.String())
|
||||||
if body != ErrInvalidSignature.Error() {
|
if body != ErrInvalidSignature.Error() {
|
||||||
t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body)
|
t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body)
|
||||||
|
@ -195,7 +212,7 @@ func TestSecureHandlerGoodToken(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
||||||
body := strings.TrimSpace(resp.Body.String())
|
body := strings.TrimSpace(resp.Body.String())
|
||||||
if body != "test" {
|
if body != "test" {
|
||||||
t.Errorf("wanted %s, got %s", "test", body)
|
t.Errorf("wanted %s, got %s", "test", body)
|
||||||
|
|
Loading…
Add table
Reference in a new issue