diff --git a/api/handler.go b/api/handler.go index 479bc27..fcd6d59 100644 --- a/api/handler.go +++ b/api/handler.go @@ -62,6 +62,7 @@ func Handler() *mux.Router { m.Get(router.UnitType).Handler(handler(serveUnitType)) m.Get(router.CreateUnitType).Handler(handler(serveCreateUnitType)) m.Get(router.UnitTypes).Handler(handler(serveUnitTypeList)) + m.Get(router.UpdateUnitType).Handler(handler(serveUpdateUnitType)) return m } diff --git a/api/unit_types.go b/api/unit_types.go index 34d6bb6..306be97 100644 --- a/api/unit_types.go +++ b/api/unit_types.go @@ -57,3 +57,22 @@ func serveUnitTypeList(w http.ResponseWriter, r *http.Request) error { return writeJSON(w, unit_types) } + +func serveUpdateUnitType(w http.ResponseWriter, r *http.Request) error { + id, _ := strconv.ParseInt(mux.Vars(r)["Id"], 10, 0) + var unit_type models.UnitType + err := json.NewDecoder(r.Body).Decode(&unit_type) + if err != nil { + return err + } + + updated, err := store.UnitTypes.Update(id, &unit_type) + if err != nil { + return err + } + if updated { + w.WriteHeader(http.StatusOK) + } + + return writeJSON(w, unit_type) +} diff --git a/api/unit_types_test.go b/api/unit_types_test.go index 385732f..09229e6 100644 --- a/api/unit_types_test.go +++ b/api/unit_types_test.go @@ -94,3 +94,33 @@ func TestUnitType_List(t *testing.T) { t.Errorf("got unit_types %+v but wanted unit_types %+v", unit_types, want) } } + +func TestUnitType_Update(t *testing.T) { + setup() + + want := newUnitType() + + calledPut := false + store.UnitTypes.(*models.MockUnitTypesService).Update_ = func(id int64, unit_type *models.UnitType) (bool, error) { + if id != want.Id { + t.Errorf("wanted request for unit_type %d but got %d", want.Id, id) + } + if !normalizeDeepEqual(want, unit_type) { + t.Errorf("wanted request for unit_type %d but got %d", want, unit_type) + } + calledPut = true + return true, nil + } + + success, err := apiClient.UnitTypes.Update(want.Id, want) + if err != nil { + t.Fatal(err) + } + + if !calledPut { + t.Error("!calledPut") + } + if !success { + t.Error("!success") + } +} diff --git a/datastore/unit_types.go b/datastore/unit_types.go index 621ba60..a53c816 100644 --- a/datastore/unit_types.go +++ b/datastore/unit_types.go @@ -46,3 +46,26 @@ func (s *unitTypesStore) List(opt *models.UnitTypeListOptions) ([]*models.UnitTy } return unit_types, nil } + +func (s *unitTypesStore) Update(id int64, unit_type *models.UnitType) (bool, error) { + _, err := s.Get(id) + if err != nil { + return false, err + } + + if id != unit_type.Id { + return false, models.ErrUnitTypeNotFound + } + + unit_type.UpdatedAt = time.Now() + changed, err := s.dbh.Update(unit_type) + if err != nil { + return false, err + } + + if changed == 0 { + return false, ErrNoRowsUpdated + } + + return true, nil +} diff --git a/datastore/unit_types_test.go b/datastore/unit_types_test.go index 7133a17..3675785 100644 --- a/datastore/unit_types_test.go +++ b/datastore/unit_types_test.go @@ -85,3 +85,23 @@ func TestUnitTypesStore_List_db(t *testing.T) { t.Errorf("got unit_types %+v, want %+v", unit_types, want) } } + +func TestUnitTypesStore_Update_db(t *testing.T) { + tx, _ := DB.Begin() + defer tx.Rollback() + + unit_type := insertUnitType(t, tx) + + d := NewDatastore(tx) + + // Tweak it + unit_type.Name = "Updated Unit Type" + updated, err := d.UnitTypes.Update(unit_type.Id, unit_type) + if err != nil { + t.Fatal(err) + } + + if !updated { + t.Error("!updated") + } +} diff --git a/models/unit_types.go b/models/unit_types.go index 30ca352..8fb3478 100644 --- a/models/unit_types.go +++ b/models/unit_types.go @@ -36,6 +36,9 @@ type UnitTypesService interface { // Create a unit type Create(unit_type *UnitType) (bool, error) + + // Update a unit type + Update(id int64, UnitType *UnitType) (bool, error) } var ( @@ -111,10 +114,32 @@ func (s *unitTypesService) List(opt *UnitTypeListOptions) ([]*UnitType, error) { return unit_types, nil } +func (s *unitTypesService) Update(id int64, unit_type *UnitType) (bool, error) { + strId := strconv.FormatInt(id, 10) + + url, err := s.client.url(router.UpdateUnitType, map[string]string{"Id": strId}, nil) + if err != nil { + return false, err + } + + req, err := s.client.NewRequest("PUT", url.String(), unit_type) + if err != nil { + return false, err + } + + resp, err := s.client.Do(req, &unit_type) + if err != nil { + return false, err + } + + return resp.StatusCode == http.StatusOK, nil +} + type MockUnitTypesService struct { Get_ func(id int64) (*UnitType, error) List_ func(opt *UnitTypeListOptions) ([]*UnitType, error) Create_ func(unit_type *UnitType) (bool, error) + Update_ func(id int64, unit_type *UnitType) (bool, error) } var _ UnitTypesService = &MockUnitTypesService{} @@ -139,3 +164,10 @@ func (s *MockUnitTypesService) List(opt *UnitTypeListOptions) ([]*UnitType, erro } return s.List_(opt) } + +func (s *MockUnitTypesService) Update(id int64, unit_type *UnitType) (bool, error) { + if s.Update_ == nil { + return false, nil + } + return s.Update_(id, unit_type) +} diff --git a/models/unit_types_test.go b/models/unit_types_test.go index 1bcd154..1e76a43 100644 --- a/models/unit_types_test.go +++ b/models/unit_types_test.go @@ -112,3 +112,34 @@ func TestUnitTypeService_List(t *testing.T) { t.Errorf("UnitTypes.List return %+v, want %+v", unit_types, want) } } + +func TestUnitTypeService_Update(t *testing.T) { + setup() + defer teardown() + + want := newUnitType() + + var called bool + mux.HandleFunc(urlPath(t, router.UpdateUnitType, map[string]string{"Id": "1"}), func(w http.ResponseWriter, r *http.Request) { + called = true + testMethod(t, r, "PUT") + testBody(t, r, `{"id":1,"name":"Test Unit Type Updated","symbol":"x","createdAt":"0001-01-01T00:00:00Z","updatedAt":"0001-01-01T00:00:00Z","deletedAt":{"Time":"0001-01-01T00:00:00Z","Valid":false}}`+"\n") + w.WriteHeader(http.StatusOK) + writeJSON(w, want) + }) + + unit_type := newUnitType() + unit_type.Name = "Test Unit Type Updated" + updated, err := client.UnitTypes.Update(unit_type.Id, unit_type) + if err != nil { + t.Errorf("UnitTypes.Update returned error: %v", err) + } + + if !updated { + t.Error("!updated") + } + + if !called { + t.Fatal("!called") + } +} diff --git a/router/api.go b/router/api.go index 78a4e41..9c9954f 100644 --- a/router/api.go +++ b/router/api.go @@ -56,6 +56,7 @@ func API() *mux.Router { m.Path("/unit_types/").Methods("GET").Name(UnitTypes) m.Path("/unit_types/").Methods("POST").Name(CreateUnitType) m.Path("/unit_types/{Id:.+}").Methods("GET").Name(UnitType) + m.Path("/unit_types/{Id:.+}").Methods("PUT").Name(UpdateUnitType) return m } diff --git a/router/routes.go b/router/routes.go index 031d42c..29e7317 100644 --- a/router/routes.go +++ b/router/routes.go @@ -44,4 +44,5 @@ const ( UnitType = "unit_type:get" CreateUnitType = "unit_type:create" UnitTypes = "unit_type:list" + UpdateUnitType = "unit_type:update" )