diff --git a/constants.go b/constants.go index 23288d3..d5cd0ea 100644 --- a/constants.go +++ b/constants.go @@ -10,6 +10,7 @@ const ( annotationOmitEmpty = "omitempty" annotationISO8601 = "iso8601" annotationSeperator = "," + annotationIgnore = "-" iso8601TimeFormat = "2006-01-02T15:04:05Z" diff --git a/node.go b/node.go index a58488c..73b7d59 100644 --- a/node.go +++ b/node.go @@ -44,6 +44,38 @@ type Node struct { Meta *Meta `json:"meta,omitempty"` } +func (n *Node) merge(node *Node) { + if node.Type != "" { + n.Type = node.Type + } + + if node.ID != "" { + n.ID = node.ID + } + + if node.ClientID != "" { + n.ClientID = node.ClientID + } + + if n.Attributes == nil && node.Attributes != nil { + n.Attributes = make(map[string]interface{}) + } + for k, v := range node.Attributes { + n.Attributes[k] = v + } + + if n.Relationships == nil && node.Relationships != nil { + n.Relationships = make(map[string]interface{}) + } + for k, v := range node.Relationships { + n.Relationships[k] = v + } + + if node.Links != nil { + n.Links = node.Links + } +} + // RelationshipOneNode is used to represent a generic has one JSON API relation type RelationshipOneNode struct { Data *Node `json:"data"` @@ -119,3 +151,35 @@ type RelationshipMetable interface { // JSONRelationshipMeta will be invoked for each relationship with the corresponding relation name (e.g. `comments`) JSONAPIRelationshipMeta(relation string) *Meta } + +// derefs the arg, and clones the map-type attributes +// note: maps are reference types, so they need an explicit copy. +func deepCopyNode(n *Node) *Node { + if n == nil { + return n + } + + copyMap := func(m map[string]interface{}) map[string]interface{} { + if m == nil { + return m + } + cp := make(map[string]interface{}) + for k, v := range m { + cp[k] = v + } + return cp + } + + copy := *n + copy.Attributes = copyMap(copy.Attributes) + copy.Relationships = copyMap(copy.Relationships) + if copy.Links != nil { + tmp := Links(copyMap(map[string]interface{}(*copy.Links))) + copy.Links = &tmp + } + if copy.Meta != nil { + tmp := Meta(copyMap(map[string]interface{}(*copy.Meta))) + copy.Meta = &tmp + } + return © +} diff --git a/request.go b/request.go index fe29706..9e0eb1a 100644 --- a/request.go +++ b/request.go @@ -117,6 +117,11 @@ func UnmarshalManyPayload(in io.Reader, t reflect.Type) ([]interface{}, error) { return models, nil } +// unmarshalNode handles embedded struct models from top to down. +// it loops through the struct fields, handles attributes/relations at that level first +// the handling the embedded structs are done last, so that you get the expected composition behavior +// data (*Node) attributes are cleared on each success. +// relations/sideloaded models use deeply copied Nodes (since those sideloaded models can be referenced in multiple relations) func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) (err error) { defer func() { if r := recover(); r != nil { @@ -127,416 +132,518 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) modelValue := model.Elem() modelType := model.Type().Elem() - var er error + type embedded struct { + structField, model reflect.Value + } + embeddeds := []*embedded{} for i := 0; i < modelValue.NumField(); i++ { fieldType := modelType.Field(i) - tag := fieldType.Tag.Get("jsonapi") + fieldValue := modelValue.Field(i) + tag := fieldType.Tag.Get(annotationJSONAPI) + + // handle explicit ignore annotation + if shouldIgnoreField(tag) { + continue + } + + // handles embedded structs + if isEmbeddedStruct(fieldType) { + embeddeds = append(embeddeds, + &embedded{ + model: reflect.ValueOf(fieldValue.Addr().Interface()), + structField: fieldValue, + }, + ) + continue + } + + // handles pointers to embedded structs + if isEmbeddedStructPtr(fieldType) { + embeddeds = append(embeddeds, + &embedded{ + model: reflect.ValueOf(fieldValue.Interface()), + structField: fieldValue, + }, + ) + continue + } + + // handle tagless; after handling embedded structs (which could be tagless) if tag == "" { continue } - fieldValue := modelValue.Field(i) - - args := strings.Split(tag, ",") - + args := strings.Split(tag, annotationSeperator) + // require atleast 1 if len(args) < 1 { - er = ErrBadJSONAPIStructTag - break + return ErrBadJSONAPIStructTag } - annotation := args[0] - - if (annotation == annotationClientID && len(args) != 1) || - (annotation != annotationClientID && len(args) < 2) { - er = ErrBadJSONAPIStructTag - break - } - - if annotation == annotationPrimary { - if data.ID == "" { - continue + // args[0] == annotation + switch args[0] { + case annotationClientID: + if err := handleClientIDUnmarshal(data, args, fieldValue); err != nil { + return err } - - // Check the JSON API Type - if data.Type != args[1] { - er = fmt.Errorf( - "Trying to Unmarshal an object of type %#v, but %#v does not match", - data.Type, - args[1], - ) - break + case annotationPrimary: + if err := handlePrimaryUnmarshal(data, args, fieldType, fieldValue); err != nil { + return err } - - // ID will have to be transmitted as astring per the JSON API spec - v := reflect.ValueOf(data.ID) - - // Deal with PTRS - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() + case annotationAttribute: + if err := handleAttributeUnmarshal(data, args, fieldType, fieldValue); err != nil { + return err } - - // Handle String case - if kind == reflect.String { - assign(fieldValue, v) - continue + case annotationRelation: + if err := handleRelationUnmarshal(data, args, fieldValue, included); err != nil { + return err } - - // Value was not a string... only other supported type was a numeric, - // which would have been sent as a float value. - floatValue, err := strconv.ParseFloat(data.ID, 64) - if err != nil { - // Could not convert the value in the "id" attr to a float - er = ErrBadJSONAPIID - break - } - - // Convert the numeric float to one of the supported ID numeric types - // (int[8,16,32,64] or uint[8,16,32,64]) - var idValue reflect.Value - switch kind { - case reflect.Int: - n := int(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int8: - n := int8(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int16: - n := int16(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int32: - n := int32(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int64: - n := int64(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint: - n := uint(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint8: - n := uint8(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint16: - n := uint16(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint32: - n := uint32(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint64: - n := uint64(floatValue) - idValue = reflect.ValueOf(&n) - default: - // We had a JSON float (numeric), but our field was not one of the - // allowed numeric types - er = ErrBadJSONAPIID - break - } - - assign(fieldValue, idValue) - } else if annotation == annotationClientID { - if data.ClientID == "" { - continue - } - - fieldValue.Set(reflect.ValueOf(data.ClientID)) - } else if annotation == annotationAttribute { - attributes := data.Attributes - if attributes == nil || len(data.Attributes) == 0 { - continue - } - - var iso8601 bool - - if len(args) > 2 { - for _, arg := range args[2:] { - if arg == annotationISO8601 { - iso8601 = true - } - } - } - - val := attributes[args[1]] - - // continue if the attribute was not included in the request - if val == nil { - continue - } - - v := reflect.ValueOf(val) - - // Handle field of type time.Time - if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - if iso8601 { - var tm string - if v.Kind() == reflect.String { - tm = v.Interface().(string) - } else { - er = ErrInvalidISO8601 - break - } - - t, err := time.Parse(iso8601TimeFormat, tm) - if err != nil { - er = ErrInvalidISO8601 - break - } - - fieldValue.Set(reflect.ValueOf(t)) - - continue - } - - var at int64 - - if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) - } else if v.Kind() == reflect.Int { - at = v.Int() - } else { - return ErrInvalidTime - } - - t := time.Unix(at, 0) - - fieldValue.Set(reflect.ValueOf(t)) - - continue - } - - if fieldValue.Type() == reflect.TypeOf([]string{}) { - values := make([]string, v.Len()) - for i := 0; i < v.Len(); i++ { - values[i] = v.Index(i).Interface().(string) - } - - fieldValue.Set(reflect.ValueOf(values)) - - continue - } - - if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - if iso8601 { - var tm string - if v.Kind() == reflect.String { - tm = v.Interface().(string) - } else { - er = ErrInvalidISO8601 - break - } - - v, err := time.Parse(iso8601TimeFormat, tm) - if err != nil { - er = ErrInvalidISO8601 - break - } - - t := &v - - fieldValue.Set(reflect.ValueOf(t)) - - continue - } - - var at int64 - - if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) - } else if v.Kind() == reflect.Int { - at = v.Int() - } else { - return ErrInvalidTime - } - - v := time.Unix(at, 0) - t := &v - - fieldValue.Set(reflect.ValueOf(t)) - - continue - } - - // JSON value was a float (numeric) - if v.Kind() == reflect.Float64 { - floatValue := v.Interface().(float64) - - // The field may or may not be a pointer to a numeric; the kind var - // will not contain a pointer type - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() - } - - var numericValue reflect.Value - - switch kind { - case reflect.Int: - n := int(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int8: - n := int8(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int16: - n := int16(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int32: - n := int32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int64: - n := int64(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint: - n := uint(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint8: - n := uint8(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint16: - n := uint16(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint32: - n := uint32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint64: - n := uint64(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Float32: - n := float32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Float64: - n := floatValue - numericValue = reflect.ValueOf(&n) - default: - return ErrUnknownFieldNumberType - } - - assign(fieldValue, numericValue) - continue - } - - // Field was a Pointer type - if fieldValue.Kind() == reflect.Ptr { - var concreteVal reflect.Value - - switch cVal := val.(type) { - case string: - concreteVal = reflect.ValueOf(&cVal) - case bool: - concreteVal = reflect.ValueOf(&cVal) - case complex64: - concreteVal = reflect.ValueOf(&cVal) - case complex128: - concreteVal = reflect.ValueOf(&cVal) - case uintptr: - concreteVal = reflect.ValueOf(&cVal) - default: - return ErrUnsupportedPtrType - } - - if fieldValue.Type() != concreteVal.Type() { - return ErrUnsupportedPtrType - } - - fieldValue.Set(concreteVal) - continue - } - - // As a final catch-all, ensure types line up to avoid a runtime panic. - if fieldValue.Kind() != v.Kind() { - return ErrInvalidType - } - fieldValue.Set(reflect.ValueOf(val)) - - } else if annotation == annotationRelation { - isSlice := fieldValue.Type().Kind() == reflect.Slice - - if data.Relationships == nil || data.Relationships[args[1]] == nil { - continue - } - - if isSlice { - // to-many relationship - relationship := new(RelationshipManyNode) - - buf := bytes.NewBuffer(nil) - - json.NewEncoder(buf).Encode(data.Relationships[args[1]]) - json.NewDecoder(buf).Decode(relationship) - - data := relationship.Data - models := reflect.New(fieldValue.Type()).Elem() - - for _, n := range data { - m := reflect.New(fieldValue.Type().Elem().Elem()) - - if err := unmarshalNode( - fullNode(n, included), - m, - included, - ); err != nil { - er = err - break - } - - models = reflect.Append(models, m) - } - - fieldValue.Set(models) - } else { - // to-one relationships - relationship := new(RelationshipOneNode) - - buf := bytes.NewBuffer(nil) - - json.NewEncoder(buf).Encode( - data.Relationships[args[1]], - ) - json.NewDecoder(buf).Decode(relationship) - - /* - http://jsonapi.org/format/#document-resource-object-relationships - http://jsonapi.org/format/#document-resource-object-linkage - relationship can have a data node set to null (e.g. to disassociate the relationship) - so unmarshal and set fieldValue only if data obj is not null - */ - if relationship.Data == nil { - continue - } - - m := reflect.New(fieldValue.Type().Elem()) - if err := unmarshalNode( - fullNode(relationship.Data, included), - m, - included, - ); err != nil { - er = err - break - } - - fieldValue.Set(m) - - } - - } else { - er = fmt.Errorf(unsuportedStructTagMsg, annotation) + default: + return fmt.Errorf(unsuportedStructTagMsg, args[0]) } } - return er + // handle embedded last + for _, em := range embeddeds { + // if nil, need to construct and rollback accordingly + if em.model.IsNil() { + copy := deepCopyNode(data) + tmp := reflect.New(em.model.Type().Elem()) + if err := unmarshalNode(copy, tmp, included); err != nil { + return err + } + + // had changes; assign value to struct field, replace orig node (data) w/ mutated copy + if !reflect.DeepEqual(copy, data) { + assign(em.structField, tmp) + data = copy + } + return nil + } + // handle non-nil scenarios + return unmarshalNode(data, em.model, included) + } + + return nil +} + +func handleClientIDUnmarshal(data *Node, args []string, fieldValue reflect.Value) error { + if len(args) != 1 { + return ErrBadJSONAPIStructTag + } + + if data.ClientID == "" { + return nil + } + + // set value and clear clientID to denote it's already been processed + fieldValue.Set(reflect.ValueOf(data.ClientID)) + data.ClientID = "" + + return nil +} + +func handlePrimaryUnmarshal(data *Node, args []string, fieldType reflect.StructField, fieldValue reflect.Value) error { + if len(args) < 2 { + return ErrBadJSONAPIStructTag + } + + if data.ID == "" { + return nil + } + + // Check the JSON API Type + if data.Type != args[1] { + return fmt.Errorf( + "Trying to Unmarshal an object of type %#v, but %#v does not match", + data.Type, + args[1], + ) + } + + // Deal with PTRS + var kind reflect.Kind + if fieldValue.Kind() == reflect.Ptr { + kind = fieldType.Type.Elem().Kind() + } else { + kind = fieldType.Type.Kind() + } + + var idValue reflect.Value + + // Handle String case + if kind == reflect.String { + // ID will have to be transmitted as a string per the JSON API spec + idValue = reflect.ValueOf(data.ID) + } else { + // Value was not a string... only other supported type was a numeric, + // which would have been sent as a float value. + floatValue, err := strconv.ParseFloat(data.ID, 64) + if err != nil { + // Could not convert the value in the "id" attr to a float + return ErrBadJSONAPIID + } + + // Convert the numeric float to one of the supported ID numeric types + // (int[8,16,32,64] or uint[8,16,32,64]) + switch kind { + case reflect.Int: + n := int(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Int8: + n := int8(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Int16: + n := int16(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Int32: + n := int32(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Int64: + n := int64(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Uint: + n := uint(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Uint8: + n := uint8(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Uint16: + n := uint16(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Uint32: + n := uint32(floatValue) + idValue = reflect.ValueOf(&n) + case reflect.Uint64: + n := uint64(floatValue) + idValue = reflect.ValueOf(&n) + default: + // We had a JSON float (numeric), but our field was not one of the + // allowed numeric types + return ErrBadJSONAPIID + } + } + + // set value and clear ID to denote it's already been processed + assign(fieldValue, idValue) + data.ID = "" + + return nil +} + +func handleRelationUnmarshal(data *Node, args []string, fieldValue reflect.Value, included *map[string]*Node) error { + if len(args) < 2 { + return ErrBadJSONAPIStructTag + } + + if data.Relationships == nil || data.Relationships[args[1]] == nil { + return nil + } + + // to-one relationships + handler := handleToOneRelationUnmarshal + isSlice := fieldValue.Type().Kind() == reflect.Slice + if isSlice { + // to-many relationship + handler = handleToManyRelationUnmarshal + } + + v, err := handler(data.Relationships[args[1]], fieldValue.Type(), included) + if err != nil { + return err + } + // set only if there is a val since val can be null (e.g. to disassociate the relationship) + if v != nil { + fieldValue.Set(*v) + } + delete(data.Relationships, args[1]) + return nil +} + +// to-one relationships +func handleToOneRelationUnmarshal(relationData interface{}, fieldType reflect.Type, included *map[string]*Node) (*reflect.Value, error) { + relationship := new(RelationshipOneNode) + + buf := bytes.NewBuffer(nil) + json.NewEncoder(buf).Encode(relationData) + json.NewDecoder(buf).Decode(relationship) + + m := reflect.New(fieldType.Elem()) + /* + http://jsonapi.org/format/#document-resource-object-relationships + http://jsonapi.org/format/#document-resource-object-linkage + relationship can have a data node set to null (e.g. to disassociate the relationship) + so unmarshal and set fieldValue only if data obj is not null + */ + if relationship.Data == nil { + return nil, nil + } + + if err := unmarshalNode( + fullNode(relationship.Data, included), + m, + included, + ); err != nil { + return nil, err + } + + return &m, nil +} + +// to-many relationship +func handleToManyRelationUnmarshal(relationData interface{}, fieldType reflect.Type, included *map[string]*Node) (*reflect.Value, error) { + relationship := new(RelationshipManyNode) + + buf := bytes.NewBuffer(nil) + json.NewEncoder(buf).Encode(relationData) + json.NewDecoder(buf).Decode(relationship) + + models := reflect.New(fieldType).Elem() + + rData := relationship.Data + for _, n := range rData { + m := reflect.New(fieldType.Elem().Elem()) + + if err := unmarshalNode( + fullNode(n, included), + m, + included, + ); err != nil { + return nil, err + } + + models = reflect.Append(models, m) + } + + return &models, nil +} + +// TODO: break this out into smaller funcs +func handleAttributeUnmarshal(data *Node, args []string, fieldType reflect.StructField, fieldValue reflect.Value) error { + if len(args) < 2 { + return ErrBadJSONAPIStructTag + } + attributes := data.Attributes + if attributes == nil || len(data.Attributes) == 0 { + return nil + } + + var iso8601 bool + + if len(args) > 2 { + for _, arg := range args[2:] { + if arg == annotationISO8601 { + iso8601 = true + } + } + } + + val := attributes[args[1]] + + // continue if the attribute was not included in the request + if val == nil { + return nil + } + + v := reflect.ValueOf(val) + + // Handle field of type time.Time + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { + if iso8601 { + var tm string + if v.Kind() == reflect.String { + tm = v.Interface().(string) + } else { + return ErrInvalidISO8601 + } + + t, err := time.Parse(iso8601TimeFormat, tm) + if err != nil { + return ErrInvalidISO8601 + } + + fieldValue.Set(reflect.ValueOf(t)) + delete(data.Attributes, args[1]) + return nil + } + + var at int64 + + if v.Kind() == reflect.Float64 { + at = int64(v.Interface().(float64)) + } else if v.Kind() == reflect.Int { + at = v.Int() + } else { + return ErrInvalidTime + } + + t := time.Unix(at, 0) + + fieldValue.Set(reflect.ValueOf(t)) + delete(data.Attributes, args[1]) + return nil + } + + if fieldValue.Type() == reflect.TypeOf([]string{}) { + values := make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + values[i] = v.Index(i).Interface().(string) + } + + fieldValue.Set(reflect.ValueOf(values)) + delete(data.Attributes, args[1]) + return nil + } + + if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { + if iso8601 { + var tm string + if v.Kind() == reflect.String { + tm = v.Interface().(string) + } else { + return ErrInvalidISO8601 + + } + + v, err := time.Parse(iso8601TimeFormat, tm) + if err != nil { + return ErrInvalidISO8601 + } + + t := &v + + fieldValue.Set(reflect.ValueOf(t)) + delete(data.Attributes, args[1]) + return nil + } + + var at int64 + + if v.Kind() == reflect.Float64 { + at = int64(v.Interface().(float64)) + } else if v.Kind() == reflect.Int { + at = v.Int() + } else { + return ErrInvalidTime + } + + v := time.Unix(at, 0) + t := &v + + fieldValue.Set(reflect.ValueOf(t)) + delete(data.Attributes, args[1]) + return nil + } + + // JSON value was a float (numeric) + if v.Kind() == reflect.Float64 { + floatValue := v.Interface().(float64) + + // The field may or may not be a pointer to a numeric; the kind var + // will not contain a pointer type + var kind reflect.Kind + if fieldValue.Kind() == reflect.Ptr { + kind = fieldType.Type.Elem().Kind() + } else { + kind = fieldType.Type.Kind() + } + + var numericValue reflect.Value + + switch kind { + case reflect.Int: + n := int(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Int8: + n := int8(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Int16: + n := int16(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Int32: + n := int32(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Int64: + n := int64(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Uint: + n := uint(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Uint8: + n := uint8(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Uint16: + n := uint16(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Uint32: + n := uint32(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Uint64: + n := uint64(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Float32: + n := float32(floatValue) + numericValue = reflect.ValueOf(&n) + case reflect.Float64: + n := floatValue + numericValue = reflect.ValueOf(&n) + default: + return ErrUnknownFieldNumberType + } + + assign(fieldValue, numericValue) + delete(data.Attributes, args[1]) + return nil + } + + // Field was a Pointer type + if fieldValue.Kind() == reflect.Ptr { + var concreteVal reflect.Value + + switch cVal := val.(type) { + case string: + concreteVal = reflect.ValueOf(&cVal) + case bool: + concreteVal = reflect.ValueOf(&cVal) + case complex64: + concreteVal = reflect.ValueOf(&cVal) + case complex128: + concreteVal = reflect.ValueOf(&cVal) + case uintptr: + concreteVal = reflect.ValueOf(&cVal) + default: + return ErrUnsupportedPtrType + } + + if fieldValue.Type() != concreteVal.Type() { + return ErrUnsupportedPtrType + } + + fieldValue.Set(concreteVal) + delete(data.Attributes, args[1]) + return nil + } + + // As a final catch-all, ensure types line up to avoid a runtime panic. + // Ignore interfaces since interfaces are poly + if fieldValue.Kind() != reflect.Interface && fieldValue.Kind() != v.Kind() { + return ErrInvalidType + } + + // set val and clear attribute key so its not processed again + fieldValue.Set(reflect.ValueOf(val)) + delete(data.Attributes, args[1]) + return nil } func fullNode(n *Node, included *map[string]*Node) *Node { includedKey := fmt.Sprintf("%s,%s", n.Type, n.ID) if included != nil && (*included)[includedKey] != nil { - return (*included)[includedKey] + return deepCopyNode((*included)[includedKey]) } - return n + return deepCopyNode(n) } // assign will take the value specified and assign it to the field; if diff --git a/response.go b/response.go index 76c3a3d..2e9acd7 100644 --- a/response.go +++ b/response.go @@ -212,14 +212,39 @@ func visitModelNode(model interface{}, included *map[string]*Node, modelType := reflect.ValueOf(model).Type().Elem() for i := 0; i < modelValue.NumField(); i++ { - structField := modelValue.Type().Field(i) - tag := structField.Tag.Get(annotationJSONAPI) - if tag == "" { + fieldValue := modelValue.Field(i) + fieldType := modelType.Field(i) + + tag := fieldType.Tag.Get(annotationJSONAPI) + + if shouldIgnoreField(tag) { continue } - fieldValue := modelValue.Field(i) - fieldType := modelType.Field(i) + // handles embedded structs and pointers to embedded structs + if isEmbeddedStruct(fieldType) || isEmbeddedStructPtr(fieldType) { + var embModel interface{} + if fieldType.Type.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + continue + } + embModel = fieldValue.Interface() + } else { + embModel = fieldValue.Addr().Interface() + } + + embNode, err := visitModelNode(embModel, included, sideload) + if err != nil { + er = err + break + } + node.merge(embNode) + continue + } + + if tag == "" { + continue + } args := strings.Split(tag, annotationSeperator) @@ -533,3 +558,15 @@ func convertToSliceInterface(i *interface{}) ([]interface{}, error) { } return response, nil } + +func isEmbeddedStruct(sField reflect.StructField) bool { + return sField.Anonymous && sField.Type.Kind() == reflect.Struct +} + +func isEmbeddedStructPtr(sField reflect.StructField) bool { + return sField.Anonymous && sField.Type.Kind() == reflect.Ptr && sField.Type.Elem().Kind() == reflect.Struct +} + +func shouldIgnoreField(japiTag string) bool { + return strings.HasPrefix(japiTag, annotationIgnore) +} diff --git a/response_test.go b/response_test.go index 71589dc..8c96cfb 100644 --- a/response_test.go +++ b/response_test.go @@ -817,6 +817,630 @@ func TestMarshal_InvalidIntefaceArgument(t *testing.T) { } } +func TestMergeNode(t *testing.T) { + parent := &Node{ + Type: "Good", + ID: "99", + Attributes: map[string]interface{}{"fizz": "buzz"}, + } + + child := &Node{ + Type: "Better", + ClientID: "1111", + Attributes: map[string]interface{}{"timbuk": 2}, + } + + expected := &Node{ + Type: "Better", + ID: "99", + ClientID: "1111", + Attributes: map[string]interface{}{"fizz": "buzz", "timbuk": 2}, + } + + parent.merge(child) + + if !reflect.DeepEqual(expected, parent) { + t.Errorf("Got %+v Expected %+v", parent, expected) + } +} + +func TestIsEmbeddedStruct(t *testing.T) { + type foo struct{} + + structType := reflect.TypeOf(foo{}) + stringType := reflect.TypeOf("") + if structType.Kind() != reflect.Struct { + t.Fatal("structType.Kind() is not a struct.") + } + if stringType.Kind() != reflect.String { + t.Fatal("stringType.Kind() is not a string.") + } + + type test struct { + scenario string + input reflect.StructField + expectedRes bool + } + + tests := []test{ + test{ + scenario: "success", + input: reflect.StructField{Anonymous: true, Type: structType}, + expectedRes: true, + }, + test{ + scenario: "wrong type", + input: reflect.StructField{Anonymous: true, Type: stringType}, + expectedRes: false, + }, + test{ + scenario: "not embedded", + input: reflect.StructField{Type: structType}, + expectedRes: false, + }, + } + + for _, test := range tests { + res := isEmbeddedStruct(test.input) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +func TestShouldIgnoreField(t *testing.T) { + type test struct { + scenario string + input string + expectedRes bool + } + + tests := []test{ + test{ + scenario: "opt-out", + input: annotationIgnore, + expectedRes: true, + }, + test{ + scenario: "no tag", + input: "", + expectedRes: false, + }, + test{ + scenario: "wrong tag", + input: "wrong,tag", + expectedRes: false, + }, + } + + for _, test := range tests { + res := shouldIgnoreField(test.input) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +func TestIsValidEmbeddedStruct(t *testing.T) { + type foo struct{} + + structType := reflect.TypeOf(foo{}) + stringType := reflect.TypeOf("") + if structType.Kind() != reflect.Struct { + t.Fatal("structType.Kind() is not a struct.") + } + if stringType.Kind() != reflect.String { + t.Fatal("stringType.Kind() is not a string.") + } + + type test struct { + scenario string + input reflect.StructField + expectedRes bool + } + + tests := []test{ + test{ + scenario: "success", + input: reflect.StructField{Anonymous: true, Type: structType}, + expectedRes: true, + }, + test{ + scenario: "opt-out", + input: reflect.StructField{Anonymous: true, Tag: "jsonapi:\"-\"", Type: structType}, + expectedRes: false, + }, + test{ + scenario: "wrong type", + input: reflect.StructField{Anonymous: true, Type: stringType}, + expectedRes: false, + }, + test{ + scenario: "not embedded", + input: reflect.StructField{Type: structType}, + expectedRes: false, + }, + } + + for _, test := range tests { + res := (isEmbeddedStruct(test.input) && !shouldIgnoreField(test.input.Tag.Get(annotationJSONAPI))) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +// TestEmbeddedUnmarshalOrder tests the behavior of the marshaler/unmarshaler of embedded structs +// when a struct has an embedded struct w/ competing attributes, the top-level attributes take precedence +// it compares the behavior against the standard json package +func TestEmbeddedUnmarshalOrder(t *testing.T) { + type Bar struct { + Name int `jsonapi:"attr,Name"` + } + + type Foo struct { + Bar + ID string `jsonapi:"primary,foos"` + Name string `jsonapi:"attr,Name"` + } + + f := &Foo{ + ID: "1", + Name: "foo", + Bar: Bar{ + Name: 5, + }, + } + + // marshal f (Foo) using jsonapi marshaler + jsonAPIData := bytes.NewBuffer(nil) + if err := MarshalPayload(jsonAPIData, f); err != nil { + t.Fatal(err) + } + + // marshal f (Foo) using json marshaler + jsonData, err := json.Marshal(f) + + // convert bytes to map[string]interface{} so that we can do a semantic JSON comparison + var jsonAPIVal, jsonVal map[string]interface{} + if err := json.Unmarshal(jsonAPIData.Bytes(), &jsonAPIVal); err != nil { + t.Fatal(err) + } + if err = json.Unmarshal(jsonData, &jsonVal); err != nil { + t.Fatal(err) + } + + // get to the jsonapi attribute map + jAttrMap := jsonAPIVal["data"].(map[string]interface{})["attributes"].(map[string]interface{}) + + // compare + if !reflect.DeepEqual(jAttrMap["Name"], jsonVal["Name"]) { + t.Errorf("Got\n%s\nExpected\n%s\n", jAttrMap["Name"], jsonVal["Name"]) + } +} + +// TestEmbeddedMarshalOrder tests the behavior of the marshaler/unmarshaler of embedded structs +// when a struct has an embedded struct w/ competing attributes, the top-level attributes take precedence +// it compares the behavior against the standard json package +func TestEmbeddedMarshalOrder(t *testing.T) { + type Bar struct { + Name int `jsonapi:"attr,Name"` + } + + type Foo struct { + Bar + ID string `jsonapi:"primary,foos"` + Name string `jsonapi:"attr,Name"` + } + + // get a jsonapi payload w/ Name attribute of an int type + payloadWithInt, err := json.Marshal(&OnePayload{ + Data: &Node{ + Type: "foos", + ID: "1", + Attributes: map[string]interface{}{ + "Name": 5, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + // get a jsonapi payload w/ Name attribute of an string type + payloadWithString, err := json.Marshal(&OnePayload{ + Data: &Node{ + Type: "foos", + ID: "1", + Attributes: map[string]interface{}{ + "Name": "foo", + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + // unmarshal payloadWithInt to f (Foo) using jsonapi unmarshaler; expecting an error + f := &Foo{} + if err := UnmarshalPayload(bytes.NewReader(payloadWithInt), f); err == nil { + t.Errorf("expected an error: int value of 5 should attempt to map to Foo.Name (string) and error") + } + + // unmarshal payloadWithString to f (Foo) using jsonapi unmarshaler; expecting no error + f = &Foo{} + if err := UnmarshalPayload(bytes.NewReader(payloadWithString), f); err != nil { + t.Error(err) + } + if f.Name != "foo" { + t.Errorf("Got\n%s\nExpected\n%s\n", "foo", f.Name) + } + + // get a json payload w/ Name attribute of an int type + bWithInt, err := json.Marshal(map[string]interface{}{ + "Name": 5, + }) + if err != nil { + t.Fatal(err) + } + + // get a json payload w/ Name attribute of an string type + bWithString, err := json.Marshal(map[string]interface{}{ + "Name": "foo", + }) + if err != nil { + t.Fatal(err) + } + + // unmarshal bWithInt to f (Foo) using json unmarshaler; expecting an error + f = &Foo{} + if err := json.Unmarshal(bWithInt, f); err == nil { + t.Errorf("expected an error: int value of 5 should attempt to map to Foo.Name (string) and error") + } + // unmarshal bWithString to f (Foo) using json unmarshaler; expecting no error + f = &Foo{} + if err := json.Unmarshal(bWithString, f); err != nil { + t.Error(err) + } + if f.Name != "foo" { + t.Errorf("Got\n%s\nExpected\n%s\n", "foo", f.Name) + } +} + +func TestMarshalUnmarshalCompositeStruct(t *testing.T) { + type Thing struct { + ID int `jsonapi:"primary,things"` + Fizz string `jsonapi:"attr,fizz"` + Buzz int `jsonapi:"attr,buzz"` + } + + type Model struct { + Thing + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + type test struct { + name string + payload *OnePayload + dst, expected interface{} + } + + scenarios := []test{} + + scenarios = append(scenarios, test{ + name: "Model embeds Thing, models have no annotation overlaps", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "things", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "buzz": 99, + "fizz": "fizzy", + "foo": "fooey", + }, + }, + }, + expected: &Model{ + Foo: "fooey", + Bar: "barry", + Bat: "batty", + Thing: Thing{ + ID: 1, + Fizz: "fizzy", + Buzz: 99, + }, + }, + }) + + { + type Model struct { + Thing + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + Buzz int `jsonapi:"attr,buzz"` // overrides Thing.Buzz + } + + scenarios = append(scenarios, test{ + name: "Model embeds Thing, overlap Buzz attribute", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "things", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "buzz": 99, + "fizz": "fizzy", + "foo": "fooey", + }, + }, + }, + expected: &Model{ + Foo: "fooey", + Bar: "barry", + Bat: "batty", + Buzz: 99, + Thing: Thing{ + ID: 1, + Fizz: "fizzy", + }, + }, + }) + } + + { + type Model struct { + Thing + ModelID int `jsonapi:"primary,models"` //overrides Thing.ID due to primary annotation + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + Buzz int `jsonapi:"attr,buzz"` // overrides Thing.Buzz + } + + scenarios = append(scenarios, test{ + name: "Model embeds Thing, attribute, and primary annotation overlap", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "buzz": 99, + "fizz": "fizzy", + "foo": "fooey", + }, + }, + }, + expected: &Model{ + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + Buzz: 99, + Thing: Thing{ + Fizz: "fizzy", + }, + }, + }) + } + + { + type Model struct { + Thing `jsonapi:"-"` + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + Buzz int `jsonapi:"attr,buzz"` + } + + scenarios = append(scenarios, test{ + name: "Model embeds Thing, but is annotated w/ ignore", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "buzz": 99, + "foo": "fooey", + }, + }, + }, + expected: &Model{ + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + Buzz: 99, + }, + }) + } + { + type Model struct { + *Thing + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + scenarios = append(scenarios, test{ + name: "Model embeds pointer of Thing; Thing is initialized in advance", + dst: &Model{Thing: &Thing{}}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "foo": "fooey", + "buzz": 99, + "fizz": "fizzy", + }, + }, + }, + expected: &Model{ + Thing: &Thing{ + Fizz: "fizzy", + Buzz: 99, + }, + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + }, + }) + } + { + type Model struct { + *Thing + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + scenarios = append(scenarios, test{ + name: "Model embeds pointer of Thing; Thing is initialized w/ Unmarshal", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "foo": "fooey", + "buzz": 99, + "fizz": "fizzy", + }, + }, + }, + expected: &Model{ + Thing: &Thing{ + Fizz: "fizzy", + Buzz: 99, + }, + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + }, + }) + } + { + type Model struct { + *Thing + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + scenarios = append(scenarios, test{ + name: "Model embeds pointer of Thing; jsonapi model doesn't assign anything to Thing; *Thing is nil", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "foo": "fooey", + }, + }, + }, + expected: &Model{ + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + }, + }) + } + + { + type Model struct { + *Thing + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + scenarios = append(scenarios, test{ + name: "Model embeds pointer of Thing; *Thing is nil", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "foo": "fooey", + }, + }, + }, + expected: &Model{ + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + }, + }) + } + for _, scenario := range scenarios { + t.Logf("running scenario: %s\n", scenario.name) + + // get the expected model and marshal to jsonapi + buf := bytes.NewBuffer(nil) + if err := MarshalPayload(buf, scenario.expected); err != nil { + t.Fatal(err) + } + + // get the node model representation and marshal to jsonapi + payload, err := json.Marshal(scenario.payload) + if err != nil { + t.Fatal(err) + } + + // assert that we're starting w/ the same payload + isJSONEqual, err := isJSONEqual(payload, buf.Bytes()) + if err != nil { + t.Fatal(err) + } + if !isJSONEqual { + t.Errorf("Got\n%s\nExpected\n%s\n", buf.Bytes(), payload) + } + + // run jsonapi unmarshal + if err := UnmarshalPayload(bytes.NewReader(payload), scenario.dst); err != nil { + t.Fatal(err) + } + + // assert decoded and expected models are equal + if !reflect.DeepEqual(scenario.expected, scenario.dst) { + t.Errorf("Got\n%#v\nExpected\n%#v\n", scenario.dst, scenario.expected) + } + } +} + func testBlog() *Blog { return &Blog{ ID: 5, @@ -883,3 +1507,17 @@ func testBlog() *Blog { }, } } + +func isJSONEqual(b1, b2 []byte) (bool, error) { + var i1, i2 interface{} + var result bool + var err error + if err = json.Unmarshal(b1, &i1); err != nil { + return result, err + } + if err = json.Unmarshal(b2, &i2); err != nil { + return result, err + } + result = reflect.DeepEqual(i1, i2) + return result, err +}