Verify header in auth handler

This commit is contained in:
Matthew Dillon 2015-04-18 13:33:26 -08:00
parent f9557a80a3
commit e0241b074f
2 changed files with 107 additions and 34 deletions

37
jwt.go
View file

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"log"
"net/http"
"strings"
)
@ -16,6 +17,9 @@ var (
ErrMissingAuthFunc = errors.New("please provide an auth function")
ErrMissingClaimsFunc = errors.New("please provide a claims function")
ErrEncoding = errors.New("error encoding value")
ErrMissingToken = errors.New("please provide a token")
ErrMalformedToken = errors.New("please provide a valid token")
ErrDecodingHeader = errors.New("could not decode JOSE header")
)
type Config struct {
@ -56,8 +60,35 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
}
func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
// This is just a placeholder for now
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, ErrMissingToken.Error(), http.StatusUnauthorized)
return
}
token := strings.Split(authHeader, " ")[1]
if strings.LastIndex(token, ".") == -1 {
http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized)
return
}
// Verify JOSE header
var t struct {
Typ string
Alg string
}
tokenParts := strings.Split(token, ".")
header, err := decode(tokenParts[0])
if err != nil {
log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0])
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = json.Unmarshal(header, &t)
if err != nil {
log.Printf("error (%v) while unmarshalling header (%s)", err, header)
http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError)
return
}
h.ServeHTTP(w, r)
})
}
@ -122,3 +153,7 @@ func encode(s interface{}) (string, error) {
}
return base64.StdEncoding.EncodeToString(r), nil
}
func decode(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}

View file

@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -43,6 +44,29 @@ func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
return middleware
}
func newToken(t *testing.T) (string, *JWTMiddleware) {
middleware := newJWTMiddlewareOrFatal(t)
authBody := map[string]interface{}{
"email": "user@example.com",
"password": "password",
}
body, err := json.Marshal(authBody)
if err != nil {
t.Error(err)
}
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
defer ts.Close()
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
respBody, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Error(err)
}
return string(respBody), middleware
}
func TestNewJWTMiddleware(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t)
if middleware.secret != "password" {
@ -88,39 +112,9 @@ func TestNewJWTMiddlewareNoConfig(t *testing.T) {
}
}
}
func TestSecureHandler(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t)
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil)
middleware.Secure(testHandler).ServeHTTP(resp, req)
if resp.Body.String() != "test" {
t.Errorf("wanted test, got %v", resp.Body.String())
}
}
func TestGenerateTokenHandler(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t)
authBody := map[string]interface{}{
"email": "user@example.com",
"password": "password",
}
body, err := json.Marshal(authBody)
if err != nil {
t.Error(err)
}
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
defer ts.Close()
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
respBody, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Error(err)
}
j := strings.Split(string(respBody), ".")
token, m := newToken(t)
j := strings.Split(token, ".")
header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`))
if j[0] != header {
@ -141,7 +135,7 @@ func TestGenerateTokenHandler(t *testing.T) {
if duration != d {
t.Errorf("wanted %v, got %v", d, duration)
}
mac := hmac.New(sha256.New, []byte(middleware.secret))
mac := hmac.New(sha256.New, []byte(m.secret))
message := []byte(strings.Join([]string{j[0], j[1]}, "."))
mac.Write(message)
expectedMac := base64.StdEncoding.EncodeToString(mac.Sum(nil))
@ -149,3 +143,47 @@ func TestGenerateTokenHandler(t *testing.T) {
t.Errorf("wanted %v, got %v", expectedMac, j[2])
}
}
func TestSecureHandlerNoToken(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t)
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil)
middleware.Secure(testHandler).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String())
if body != ErrMissingToken.Error() {
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
}
}
func TestSecureHandlerBadToken(t *testing.T) {
middleware := newJWTMiddlewareOrFatal(t)
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", "Bearer abcdefg")
middleware.Secure(testHandler).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String())
if body != ErrMalformedToken.Error() {
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
}
resp = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
middleware.Secure(testHandler).ServeHTTP(resp, req)
body = strings.TrimSpace(resp.Body.String())
if body != ErrMalformedToken.Error() {
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
}
}
func TestSecureHandlerGoodToken(t *testing.T) {
token, middleware := newToken(t)
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
middleware.Secure(testHandler).ServeHTTP(resp, req)
body := strings.TrimSpace(resp.Body.String())
if body != "test" {
t.Errorf("wanted %s, got %s", "test", body)
}
}