diff --git a/models_test.go b/models_test.go index d443378..2d4aae4 100644 --- a/models_test.go +++ b/models_test.go @@ -176,3 +176,18 @@ type Employee struct { Age int `jsonapi:"attr,age"` HiredAt *time.Time `jsonapi:"attr,hired-at,iso8601"` } + +type CustomIntType int +type CustomFloatType float64 +type CustomStringType string + +type CustomAttributeTypes struct { + ID string `jsonapi:"primary,customtypes"` + + Int CustomIntType `jsonapi:"attr,int"` + IntPtr *CustomIntType `jsonapi:"attr,intptr"` + IntPtrNull *CustomIntType `jsonapi:"attr,intptrnull"` + + Float CustomFloatType `jsonapi:"attr,float"` + String CustomStringType `jsonapi:"attr,string"` +} diff --git a/request.go b/request.go index b9883f2..a830bb1 100644 --- a/request.go +++ b/request.go @@ -253,9 +253,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) break } - assign(fieldValue, value) - continue - + // As a final catch-all, ensure types line up to avoid a runtime panic. + if fieldValue.Kind() != v.Kind() { + return ErrInvalidType + } + assignValue(fieldValue, reflect.ValueOf(val)) } else if annotation == annotationRelation { isSlice := fieldValue.Type().Kind() == reflect.Slice @@ -347,10 +349,36 @@ func fullNode(n *Node, included *map[string]*Node) *Node { // assign will take the value specified and assign it to the field; if // field is expecting a ptr assign will assign a ptr. func assign(field, value reflect.Value) { + value = reflect.Indirect(value) + if field.Kind() == reflect.Ptr { - field.Set(value) + // initialize pointer so it's value + // can be set by assignValue + field.Set(reflect.New(field.Type().Elem())) + assignValue(field.Elem(), value) } else { - field.Set(reflect.Indirect(value)) + assignValue(field, value) + } +} + +// assign assigns the specified value to the field, +// expecting both values not to be pointer types. +func assignValue(field, value reflect.Value) { + switch field.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, + reflect.Int32, reflect.Int64: + field.SetInt(value.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64, reflect.Uintptr: + field.SetUint(value.Uint()) + case reflect.Float32, reflect.Float64: + field.SetFloat(value.Float()) + case reflect.String: + field.SetString(value.String()) + case reflect.Bool: + field.SetBool(value.Bool()) + default: + field.Set(value) } } @@ -588,7 +616,6 @@ func handleStruct( return reflect.Value{}, err } - return model, nil } diff --git a/request_test.go b/request_test.go index 111b5fb..2a8e48b 100644 --- a/request_test.go +++ b/request_test.go @@ -768,6 +768,54 @@ func TestManyPayload_withLinks(t *testing.T) { } } +func TestUnmarshalCustomTypeAttributes(t *testing.T) { + customInt := CustomIntType(5) + customFloat := CustomFloatType(1.5) + customString := CustomStringType("Test") + + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "customtypes", + "id": "1", + "attributes": map[string]interface{}{ + "int": customInt, + "intptr": &customInt, + "intptrnull": nil, + + "float": customFloat, + "string": customString, + }, + }, + } + payload, err := payload(data) + if err != nil { + t.Fatal(err) + } + + // Parse JSON API payload + customAttributeTypes := new(CustomAttributeTypes) + if err := UnmarshalPayload(bytes.NewReader(payload), customAttributeTypes); err != nil { + t.Fatal(err) + } + + if expected, actual := customInt, customAttributeTypes.Int; expected != actual { + t.Fatalf("Was expecting custom int to be `%s`, got `%s`", expected, actual) + } + if expected, actual := customInt, *customAttributeTypes.IntPtr; expected != actual { + t.Fatalf("Was expecting custom int pointer to be `%s`, got `%s`", expected, actual) + } + if customAttributeTypes.IntPtrNull != nil { + t.Fatalf("Was expecting custom int pointer to be , got `%s`", customAttributeTypes.IntPtrNull) + } + + if expected, actual := customFloat, customAttributeTypes.Float; expected != actual { + t.Fatalf("Was expecting custom float to be `%s`, got `%s`", expected, actual) + } + if expected, actual := customString, customAttributeTypes.String; expected != actual { + t.Fatalf("Was expecting custom string to be `%s`, got `%s`", expected, actual) + } +} + func samplePayloadWithoutIncluded() map[string]interface{} { return map[string]interface{}{ "data": map[string]interface{}{