diff --git a/models_test.go b/models_test.go index a53dd61..d443378 100644 --- a/models_test.go +++ b/models_test.go @@ -155,3 +155,24 @@ func (bc *BadComment) JSONAPILinks() *Links { "self": []string{"invalid", "should error"}, } } + +type Company struct { + ID string `jsonapi:"primary,companies"` + Name string `jsonapi:"attr,name"` + Boss Employee `jsonapi:"attr,boss"` + Teams []Team `jsonapi:"attr,teams"` + FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"` +} + +type Team struct { + Name string `jsonapi:"attr,name"` + Leader *Employee `jsonapi:"attr,leader"` + Members []Employee `jsonapi:"attr,members"` +} + +type Employee struct { + Firstname string `jsonapi:"attr,firstname"` + Surname string `jsonapi:"attr,surname"` + Age int `jsonapi:"attr,age"` + HiredAt *time.Time `jsonapi:"attr,hired-at,iso8601"` +} diff --git a/request.go b/request.go index 104eb78..46863e9 100644 --- a/request.go +++ b/request.go @@ -117,13 +117,6 @@ func UnmarshalManyPayload(in io.Reader, t reflect.Type) ([]interface{}, error) { return models, nil } -type unmarshal struct { - attribute interface{} - args []string - fieldType reflect.StructField - fieldValue reflect.Value -} - func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) (err error) { defer func() { @@ -147,7 +140,6 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) fieldValue := modelValue.Field(i) args := strings.Split(tag, ",") - if len(args) < 1 { er = ErrBadJSONAPIStructTag break @@ -264,14 +256,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) continue } - data := unmarshal{ - attribute, - args, - fieldType, - fieldValue, - } - - value, err := unmarshalAttribute(data) + value, err := unmarshalAttribute(attribute, args, fieldType.Type, fieldValue) if err != nil { er = err break @@ -378,36 +363,54 @@ func assign(field, value reflect.Value) { } } -func unmarshalAttribute(data unmarshal) (value reflect.Value, err error) { +func unmarshalAttribute(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (value reflect.Value, err error) { - value = reflect.ValueOf(data.attribute) + value = reflect.ValueOf(attribute) // Handle field of type []string - if data.fieldValue.Type() == reflect.TypeOf([]string{}) { - value, err = handleStringSlice(data) + if fieldValue.Type() == reflect.TypeOf([]string{}) { + value, err = handleStringSlice(attribute, args, fieldType, fieldValue) return } // Handle field of type time.Time - if data.fieldValue.Type() == reflect.TypeOf(time.Time{}) || data.fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - value, err = handleTime(data) + if fieldValue.Type() == reflect.TypeOf(time.Time{}) || fieldValue.Type() == reflect.TypeOf(new(time.Time)) { + value, err = handleTime(attribute, args, fieldType, fieldValue) + return + } + + // Handle field of type struct + if fieldValue.Type().Kind() == reflect.Struct { + value, err = handleStruct(attribute, args, fieldType, fieldValue) + return + } + + // Handle field of type struct + if fieldValue.Type().Kind() == reflect.Struct { + value, err = handleStruct(attribute, args, fieldType, fieldValue) + return + } + + // Handle field containing slice of structs + if fieldValue.Type().Kind() == reflect.Slice && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { + value, err = handleStructSlice(attribute, args, fieldType, fieldValue) return } // JSON value was a float (numeric) if value.Kind() == reflect.Float64 { - value, err = handleNumeric(data) + value, err = handleNumeric(attribute, args, fieldType, fieldValue) return } // Field was a Pointer type - if data.fieldValue.Kind() == reflect.Ptr { - value, err = handlePointer(data) + if fieldValue.Kind() == reflect.Ptr { + value, err = handlePointer(attribute, args, fieldType, fieldValue) return } // As a final catch-all, ensure types line up to avoid a runtime panic. - if data.fieldValue.Kind() != value.Kind() { + if fieldValue.Kind() != value.Kind() { err = ErrInvalidType return } @@ -415,8 +418,8 @@ func unmarshalAttribute(data unmarshal) (value reflect.Value, err error) { return } -func handleStringSlice(data unmarshal) (reflect.Value, error) { - v := reflect.ValueOf(data.attribute) +func handleStringSlice(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { + v := reflect.ValueOf(attribute) values := make([]string, v.Len()) for i := 0; i < v.Len(); i++ { values[i] = v.Index(i).Interface().(string) @@ -425,13 +428,13 @@ func handleStringSlice(data unmarshal) (reflect.Value, error) { return reflect.ValueOf(values), nil } -func handleTime(data unmarshal) (reflect.Value, error) { +func handleTime(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { var isIso8601 bool - v := reflect.ValueOf(data.attribute) + v := reflect.ValueOf(attribute) - if len(data.args) > 2 { - for _, arg := range data.args[2:] { + if len(args) > 2 { + for _, arg := range args[2:] { if arg == annotationISO8601 { isIso8601 = true } @@ -451,7 +454,7 @@ func handleTime(data unmarshal) (reflect.Value, error) { return reflect.ValueOf(time.Now()), ErrInvalidISO8601 } - if data.fieldValue.Kind() == reflect.Ptr { + if fieldValue.Kind() == reflect.Ptr { return reflect.ValueOf(&t), nil } @@ -473,15 +476,15 @@ func handleTime(data unmarshal) (reflect.Value, error) { return reflect.ValueOf(t), nil } -func handleNumeric(data unmarshal) (reflect.Value, error) { - v := reflect.ValueOf(data.attribute) +func handleNumeric(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { + v := reflect.ValueOf(attribute) floatValue := v.Interface().(float64) var kind reflect.Kind - if data.fieldValue.Kind() == reflect.Ptr { - kind = data.fieldType.Type.Elem().Kind() + if fieldValue.Kind() == reflect.Ptr { + kind = fieldType.Elem().Kind() } else { - kind = data.fieldType.Type.Kind() + kind = fieldType.Kind() } var numericValue reflect.Value @@ -530,11 +533,11 @@ func handleNumeric(data unmarshal) (reflect.Value, error) { return numericValue, nil } -func handlePointer(data unmarshal) (reflect.Value, error) { - t := data.fieldValue.Type() +func handlePointer(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { + t := fieldValue.Type() var concreteVal reflect.Value - switch cVal := data.attribute.(type) { + switch cVal := attribute.(type) { case string: concreteVal = reflect.ValueOf(&cVal) case bool: @@ -555,3 +558,68 @@ func handlePointer(data unmarshal) (reflect.Value, error) { return concreteVal, nil } + +func handleStruct(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { + model := reflect.New(fieldValue.Type()) + + modelValue := model.Elem() + modelType := model.Type().Elem() + + var er error + + for i := 0; i < modelValue.NumField(); i++ { + fieldType := modelType.Field(i) + tag := fieldType.Tag.Get("jsonapi") + if tag == "" { + continue + } + + fieldValue := modelValue.Field(i) + + args := strings.Split(tag, ",") + + if len(args) < 1 { + er = ErrBadJSONAPIStructTag + break + } + + if reflect.TypeOf(attribute).Kind() != reflect.Map { + return model, nil + } + + attributes := reflect.ValueOf(attribute).Interface().(map[string]interface{}) + attribute := attributes[args[1]] + + if attribute == nil { + continue + } + + value, err := unmarshalAttribute(attribute, args, fieldType.Type, fieldValue) + if err != nil { + return model, nil + } + + assign(fieldValue, value) + } + + return model, er +} + +func handleStructSlice(attribute interface{}, args []string, fieldType reflect.Type, fieldValue reflect.Value) (reflect.Value, error) { + models := reflect.New(fieldValue.Type()).Elem() + dataMap := reflect.ValueOf(attribute).Interface().([]interface{}) + for _, data := range dataMap { + model := reflect.New(fieldValue.Type().Elem()).Elem() + modelType := model.Type() + + value, err := handleStruct(data, []string{}, modelType, model) + + if err != nil { + continue + } + + models = reflect.Append(models, reflect.Indirect(value)) + } + + return models, nil +} diff --git a/request_test.go b/request_test.go index 6b47fd7..616adbf 100644 --- a/request_test.go +++ b/request_test.go @@ -945,3 +945,109 @@ func sampleSerializedEmbeddedTestModel() *Blog { return blog } + +func TestUnmarshalNestedStruct(t *testing.T) { + + boss := map[string]interface{}{ + "firstname": "Hubert", + "surname": "Farnsworth", + "age": 176, + "hired-at": "2016-08-17T08:27:12Z", + } + + sample := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "companies", + "id": "123", + "attributes": map[string]interface{}{ + "name": "Planet Express", + "boss": boss, + "founded-at": "2016-08-17T08:27:12Z", + }, + }, + } + + data, err := json.Marshal(sample) + if err != nil { + t.Fatal(err) + } + in := bytes.NewReader(data) + out := new(Company) + + if err := UnmarshalPayload(in, out); err != nil { + t.Fatal(err) + } + + if out.Boss.Firstname != "Hubert" { + t.Fatalf("Nested struct was not unmarshalled") + } + + if out.Boss.Age != 176 { + t.Fatalf("Nested struct was not unmarshalled") + } + + if out.Boss.HiredAt.IsZero() { + t.Fatalf("Nested struct was not unmarshalled") + } +} + +func TestUnmarshalNestedStructSlice(t *testing.T) { + + fry := map[string]interface{}{ + "firstname": "Philip J.", + "surname": "Fry", + "age": 25, + "hired-at": "2016-08-17T08:27:12Z", + } + + bender := map[string]interface{}{ + "firstname": "Bender Bending", + "surname": "Rodriguez", + "age": 19, + "hired-at": "2016-08-17T08:27:12Z", + } + + deliveryCrew := map[string]interface{}{ + "name": "Delivery Crew", + "members": []interface{}{ + fry, + bender, + }, + } + + sample := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "companies", + "id": "123", + "attributes": map[string]interface{}{ + "name": "Planet Express", + "teams": []interface{}{ + deliveryCrew, + }, + }, + }, + } + + data, err := json.Marshal(sample) + if err != nil { + t.Fatal(err) + } + in := bytes.NewReader(data) + out := new(Company) + + if err := UnmarshalPayload(in, out); err != nil { + t.Fatal(err) + } + + if out.Teams[0].Name != "Delivery Crew" { + t.Fatalf("Nested struct Team was not unmarshalled") + } + + if len(out.Teams[0].Members) != 2 { + t.Fatalf("Nested struct Members were not unmarshalled") + } + + if out.Teams[0].Members[0].Firstname != "Philip J." { + t.Fatalf("Nested struct member was not unmarshalled") + } +}