Misc cleanup (remove unneeded return vals)
This commit is contained in:
parent
273383cf89
commit
82980a6bac
3 changed files with 41 additions and 36 deletions
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -13,8 +14,12 @@ func protectMe(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var authFunc = func(string, string) (bool, error) {
|
var authFunc = func(email string, password string) error {
|
||||||
return true, nil
|
// Hard-code a user
|
||||||
|
if email != "test" || password != "test" {
|
||||||
|
return errors.New("invalid credentials")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var claimsFunc = func(string) (map[string]interface{}, error) {
|
var claimsFunc = func(string) (map[string]interface{}, error) {
|
||||||
|
@ -25,8 +30,9 @@ func main() {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var verifyClaimsFunc = func([]byte) (bool, error) {
|
var verifyClaimsFunc = func([]byte) error {
|
||||||
return true, nil
|
// We don't really care about the claims, just approve as-is
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &jwt.Config{
|
config := &jwt.Config{
|
||||||
|
|
46
jwt.go
46
jwt.go
|
@ -12,15 +12,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingConfig = errors.New("missing configuration")
|
ErrMissingConfig = errors.New("missing configuration")
|
||||||
ErrMissingSecret = errors.New("please provide a shared secret")
|
ErrMissingSecret = errors.New("please provide a shared secret")
|
||||||
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
ErrMissingAuthFunc = errors.New("please provide an auth function")
|
||||||
ErrMissingClaimsFunc = errors.New("please provide a claims function")
|
ErrMissingClaimsFunc = errors.New("please provide a claims function")
|
||||||
ErrEncoding = errors.New("error encoding value")
|
ErrEncoding = errors.New("error encoding value")
|
||||||
ErrMissingToken = errors.New("please provide a token")
|
ErrMissingToken = errors.New("please provide a token")
|
||||||
ErrMalformedToken = errors.New("please provide a valid token")
|
ErrMalformedToken = errors.New("please provide a valid token")
|
||||||
ErrDecodingHeader = errors.New("could not decode JOSE header")
|
ErrDecodingHeader = errors.New("could not decode JOSE header")
|
||||||
ErrInvalidSignature = errors.New("signature could not be verified")
|
ErrInvalidSignature = errors.New("signature could not be verified")
|
||||||
|
ErrParsingCredentials = errors.New("error parsing credentials")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -29,11 +30,11 @@ type Config struct {
|
||||||
Claims ClaimsFunc
|
Claims ClaimsFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthFunc func(string, string) (bool, error)
|
type AuthFunc func(string, string) error
|
||||||
|
|
||||||
type ClaimsFunc func(string) (map[string]interface{}, error)
|
type ClaimsFunc func(string) (map[string]interface{}, error)
|
||||||
|
|
||||||
type VerifyClaimsFunc func([]byte) (bool, error)
|
type VerifyClaimsFunc func([]byte) error
|
||||||
|
|
||||||
type JWTMiddleware struct {
|
type JWTMiddleware struct {
|
||||||
secret string
|
secret string
|
||||||
|
@ -70,11 +71,11 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token := strings.Split(authHeader, " ")[1]
|
token := strings.Split(authHeader, " ")[1]
|
||||||
if strings.LastIndex(token, ".") == -1 {
|
tokenParts := strings.Split(token, ".")
|
||||||
|
if len(tokenParts) != 3 {
|
||||||
http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized)
|
http.Error(w, ErrMalformedToken.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tokenParts := strings.Split(token, ".")
|
|
||||||
|
|
||||||
// First, verify JOSE header
|
// First, verify JOSE header
|
||||||
var t struct {
|
var t struct {
|
||||||
|
@ -115,9 +116,9 @@ func (m *JWTMiddleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler
|
||||||
panic(err)
|
panic(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
claimsTest, err := v(claimSet)
|
err = v(claimSet)
|
||||||
if !claimsTest {
|
if err != nil {
|
||||||
log.Printf("test: %v, error: %v", claimsTest, err)
|
log.Printf("claims error: %v", err)
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -131,14 +132,15 @@ func (m *JWTMiddleware) GenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
var b map[string]string
|
var b map[string]string
|
||||||
err := json.NewDecoder(r.Body).Decode(&b)
|
err := json.NewDecoder(r.Body).Decode(&b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Printf("error (%v) while parsing authorization", err)
|
||||||
|
http.Error(w, ErrParsingCredentials.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
result, err := m.auth(b["email"], b["password"])
|
err = m.auth(b["email"], b["password"])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Printf("error (%v) while performing authorization", err)
|
||||||
}
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
if !result {
|
return
|
||||||
panic("deal with this")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, the header will be static
|
// For now, the header will be static
|
||||||
|
|
17
jwt_test.go
17
jwt_test.go
|
@ -20,8 +20,8 @@ var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request)
|
||||||
w.Write([]byte("test"))
|
w.Write([]byte("test"))
|
||||||
})
|
})
|
||||||
|
|
||||||
var authFunc = func(email, password string) (bool, error) {
|
var authFunc = func(email, password string) error {
|
||||||
return true, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var claimsFunc = func(id string) (map[string]interface{}, error) {
|
var claimsFunc = func(id string) (map[string]interface{}, error) {
|
||||||
|
@ -32,7 +32,7 @@ var claimsFunc = func(id string) (map[string]interface{}, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var verifyClaimsFunc = func(claims []byte) (bool, error) {
|
var verifyClaimsFunc = func(claims []byte) error {
|
||||||
currentTime := time.Now()
|
currentTime := time.Now()
|
||||||
var c struct {
|
var c struct {
|
||||||
Exp int64
|
Exp int64
|
||||||
|
@ -40,12 +40,12 @@ var verifyClaimsFunc = func(claims []byte) (bool, error) {
|
||||||
}
|
}
|
||||||
err := json.Unmarshal(claims, &c)
|
err := json.Unmarshal(claims, &c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return err
|
||||||
}
|
}
|
||||||
if currentTime.After(time.Unix(c.Exp, 0)) {
|
if currentTime.After(time.Unix(c.Exp, 0)) {
|
||||||
return false, errors.New("expired")
|
return errors.New("expired")
|
||||||
}
|
}
|
||||||
return true, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
func newJWTMiddlewareOrFatal(t *testing.T) *JWTMiddleware {
|
||||||
|
@ -89,13 +89,10 @@ func TestNewJWTMiddleware(t *testing.T) {
|
||||||
if middleware.secret != "password" {
|
if middleware.secret != "password" {
|
||||||
t.Errorf("wanted password, got %v", middleware.secret)
|
t.Errorf("wanted password, got %v", middleware.secret)
|
||||||
}
|
}
|
||||||
authVal, err := middleware.auth("", "")
|
err := middleware.auth("", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if authVal != true {
|
|
||||||
t.Errorf("wanted true, got %v", authVal)
|
|
||||||
}
|
|
||||||
claimsVal, err := middleware.claims("1")
|
claimsVal, err := middleware.claims("1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
Loading…
Add table
Reference in a new issue