diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 0d0d486..1dcd5db 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -54,7 +54,7 @@ }, { "ImportPath": "github.com/thermokarst/jwt", - "Rev": "7752009bbb5cea39ab392a846c467eab4b98478f" + "Rev": "88ac9569ee8c8fc9083704a7219334fcc210c6a5" }, { "ImportPath": "golang.org/x/crypto/bcrypt", diff --git a/Godeps/_workspace/src/github.com/thermokarst/jwt/jwt.go b/Godeps/_workspace/src/github.com/thermokarst/jwt/jwt.go index a42742b..cbb2515 100644 --- a/Godeps/_workspace/src/github.com/thermokarst/jwt/jwt.go +++ b/Godeps/_workspace/src/github.com/thermokarst/jwt/jwt.go @@ -115,64 +115,12 @@ func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { } else { token = strings.Split(authHeader, " ")[1] } - tokenParts := strings.Split(token, ".") - if len(tokenParts) != 3 { - return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken} - } - // First, verify JOSE header - header, err := decode(tokenParts[0]) - if err != nil { + if status, err, message := m.VerifyToken(token, v, r); err != nil { return &jwtError{ - status: http.StatusInternalServerError, + status: status, err: err, - message: fmt.Sprintf("decoding header (%v)", tokenParts[0]), - } - } - var t struct { - Typ string - Alg string - } - err = json.Unmarshal(header, &t) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrMalformedToken, - message: fmt.Sprintf("unmarshalling header (%s)", header), - } - } - - // Then, verify signature - mac := hmac.New(sha256.New, []byte(m.secret)) - message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) - mac.Write(message) - expectedMac, err := encode(mac.Sum(nil)) - if err != nil { - return &jwtError{status: http.StatusInternalServerError, err: err} - } - if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { - return &jwtError{ - status: http.StatusUnauthorized, - err: ErrInvalidSignature, - message: fmt.Sprintf("checking signature (%v)", tokenParts[2]), - } - } - - // Finally, check claims - claimSet, err := decode(tokenParts[1]) - if err != nil { - return &jwtError{ - status: http.StatusInternalServerError, - err: ErrDecoding, - message: "decoding claims", - } - } - err = v(claimSet, r) - if err != nil { - return &jwtError{ - status: http.StatusUnauthorized, - err: err, - message: "handling claims callback", + message: message, } } @@ -279,6 +227,52 @@ func (m *Middleware) CreateToken(identity string) (string, error) { return response, nil } +// VerifyToken verifies a token +func (m *Middleware) VerifyToken(token string, v VerifyClaimsFunc, r *http.Request) (int, error, string) { + tokenParts := strings.Split(token, ".") + if len(tokenParts) != 3 { + return http.StatusUnauthorized, ErrMalformedToken, "" + } + + // First, verify JOSE header + header, err := decode(tokenParts[0]) + if err != nil { + return http.StatusInternalServerError, err, fmt.Sprintf("decoding header (%v)", tokenParts[0]) + } + var t struct { + Typ string + Alg string + } + err = json.Unmarshal(header, &t) + if err != nil { + return http.StatusInternalServerError, ErrMalformedToken, fmt.Sprintf("unmarshalling header (%s)", header) + } + + // Then, verify signature + mac := hmac.New(sha256.New, []byte(m.secret)) + message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) + mac.Write(message) + expectedMac, err := encode(mac.Sum(nil)) + if err != nil { + return http.StatusInternalServerError, err, "" + } + if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { + return http.StatusUnauthorized, ErrInvalidSignature, fmt.Sprintf("checking signature (%v)", tokenParts[2]) + } + + // Finally, check claims + claimSet, err := decode(tokenParts[1]) + if err != nil { + return http.StatusInternalServerError, ErrDecoding, "decoding claims" + } + err = v(claimSet, r) + if err != nil { + return http.StatusUnauthorized, err, "handling claims callback" + } + + return 200, nil, "" +} + type jwtError struct { status int err error diff --git a/handlers.go b/handlers.go index 73dfd21..258abc6 100644 --- a/handlers.go +++ b/handlers.go @@ -33,6 +33,20 @@ type Claims struct { Ref string } +func verifyClaims(claims []byte, r *http.Request) error { + currentTime := time.Now() + var c Claims + err := json.Unmarshal(claims, &c) + if err != nil { + return err + } + if currentTime.After(time.Unix(c.Exp, 0)) { + return errors.New("this token has expired") + } + context.Set(r, "claims", c) + return nil +} + func Handler() http.Handler { claimsFunc := func(email string) (map[string]interface{}, error) { currentTime := time.Now() @@ -46,25 +60,11 @@ func Handler() http.Handler { "sub": user.Id, "role": user.Role, "iat": currentTime.Unix(), - "exp": currentTime.Add(time.Minute * 60 * 24).Unix(), + "exp": currentTime.Add(time.Minute * 60).Unix(), "ref": "", }, nil } - verifyClaims := func(claims []byte, r *http.Request) error { - currentTime := time.Now() - var c Claims - err := json.Unmarshal(claims, &c) - if err != nil { - return err - } - if currentTime.After(time.Unix(c.Exp, 0)) { - return errors.New("this token has expired") - } - context.Set(r, "claims", c) - return nil - } - config = &jwt.Config{ Secret: os.Getenv("SECRET"), Auth: dbAuthenticate, @@ -85,6 +85,7 @@ func Handler() http.Handler { measurementService := MeasurementService{} m.Handle("/authenticate", tokenHandler(j.Authenticate())).Methods("POST") + m.Handle("/refresh", j.Secure(errorHandler(tokenRefresh(j)), verifyClaims)).Methods("POST") // Everything past here is lumped under a genus s := m.PathPrefix("/{genus}").Subrouter() @@ -307,3 +308,25 @@ func (fn errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, err.Error.Error()) } } + +func tokenRefresh(j *jwt.Middleware) errorHandler { + t := func(w http.ResponseWriter, r *http.Request) *appError { + claims := getClaims(r) + user, err := dbGetUserById(claims.Sub) + if err != nil { + return newJSONError(err, http.StatusInternalServerError) + } + token, err := j.CreateToken(user.Email) + if err != nil { + return newJSONError(err, http.StatusInternalServerError) + } + data, _ := json.Marshal(struct { + Token string `json:"token"` + }{ + Token: token, + }) + w.Write(data) + return nil + } + return t +} diff --git a/users.go b/users.go index d0c7701..3f1fd22 100644 --- a/users.go +++ b/users.go @@ -138,20 +138,11 @@ func (u UserService) list(val *url.Values, claims *Claims) (entity, *appError) { } func (u UserService) get(id int64, dummy string, claims *Claims) (entity, *appError) { - var user User - q := `SELECT id, email, 'password' AS password, name, role, - created_at, updated_at, deleted_at - FROM users - WHERE id=$1 - AND verified IS TRUE - AND deleted_at IS NULL;` - if err := DBH.SelectOne(&user, q, id); err != nil { - if err == sql.ErrNoRows { - return nil, ErrUserNotFoundJSON - } + user, err := dbGetUserById(id) + if err != nil { return nil, newJSONError(err, http.StatusInternalServerError) } - return &user, nil + return user, nil } func (u UserService) update(id int64, e *entity, dummy string, claims *Claims) *appError { @@ -245,6 +236,23 @@ func dbAuthenticate(email string, password string) error { return nil } +func dbGetUserById(id int64) (*User, error) { + var user User + q := `SELECT id, email, 'password' AS password, name, role, + created_at, updated_at, deleted_at + FROM users + WHERE id=$1 + AND verified IS TRUE + AND deleted_at IS NULL;` + if err := DBH.SelectOne(&user, q, id); err != nil { + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + return nil, err + } + return &user, nil +} + // for thermokarst/jwt: setting user in claims bundle func dbGetUserByEmail(email string) (*User, error) { var user User