Verify header in auth handler
This commit is contained in:
parent
f9557a80a3
commit
e0241b074f
2 changed files with 107 additions and 34 deletions
37
jwt.go
37
jwt.go
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
@ -16,6 +17,9 @@ var (
|
|||
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
||||
ErrMissingClaimsFunc = errors.New("please provide a claims function")
|
||||
ErrEncoding = errors.New("error encoding value")
|
||||
ErrMissingToken = errors.New("please provide a token")
|
||||
ErrMalformedToken = errors.New("please provide a valid token")
|
||||
ErrDecodingHeader = errors.New("could not decode JOSE header")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
@ -56,8 +60,35 @@ func NewMiddleware(c *Config) (*JWTMiddleware, error) {
|
|||
}
|
||||
|
||||
func (m *JWTMiddleware) Secure(h http.Handler) http.Handler {
|
||||
// This is just a placeholder for now
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, ErrMissingToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
token := strings.Split(authHeader, " ")[1]
|
||||
if strings.LastIndex(token, ".") == -1 {
|
||||
http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// Verify JOSE header
|
||||
var t struct {
|
||||
Typ string
|
||||
Alg string
|
||||
}
|
||||
tokenParts := strings.Split(token, ".")
|
||||
header, err := decode(tokenParts[0])
|
||||
if err != nil {
|
||||
log.Printf("error (%v) while decoding header (%v)", err, tokenParts[0])
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(header, &t)
|
||||
if err != nil {
|
||||
log.Printf("error (%v) while unmarshalling header (%s)", err, header)
|
||||
http.Error(w, ErrMalformedToken.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
@ -122,3 +153,7 @@ func encode(s interface{}) (string, error) {
|
|||
}
|
||||
return base64.StdEncoding.EncodeToString(r), nil
|
||||
}
|
||||
|
||||
func decode(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
|
104
jwt_test.go
104
jwt_test.go
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -43,6 +44,29 @@ func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
|||
return middleware
|
||||
}
|
||||
|
||||
func newToken(t *testing.T) (string, *JWTMiddleware) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
authBody := map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"password": "password",
|
||||
}
|
||||
body, err := json.Marshal(authBody)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
return string(respBody), middleware
|
||||
}
|
||||
|
||||
func TestNewJWTMiddleware(t *testing.T) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
if middleware.secret != "password" {
|
||||
|
@ -88,39 +112,9 @@ func TestNewJWTMiddlewareNoConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureHandler(t *testing.T) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
||||
if resp.Body.String() != "test" {
|
||||
t.Errorf("wanted test, got %v", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokenHandler(t *testing.T) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
authBody := map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"password": "password",
|
||||
}
|
||||
body, err := json.Marshal(authBody)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(middleware.GenerateToken))
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body))
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
j := strings.Split(string(respBody), ".")
|
||||
token, m := newToken(t)
|
||||
j := strings.Split(token, ".")
|
||||
|
||||
header := base64.StdEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`))
|
||||
if j[0] != header {
|
||||
|
@ -141,7 +135,7 @@ func TestGenerateTokenHandler(t *testing.T) {
|
|||
if duration != d {
|
||||
t.Errorf("wanted %v, got %v", d, duration)
|
||||
}
|
||||
mac := hmac.New(sha256.New, []byte(middleware.secret))
|
||||
mac := hmac.New(sha256.New, []byte(m.secret))
|
||||
message := []byte(strings.Join([]string{j[0], j[1]}, "."))
|
||||
mac.Write(message)
|
||||
expectedMac := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
|
@ -149,3 +143,47 @@ func TestGenerateTokenHandler(t *testing.T) {
|
|||
t.Errorf("wanted %v, got %v", expectedMac, j[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureHandlerNoToken(t *testing.T) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
||||
body := strings.TrimSpace(resp.Body.String())
|
||||
if body != ErrMissingToken.Error() {
|
||||
t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureHandlerBadToken(t *testing.T) {
|
||||
middleware := newJWTMiddlewareOrFatal(t)
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("Authorization", "Bearer abcdefg")
|
||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
||||
body := strings.TrimSpace(resp.Body.String())
|
||||
if body != ErrMalformedToken.Error() {
|
||||
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
||||
}
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("Authorization", "Bearer abcd.abcd.abcd")
|
||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
||||
body = strings.TrimSpace(resp.Body.String())
|
||||
if body != ErrMalformedToken.Error() {
|
||||
t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureHandlerGoodToken(t *testing.T) {
|
||||
token, middleware := newToken(t)
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
middleware.Secure(testHandler).ServeHTTP(resp, req)
|
||||
body := strings.TrimSpace(resp.Body.String())
|
||||
if body != "test" {
|
||||
t.Errorf("wanted %s, got %s", "test", body)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue