diff --git a/api/handler.go b/api/handler.go index 928f165..63005cf 100644 --- a/api/handler.go +++ b/api/handler.go @@ -38,6 +38,7 @@ func Handler() *mux.Router { m.Get(router.Strain).Handler(handler(serveStrain)) m.Get(router.CreateStrain).Handler(handler(serveCreateStrain)) m.Get(router.Strains).Handler(handler(serveStrainList)) + m.Get(router.UpdateStrain).Handler(handler(serveUpdateStrain)) return m } diff --git a/api/strains.go b/api/strains.go index 159796b..a3bf4e6 100644 --- a/api/strains.go +++ b/api/strains.go @@ -57,3 +57,22 @@ func serveStrainList(w http.ResponseWriter, r *http.Request) error { return writeJSON(w, strains) } + +func serveUpdateStrain(w http.ResponseWriter, r *http.Request) error { + id, _ := strconv.ParseInt(mux.Vars(r)["Id"], 10, 0) + var strain models.Strain + err := json.NewDecoder(r.Body).Decode(&strain) + if err != nil { + return err + } + + updated, err := store.Strains.Update(id, &strain) + if err != nil { + return err + } + if updated { + w.WriteHeader(http.StatusOK) + } + + return writeJSON(w, strain) +} diff --git a/api/strains_test.go b/api/strains_test.go index d0dd5cb..6b931be 100644 --- a/api/strains_test.go +++ b/api/strains_test.go @@ -96,3 +96,33 @@ func TestStrain_List(t *testing.T) { t.Errorf("got strains %+v but wanted strains %+v", strains, want) } } + +func TestStrain_Update(t *testing.T) { + setup() + + want := newStrain() + + calledPut := false + store.Strains.(*models.MockStrainsService).Update_ = func(id int64, strain *models.Strain) (bool, error) { + if id != want.Id { + t.Errorf("wanted request for strain %d but got %d", want.Id, id) + } + if !normalizeDeepEqual(want, strain) { + t.Errorf("wanted request for strain %d but got %d", want, strain) + } + calledPut = true + return true, nil + } + + success, err := apiClient.Strains.Update(1, want) + if err != nil { + t.Fatal(err) + } + + if !calledPut { + t.Error("!calledPut") + } + if !success { + t.Error("!success") + } +} diff --git a/datastore/strains.go b/datastore/strains.go index 25c77ba..d2450e7 100644 --- a/datastore/strains.go +++ b/datastore/strains.go @@ -39,3 +39,25 @@ func (s *strainsStore) List(opt *models.StrainListOptions) ([]*models.Strain, er } return strains, nil } + +func (s *strainsStore) Update(id int64, strain *models.Strain) (bool, error) { + _, err := s.Get(id) + if err != nil { + return false, err + } + + if id != strain.Id { + return false, models.ErrStrainNotFound + } + + changed, err := s.dbh.Update(strain) + if err != nil { + return false, err + } + + if changed == 0 { + return false, ErrNoRowsUpdated + } + + return true, nil +} diff --git a/datastore/strains_test.go b/datastore/strains_test.go index bc2c6bc..2d66773 100644 --- a/datastore/strains_test.go +++ b/datastore/strains_test.go @@ -87,3 +87,23 @@ func TestStrainsStore_List_db(t *testing.T) { t.Errorf("got strains %+v, want %+v", strains, want) } } + +func TestStrainsStore_Update_db(t *testing.T) { + tx, _ := DB.Begin() + defer tx.Rollback() + + strain := insertStrain(t, tx) + + d := NewDatastore(tx) + + // Tweak it + strain.StrainName = "Updated Strain" + updated, err := d.Strains.Update(strain.Id, strain) + if err != nil { + t.Fatal(err) + } + + if !updated { + t.Error("!updated") + } +} diff --git a/models/strains.go b/models/strains.go index d6b044e..c57a56e 100644 --- a/models/strains.go +++ b/models/strains.go @@ -17,7 +17,7 @@ type Strain struct { StrainType string `db:"strain_type" json:"strain_type"` Etymology string `db:"etymology" json:"etymology"` AccessionBanks string `db:"accession_banks" json:"accession_banks"` - GenbankEmblDdb string `db:"genbank_embl_ddb" json:"genbank_eml_ddb"` + GenbankEmblDdb string `db:"genbank_embl_ddb" json:"genbank_embl_ddb"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` DeletedAt time.Time `db:"deleted_at" json:"deleted_at"` @@ -37,6 +37,9 @@ type StrainsService interface { // Create a strain record Create(strain *Strain) (bool, error) + + // Update an existing strain + Update(id int64, strain *Strain) (updated bool, err error) } var ( @@ -112,10 +115,32 @@ func (s *strainsService) List(opt *StrainListOptions) ([]*Strain, error) { return strains, nil } +func (s *strainsService) Update(id int64, strain *Strain) (bool, error) { + strId := strconv.FormatInt(id, 10) + + url, err := s.client.url(router.UpdateStrain, map[string]string{"Id": strId}, nil) + if err != nil { + return false, err + } + + req, err := s.client.NewRequest("PUT", url.String(), strain) + if err != nil { + return false, err + } + + resp, err := s.client.Do(req, &strain) + if err != nil { + return false, err + } + + return resp.StatusCode == http.StatusOK, nil +} + type MockStrainsService struct { Get_ func(id int64) (*Strain, error) List_ func(opt *StrainListOptions) ([]*Strain, error) Create_ func(strain *Strain) (bool, error) + Update_ func(id int64, strain *Strain) (bool, error) } var _ StrainsService = &MockStrainsService{} @@ -140,3 +165,10 @@ func (s *MockStrainsService) List(opt *StrainListOptions) ([]*Strain, error) { } return s.List_(opt) } + +func (s *MockStrainsService) Update(id int64, strain *Strain) (bool, error) { + if s.Update_ == nil { + return false, nil + } + return s.Update_(id, strain) +} diff --git a/models/strains_test.go b/models/strains_test.go index d0ee0e6..6531a05 100644 --- a/models/strains_test.go +++ b/models/strains_test.go @@ -55,7 +55,7 @@ func TestStrainService_Create(t *testing.T) { mux.HandleFunc(urlPath(t, router.CreateStrain, nil), func(w http.ResponseWriter, r *http.Request) { called = true testMethod(t, r, "POST") - testBody(t, r, `{"id":1,"species_id":1,"strain_name":"Test Strain","strain_type":"Test Type","etymology":"Test Etymology","accession_banks":"Test Accession","genbank_eml_ddb":"Test Genbank","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","deleted_at":"0001-01-01T00:00:00Z"}`+"\n") + testBody(t, r, `{"id":1,"species_id":1,"strain_name":"Test Strain","strain_type":"Test Type","etymology":"Test Etymology","accession_banks":"Test Accession","genbank_embl_ddb":"Test Genbank","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","deleted_at":"0001-01-01T00:00:00Z"}`+"\n") w.WriteHeader(http.StatusCreated) writeJSON(w, want) @@ -113,3 +113,38 @@ func TestStrainService_List(t *testing.T) { t.Errorf("Strains.List return %+v, want %+v", strains, want) } } + +func TestStrainService_Update(t *testing.T) { + setup() + defer teardown() + + want := newStrain() + + var called bool + mux.HandleFunc(urlPath(t, router.UpdateStrain, map[string]string{"Id": "1"}), func(w http.ResponseWriter, r *http.Request) { + called = true + testMethod(t, r, "PUT") + testBody(t, r, `{"id":1,"species_id":1,"strain_name":"Test Strain Updated","strain_type":"Test Type Updated","etymology":"Test Etymology Updated","accession_banks":"Test Accession Updated","genbank_embl_ddb":"Test Genbank Updated","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","deleted_at":"0001-01-01T00:00:00Z"}`+"\n") + w.WriteHeader(http.StatusOK) + writeJSON(w, want) + }) + + strain := newStrain() + strain.StrainName = "Test Strain Updated" + strain.StrainType = "Test Type Updated" + strain.Etymology = "Test Etymology Updated" + strain.AccessionBanks = "Test Accession Updated" + strain.GenbankEmblDdb = "Test Genbank Updated" + updated, err := client.Strains.Update(strain.Id, strain) + if err != nil { + t.Errorf("Strains.Update returned error: %v", err) + } + + if !updated { + t.Error("!updated") + } + + if !called { + t.Fatal("!called") + } +} diff --git a/router/api.go b/router/api.go index 7678abc..2f27ae3 100644 --- a/router/api.go +++ b/router/api.go @@ -28,6 +28,7 @@ func API() *mux.Router { m.Path("/strains").Methods("GET").Name(Strains) m.Path("/strains").Methods("POST").Name(CreateStrain) m.Path("/strains/{Id:.+}").Methods("GET").Name(Strain) + m.Path("/strains/{Id:.+}").Methods("PUT").Name(UpdateStrain) return m } diff --git a/router/routes.go b/router/routes.go index 068b179..8658634 100644 --- a/router/routes.go +++ b/router/routes.go @@ -20,4 +20,5 @@ const ( Strain = "strain:get" CreateStrain = "strain:create" Strains = "strain:list" + UpdateStrain = "strain:update" )