diff --git a/examples/net-http.go b/examples/net-http.go index 45097e0..cf74d0a 100644 --- a/examples/net-http.go +++ b/examples/net-http.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "net/http" "time" @@ -13,8 +14,12 @@ func protectMe(w http.ResponseWriter, r *http.Request) { } func main() { - var authFunc = func(string, string) (bool, error) { - return true, nil + var authFunc = func(email string, password string) error { + // Hard-code a user + if email != "test" || password != "test" { + return errors.New("invalid credentials") + } + return nil } var claimsFunc = func(string) (map[string]interface{}, error) { @@ -25,8 +30,9 @@ func main() { }, nil } - var verifyClaimsFunc = func([]byte) (bool, error) { - return true, nil + var verifyClaimsFunc = func([]byte) error { + // We don't really care about the claims, just approve as-is + return nil } config := &jwt.Config{ diff --git a/jwt.go b/jwt.go index ca155b3..cd75025 100644 --- a/jwt.go +++ b/jwt.go @@ -12,15 +12,16 @@ import ( ) var ( - ErrMissingConfig = errors.New("missing configuration") - ErrMissingSecret = errors.New("please provide a shared secret") - ErrMissingAuthFunc = errors.New("please provide an auth function") - ErrMissingClaimsFunc = errors.New("please provide a claims function") - ErrEncoding = errors.New("error encoding value") - ErrMissingToken = errors.New("please provide a token") - ErrMalformedToken = errors.New("please provide a valid token") - ErrDecodingHeader = errors.New("could not decode JOSE header") - ErrInvalidSignature = errors.New("signature could not be verified") + ErrMissingConfig = errors.New("missing configuration") + ErrMissingSecret = errors.New("please provide a shared secret") + ErrMissingAuthFunc = errors.New("please provide an auth function") + ErrMissingClaimsFunc = errors.New("please provide a claims function") + ErrEncoding = errors.New("error encoding value") + ErrMissingToken = errors.New("please provide a token") + ErrMalformedToken = errors.New("please provide a valid token") + ErrDecodingHeader = errors.New("could not decode JOSE header") + ErrInvalidSignature = errors.New("signature could not be verified") + ErrParsingCredentials = errors.New("error parsing credentials") ) type Config struct { @@ -29,11 +30,11 @@ type Config struct { Claims ClaimsFunc } -type AuthFunc func(string, string) (bool, error) +type AuthFunc func(string, string) error type ClaimsFunc func(string) (map[string]interface{}, error) -type VerifyClaimsFunc func([]byte) (bool, error) +type VerifyClaimsFunc func([]byte) error type JWTMiddleware struct { secret string @@ -70,11 +71,11 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler return } token := strings.Split(authHeader, " ")[1] - if strings.LastIndex(token, ".") == -1 { + tokenParts := strings.Split(token, ".") + if len(tokenParts) != 3 { http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized) return } - tokenParts := strings.Split(token, ".") // First, verify JOSE header var t struct { @@ -115,9 +116,9 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler panic(err) return } - claimsTest, err := v(claimSet) - if !claimsTest { - log.Printf("test: %v, error: %v", claimsTest, err) + err = v(claimSet) + if err != nil { + log.Printf("claims error: %v", err) http.Error(w, err.Error(), http.StatusUnauthorized) return } @@ -131,14 +132,15 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) { var b map[string]string err := json.NewDecoder(r.Body).Decode(&b) if err != nil { - panic(err) + log.Printf("error (%v) while parsing authorization", err) + http.Error(w, ErrParsingCredentials.Error(), http.StatusInternalServerError) + return } - result, err := m.auth(b["email"], b["password"]) + err = m.auth(b["email"], b["password"]) if err != nil { - panic(err) - } - if !result { - panic("deal with this") + log.Printf("error (%v) while performing authorization", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return } // For now, the header will be static diff --git a/jwt_test.go b/jwt_test.go index 3c3cc51..d71adf4 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -20,8 +20,8 @@ var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) w.Write([]byte("test")) }) -var authFunc = func(email, password string) (bool, error) { - return true, nil +var authFunc = func(email, password string) error { + return nil } var claimsFunc = func(id string) (map[string]interface{}, error) { @@ -32,7 +32,7 @@ var claimsFunc = func(id string) (map[string]interface{}, error) { }, nil } -var verifyClaimsFunc = func(claims []byte) (bool, error) { +var verifyClaimsFunc = func(claims []byte) error { currentTime := time.Now() var c struct { Exp int64 @@ -40,12 +40,12 @@ var verifyClaimsFunc = func(claims []byte) (bool, error) { } err := json.Unmarshal(claims, &c) if err != nil { - return false, err + return err } if currentTime.After(time.Unix(c.Exp, 0)) { - return false, errors.New("expired") + return errors.New("expired") } - return true, nil + return nil } func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware { @@ -89,13 +89,10 @@ func TestNewJWTMiddleware(t *testing.T) { if middleware.secret != "password" { t.Errorf("wanted password, got %v", middleware.secret) } - authVal, err := middleware.auth("", "") + err := middleware.auth("", "") if err != nil { t.Fatal(err) } - if authVal != true { - t.Errorf("wanted true, got %v", authVal) - } claimsVal, err := middleware.claims("1") if err != nil { t.Fatal(err)