
StdEncoding is used for standard base64 encoding. The base64 alphabet includes '+', '/' and '=' characters which are not URL safe and not JWT compatible. Furthermore padding '=' characters are added, but the JWT defenition want it to be compact. Using RawURLEncoding instead of StdEncoding solve this issues.
247 lines
6.7 KiB
Go
247 lines
6.7 KiB
Go
package jwt
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("test"))
|
|
})
|
|
|
|
var authFunc = func(email, password string) error {
|
|
return nil
|
|
}
|
|
|
|
var claimsFunc = func(id string) (map[string]interface{}, error) {
|
|
currentTime := time.Now()
|
|
return map[string]interface{}{
|
|
"iat": currentTime.Unix(),
|
|
"exp": currentTime.Add(time.Minute * 60 * 24).Unix(),
|
|
}, nil
|
|
}
|
|
|
|
var verifyClaimsFunc = func(claims []byte, r *http.Request) error {
|
|
currentTime := time.Now()
|
|
var c struct {
|
|
Exp int64
|
|
Iat int64
|
|
}
|
|
err := json.Unmarshal(claims, &c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if currentTime.After(time.Unix(c.Exp, 0)) {
|
|
return errors.New("expired")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func newMiddlewareOrFatal(t *testing.T) *Middleware {
|
|
config := &Config{
|
|
Secret: "password",
|
|
Auth: authFunc,
|
|
Claims: claimsFunc,
|
|
}
|
|
middleware, err := New(config)
|
|
if err != nil {
|
|
t.Fatalf("new middleware: %v", err)
|
|
}
|
|
return middleware
|
|
}
|
|
|
|
func newToken(t *testing.T) (string, *Middleware) {
|
|
middleware := newMiddlewareOrFatal(t)
|
|
authBody := map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"password": "password",
|
|
}
|
|
body, err := json.Marshal(authBody)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
ts := httptest.NewServer(middleware.Authenticate())
|
|
defer ts.Close()
|
|
|
|
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
|
|
respBody, err := ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
return string(respBody), middleware
|
|
}
|
|
|
|
func TestNewJWTMiddleware(t *testing.T) {
|
|
middleware := newMiddlewareOrFatal(t)
|
|
if middleware.secret != "password" {
|
|
t.Errorf("wanted password, got %v", middleware.secret)
|
|
}
|
|
err := middleware.auth("", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
claimsVal, err := middleware.claims("1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, ok := claimsVal["iat"]; !ok {
|
|
t.Errorf("wanted a claims set, got %v", claimsVal)
|
|
}
|
|
if middleware.identityField != "email" {
|
|
t.Errorf("wanted email, got %v", middleware.identityField)
|
|
}
|
|
if middleware.verifyField != "password" {
|
|
t.Errorf("wanted password, got %v", middleware.verifyField)
|
|
}
|
|
}
|
|
|
|
func TestNewJWTMiddlewareNoConfig(t *testing.T) {
|
|
cases := map[*Config]error{
|
|
nil: ErrMissingConfig,
|
|
&Config{}: ErrMissingSecret,
|
|
&Config{
|
|
Auth: authFunc,
|
|
Claims: claimsFunc,
|
|
}: ErrMissingSecret,
|
|
&Config{
|
|
Secret: "secret",
|
|
Claims: claimsFunc,
|
|
}: ErrMissingAuthFunc,
|
|
&Config{
|
|
Auth: authFunc,
|
|
Secret: "secret",
|
|
}: ErrMissingClaimsFunc,
|
|
}
|
|
for config, jwtErr := range cases {
|
|
_, err := New(config)
|
|
if err != jwtErr {
|
|
t.Errorf("wanted error: %v, got error: %v using config: %v", jwtErr, err, config)
|
|
}
|
|
}
|
|
}
|
|
func TestGenerateTokenHandler(t *testing.T) {
|
|
token, m := newToken(t)
|
|
j := strings.Split(token, ".")
|
|
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`))
|
|
if j[0] != header {
|
|
t.Errorf("wanted %v, got %v", header, j[0])
|
|
}
|
|
|
|
claims, err := base64.RawURLEncoding.DecodeString(j[1])
|
|
var c struct {
|
|
Exp int
|
|
Iat int
|
|
}
|
|
err = json.Unmarshal(claims, &c)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
duration := time.Duration(c.Exp-c.Iat) * time.Second
|
|
d := time.Minute * 60 * 24
|
|
if duration != d {
|
|
t.Errorf("wanted %v, got %v", d, duration)
|
|
}
|
|
mac := hmac.New(sha256.New, []byte(m.secret))
|
|
message := []byte(strings.Join([]string{j[0], j[1]}, "."))
|
|
mac.Write(message)
|
|
expectedMac := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
if !hmac.Equal([]byte(j[2]), []byte(expectedMac)) {
|
|
t.Errorf("wanted %v, got %v", expectedMac, j[2])
|
|
}
|
|
}
|
|
|
|
func TestSecureHandlerNoToken(t *testing.T) {
|
|
middleware := newMiddlewareOrFatal(t)
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != ErrMissingToken.Error() {
|
|
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
|
|
}
|
|
}
|
|
|
|
func TestSecureHandlerBadToken(t *testing.T) {
|
|
middleware := newMiddlewareOrFatal(t)
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
req.Header.Set("Authorization", "Bearer abcdefg")
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != ErrMalformedToken.Error() {
|
|
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
|
}
|
|
|
|
resp = httptest.NewRecorder()
|
|
req, _ = http.NewRequest("GET", "http://example.com", nil)
|
|
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body = strings.TrimSpace(resp.Body.String())
|
|
if body != ErrMalformedToken.Error() {
|
|
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
|
}
|
|
}
|
|
|
|
func TestSecureHandlerBadSignature(t *testing.T) {
|
|
token, middleware := newToken(t)
|
|
parts := strings.Split(token, ".")
|
|
token = strings.Join([]string{parts[0], parts[1], "abcd"}, ".")
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != ErrInvalidSignature.Error() {
|
|
t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body)
|
|
}
|
|
}
|
|
|
|
func TestSecureHandlerGoodToken(t *testing.T) {
|
|
token, middleware := newToken(t)
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != "test" {
|
|
t.Errorf("wanted %s, got %s", "test", body)
|
|
}
|
|
}
|
|
|
|
func TestGenerateTokenHandlerNotPOST(t *testing.T) {
|
|
middleware := newMiddlewareOrFatal(t)
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("PUT", "http://example.com", nil)
|
|
middleware.Authenticate().ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != ErrInvalidMethod.Error() {
|
|
t.Errorf("wanted %q, got %q", ErrInvalidMethod.Error(), body)
|
|
}
|
|
}
|
|
|
|
func TestMalformedAuthorizationHeader(t *testing.T) {
|
|
_, middleware := newToken(t)
|
|
token := "hello!"
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
req.Header.Set("Authorization", token) // No "Bearer " portion of header
|
|
middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req)
|
|
body := strings.TrimSpace(resp.Body.String())
|
|
if body != ErrMalformedToken.Error() {
|
|
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
|
}
|
|
}
|