commit 75d85cba87f6ffbc02c9c3f359a1d376b0ced575
parent 106fcfa8bda0bc284d48674332298029d9c74ebc
Author: Arseny Balobanov <verytable29@gmail.com>
Date: Wed, 6 Feb 2019 01:27:16 +0300
add Index() method to Decoder
Index() returns the index of an array element being decoded and 0 in other cases
Diffstat:
4 files changed, 114 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
@@ -286,6 +286,20 @@ func (c ChannelArray) UnmarshalJSONArray(dec *gojay.Decoder) error {
}
```
+Example of implementation with an array:
+```go
+type testArray [3]string
+// implement UnmarshalerJSONArray
+func (a *testArray) UnmarshalJSONArray(dec *Decoder) error {
+ var str string
+ if err := dec.String(&str); err != nil {
+ return err
+ }
+ a[dec.Index()] = str
+ return nil
+}
+```
+
### Other types
To decode other types (string, int, int32, int64, uint32, uint64, float, booleans), you don't need to implement any interface.
diff --git a/decode.go b/decode.go
@@ -235,15 +235,16 @@ type UnmarshalerJSONArray interface {
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
- r io.Reader
- data []byte
- err error
- isPooled byte
- called byte
- child byte
- cursor int
- length int
- keysDone int
+ r io.Reader
+ data []byte
+ err error
+ isPooled byte
+ called byte
+ child byte
+ cursor int
+ length int
+ keysDone int
+ arrayIndex int
}
// Decode reads the next JSON-encoded value from the decoder's input (io.Reader) and stores it in the value pointed to by v.
diff --git a/decode_array.go b/decode_array.go
@@ -16,6 +16,12 @@ func (dec *Decoder) DecodeArray(v UnmarshalerJSONArray) error {
return err
}
func (dec *Decoder) decodeArray(arr UnmarshalerJSONArray) (int, error) {
+ // remember last array index in case of nested arrays
+ lastArrayIndex := dec.arrayIndex
+ dec.arrayIndex = 0
+ defer func() {
+ dec.arrayIndex = lastArrayIndex
+ }()
for ; dec.cursor < dec.length || dec.read(); dec.cursor++ {
switch dec.data[dec.cursor] {
case ' ', '\n', '\t', '\r', ',':
@@ -34,6 +40,7 @@ func (dec *Decoder) decodeArray(arr UnmarshalerJSONArray) (int, error) {
if err != nil {
return 0, err
}
+ dec.arrayIndex++
}
return 0, dec.raiseInvalidJSONErr(dec.cursor)
case 'n':
@@ -60,6 +67,12 @@ func (dec *Decoder) decodeArray(arr UnmarshalerJSONArray) (int, error) {
return 0, dec.raiseInvalidJSONErr(dec.cursor)
}
func (dec *Decoder) decodeArrayNull(v interface{}) (int, error) {
+ // remember last array index in case of nested arrays
+ lastArrayIndex := dec.arrayIndex
+ dec.arrayIndex = 0
+ defer func() {
+ dec.arrayIndex = lastArrayIndex
+ }()
vv := reflect.ValueOf(v)
vvt := vv.Type()
if vvt.Kind() != reflect.Ptr || vvt.Elem().Kind() != reflect.Ptr {
@@ -96,6 +109,7 @@ func (dec *Decoder) decodeArrayNull(v interface{}) (int, error) {
if err != nil {
return 0, err
}
+ dec.arrayIndex++
}
return 0, dec.raiseInvalidJSONErr(dec.cursor)
case 'n':
@@ -226,3 +240,8 @@ func (dec *Decoder) ArrayNull(v interface{}) error {
dec.called |= 1
return nil
}
+
+// Index returns the index of an array being decoded.
+func (dec *Decoder) Index() int {
+ return dec.arrayIndex
+}
diff --git a/decode_array_test.go b/decode_array_test.go
@@ -627,3 +627,74 @@ func TestDecoderArrayFunc(t *testing.T) {
var f DecodeArrayFunc
assert.True(t, f.IsNil())
}
+
+type testArrayStrings [3]string
+
+func (a *testArrayStrings) UnmarshalJSONArray(dec *Decoder) error {
+ var str string
+ if err := dec.String(&str); err != nil {
+ return err
+ }
+ a[dec.Index()] = str
+ return nil
+}
+
+func TestArrayStrings(t *testing.T) {
+ data := []byte(`["a", "b", "c"]`)
+ arr := testArrayStrings{}
+ err := Unmarshal(data, &arr)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "a", arr[0], "arr[0] must be equal to 'a'")
+ assert.Equal(t, "b", arr[1], "arr[1] must be equal to 'b'")
+ assert.Equal(t, "c", arr[2], "arr[2] must be equal to 'c'")
+}
+
+type testSliceArraysStrings struct {
+ arrays []testArrayStrings
+ t *testing.T
+}
+
+func (s *testSliceArraysStrings) UnmarshalJSONArray(dec *Decoder) error {
+ var a testArrayStrings
+ assert.Equal(s.t, len(s.arrays), dec.Index(), "decoded array index must be equal to current slice len")
+ if err := dec.AddArray(&a); err != nil {
+ return err
+ }
+ assert.Equal(s.t, len(s.arrays), dec.Index(), "decoded array index must be equal to current slice len")
+ s.arrays = append(s.arrays, a)
+ return nil
+}
+
+func TestIndex(t *testing.T) {
+ testCases := []struct {
+ name string
+ json string
+ expectedResult []testArrayStrings
+ }{
+ {
+ name: "basic-test",
+ json: `[["a","b","c"],["1","2","3"],["x","y","z"]]`,
+ expectedResult: []testArrayStrings{{"a", "b", "c"}, {"1", "2", "3"}, {"x", "y", "z"}},
+ },
+ {
+ name: "basic-test-null",
+ json: `[["a","b","c"],null,["x","y","z"]]`,
+ expectedResult: []testArrayStrings{{"a", "b", "c"}, {"", "", ""}, {"x", "y", "z"}},
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ s := make([]testArrayStrings, 0)
+ dec := BorrowDecoder(strings.NewReader(testCase.json))
+ defer dec.Release()
+ a := testSliceArraysStrings{arrays: s, t: t}
+ err := dec.Decode(&a)
+ assert.Nil(t, err, "err should be nil")
+ assert.Zero(t, dec.Index(), "Index() must return zero after decoding")
+ for k, v := range testCase.expectedResult {
+ assert.Equal(t, v, a.arrays[k], "value at given index should be the same as expected results")
+ }
+ })
+ }
+}