diff --git a/helpers.go b/helpers.go index 2be7d1e..c96f8e8 100644 --- a/helpers.go +++ b/helpers.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "time" "github.com/lib/pq" @@ -8,24 +9,24 @@ import ( // ListOptions specifies general pagination options for fetching a list of results type ListOptions struct { - PerPage int `url:",omitempty" json:",omitempty"` - Page int `url:",omitempty" json:",omitempty"` - Ids []int `url:",omitempty" json:",omitempty" schema:"ids[]"` + PerPage int64 `url:",omitempty" json:",omitempty"` + Page int64 `url:",omitempty" json:",omitempty"` + Ids []int64 `url:",omitempty" json:",omitempty" schema:"ids[]"` Genus string } -func (o ListOptions) PageOrDefault() int { +func (o ListOptions) PageOrDefault() int64 { if o.Page <= 0 { return 1 } return o.Page } -func (o ListOptions) Offset() int { +func (o ListOptions) Offset() int64 { return (o.PageOrDefault() - 1) * o.PerPageOrDefault() } -func (o ListOptions) PerPageOrDefault() int { +func (o ListOptions) PerPageOrDefault() int64 { if o.PerPage <= 0 { return DefaultPerPage } @@ -43,3 +44,18 @@ func currentTime() NullTime { }, } } + +func valsIn(attribute string, values []int64, vals *[]interface{}, counter *int64) string { + if len(values) == 1 { + return fmt.Sprintf("%v=%v", attribute, values[0]) + } + + m := fmt.Sprintf("%v IN (", attribute) + for _, id := range values { + m = m + fmt.Sprintf("$%v,", *counter) + *vals = append(*vals, id) + *counter++ + } + m = m[:len(m)-1] + ")" + return m +} diff --git a/measurements.go b/measurements.go index 3232a1f..7c29205 100644 --- a/measurements.go +++ b/measurements.go @@ -3,9 +3,7 @@ package main import ( "encoding/json" "errors" - "fmt" "net/url" - "strings" "time" ) @@ -70,8 +68,8 @@ func (m MeasurementService) list(val *url.Values) (entity, error) { } var opt struct { ListOptions - Strain *int64 - Characteristic *int64 + Strains []int64 `schema:"strain[]"` + Characteristics []int64 `schema:"characteristic[]"` } if err := schemaDecoder.Decode(&opt, *val); err != nil { return nil, err @@ -90,49 +88,38 @@ func (m MeasurementService) list(val *url.Values) (entity, error) { LEFT OUTER JOIN test_methods te ON te.id=m.test_method_id` vals = append(vals, opt.Genus) - strainId := opt.Strain != nil - charId := opt.Characteristic != nil + strainIds := len(opt.Strains) != 0 + charIds := len(opt.Characteristics) != 0 ids := len(opt.Ids) != 0 - if strainId || charId || ids { - paramsCounter := 2 + if strainIds || charIds || ids { + var paramsCounter int64 = 2 sql += "\nWHERE (" - // Filter by strain - if strainId { - sql += fmt.Sprintf("st.id=$%v", paramsCounter) - vals = append(vals, *opt.Strain) - paramsCounter++ + // Filter by strains + if strainIds { + sStr := valsIn("st.id", opt.Strains, &vals, ¶msCounter) + sql += sStr } - if strainId && (charId || ids) { + if strainIds && (charIds || ids) { sql += " AND " } - // Filter by characteristic - if charId { - sql += fmt.Sprintf("c.id=$%v", paramsCounter) - vals = append(vals, *opt.Characteristic) - paramsCounter++ + // Filter by characteristics + if charIds { + sChar := valsIn("c.id", opt.Characteristics, &vals, ¶msCounter) + sql += sChar } - if charId && ids { + if charIds && ids { sql += " AND " } // Get specific records if ids { - var conds []string - - m := "m.id IN (" - for _, id := range opt.Ids { - m = m + fmt.Sprintf("$%v,", paramsCounter) - vals = append(vals, id) - paramsCounter++ - } - m = m[:len(m)-1] + ")" - conds = append(conds, m) - sql += strings.Join(conds, ") AND (") + sId := valsIn("m.id", opt.Ids, &vals, ¶msCounter) + sql += sId } sql += ")" }