commit 2896c17ffd3c0033fc2279f66759c9224d7d6edb
parent 10124f98bd7fabf1c775ac9485c53d032ddee637
Author: Francois Parquet <francois.parquet@gmail.com>
Date: Sat, 28 Apr 2018 21:19:09 +0800
Merge pull request #4 from francoispqt/update/make-unmarshal-api-safer-for-strings
make unmarshal api safer by copying initial buffer and adding an Unsa…
Diffstat:
7 files changed, 587 insertions(+), 4 deletions(-)
diff --git a/benchmarks/decoder/decoder_bench_large_test.go b/benchmarks/decoder/decoder_bench_large_test.go
@@ -49,3 +49,11 @@ func BenchmarkGoJayDecodeObjLarge(b *testing.B) {
gojay.UnmarshalObject(benchmarks.LargeFixture, &result)
}
}
+
+func BenchmarkGoJayUnsafeDecodeObjLarge(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ result := benchmarks.LargePayload{}
+ gojay.Unsafe.UnmarshalObject(benchmarks.LargeFixture, &result)
+ }
+}
diff --git a/benchmarks/decoder/decoder_bench_medium_test.go b/benchmarks/decoder/decoder_bench_medium_test.go
@@ -60,3 +60,13 @@ func BenchmarkGoJayDecodeObjMedium(b *testing.B) {
}
}
}
+func BenchmarkGoJayUnsafeDecodeObjMedium(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ result := benchmarks.MediumPayload{}
+ err := gojay.Unsafe.UnmarshalObject(benchmarks.MediumFixture, &result)
+ if err != nil {
+ b.Error(err)
+ }
+ }
+}
diff --git a/benchmarks/decoder/decoder_bench_small_test.go b/benchmarks/decoder/decoder_bench_small_test.go
@@ -59,3 +59,11 @@ func BenchmarkGoJayDecodeObjSmall(b *testing.B) {
gojay.UnmarshalObject(benchmarks.SmallFixture, &result)
}
}
+
+func BenchmarkGoJayUnsafeDecodeObjSmall(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ result := benchmarks.SmallPayload{}
+ gojay.Unsafe.UnmarshalObject(benchmarks.SmallFixture, &result)
+ }
+}
diff --git a/decode.go b/decode.go
@@ -14,7 +14,8 @@ import (
// overflows the target type, UnmarshalArray skips that field and completes the unmarshaling as best it can.
func UnmarshalArray(data []byte, v UnmarshalerArray) error {
dec := newDecoder(nil, 0)
- dec.data = data
+ dec.data = make([]byte, len(data))
+ copy(dec.data, data)
dec.length = len(data)
_, err := dec.DecodeArray(v)
dec.addToPool()
@@ -35,7 +36,8 @@ func UnmarshalArray(data []byte, v UnmarshalerArray) error {
// overflows the target type, UnmarshalObject skips that field and completes the unmarshaling as best it can.
func UnmarshalObject(data []byte, v UnmarshalerObject) error {
dec := newDecoder(nil, 0)
- dec.data = data
+ dec.data = make([]byte, len(data))
+ copy(dec.data, data)
dec.length = len(data)
_, err := dec.DecodeObject(v)
dec.addToPool()
@@ -113,12 +115,14 @@ func Unmarshal(data []byte, v interface{}) error {
case UnmarshalerObject:
dec = newDecoder(nil, 0)
dec.length = len(data)
- dec.data = data
+ dec.data = make([]byte, len(data))
+ copy(dec.data, data)
_, err = dec.DecodeObject(vt)
case UnmarshalerArray:
dec = newDecoder(nil, 0)
dec.length = len(data)
- dec.data = data
+ dec.data = make([]byte, len(data))
+ copy(dec.data, data)
_, err = dec.DecodeArray(vt)
default:
return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String()))
diff --git a/decode_test.go b/decode_test.go
@@ -496,3 +496,110 @@ func TestDecodeAllTypes(t *testing.T) {
})
}
}
+
+func TestUnmarshalObjects(t *testing.T) {
+ testCases := []struct {
+ name string
+ v UnmarshalerObject
+ d []byte
+ expectations func(err error, v interface{}, t *testing.T)
+ }{
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":"test"}`),
+ name: "test decode object",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "test", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":null}`),
+ name: "test decode object null key",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`null`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`invalid json`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ assert.NotNil(t, err, "err must not be nil")
+ assert.IsType(t, InvalidJSONError(""), err, "err must be of type InvalidJSONError")
+ },
+ },
+ }
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.name, func(*testing.T) {
+ err := UnmarshalObject(testCase.d, testCase.v)
+ testCase.expectations(err, testCase.v, t)
+ })
+ }
+}
+
+func TestUnmarshalArrays(t *testing.T) {
+ testCases := []struct {
+ name string
+ v UnmarshalerArray
+ d []byte
+ expectations func(err error, v interface{}, t *testing.T)
+ }{
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 1, "len of vt must be 1")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"},{"test":"test2"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 2, "len of vt must be 2")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ assert.Equal(t, "test2", vt[1].test, "vt[1].test must be equal to 'test2'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`invalid json`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ assert.NotNil(t, err, "err must not be nil")
+ assert.IsType(t, InvalidJSONError(""), err, "err must be of type InvalidJSONError")
+ },
+ },
+ }
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.name, func(*testing.T) {
+ err := UnmarshalArray(testCase.d, testCase.v)
+ testCase.expectations(err, testCase.v, t)
+ })
+ }
+}
diff --git a/decode_unsafe.go b/decode_unsafe.go
@@ -0,0 +1,108 @@
+package gojay
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// Unsafe is the structure holding the unsafe version of the API.
+// The difference between unsafe api and regular api is that the regular API
+// copies the buffer passed to Unmarshal functions to a new internal buffer.
+// Making it safer because internally GoJay uses unsafe.Pointer to transform slice of bytes into a string.
+var Unsafe = decUnsafe{}
+
+type decUnsafe struct{}
+
+func (u decUnsafe) UnmarshalArray(data []byte, v UnmarshalerArray) error {
+ dec := newDecoder(nil, 0)
+ dec.data = data
+ dec.length = len(data)
+ _, err := dec.DecodeArray(v)
+ dec.addToPool()
+ if err != nil {
+ return err
+ }
+ if dec.err != nil {
+ return dec.err
+ }
+ return nil
+}
+
+func (u decUnsafe) UnmarshalObject(data []byte, v UnmarshalerObject) error {
+ dec := newDecoder(nil, 0)
+ dec.data = data
+ dec.length = len(data)
+ _, err := dec.DecodeObject(v)
+ dec.addToPool()
+ if err != nil {
+ return err
+ }
+ if dec.err != nil {
+ return dec.err
+ }
+ return nil
+}
+
+func (u decUnsafe) Unmarshal(data []byte, v interface{}) error {
+ var err error
+ var dec *Decoder
+ switch vt := v.(type) {
+ case *string:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeString(vt)
+ case *int:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeInt(vt)
+ case *int32:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeInt32(vt)
+ case *uint32:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeUint32(vt)
+ case *int64:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeInt64(vt)
+ case *uint64:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeUint64(vt)
+ case *float64:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeFloat64(vt)
+ case *bool:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ err = dec.DecodeBool(vt)
+ case UnmarshalerObject:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ _, err = dec.DecodeObject(vt)
+ case UnmarshalerArray:
+ dec = newDecoder(nil, 0)
+ dec.length = len(data)
+ dec.data = data
+ _, err = dec.DecodeArray(vt)
+ default:
+ return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String()))
+ }
+ defer dec.addToPool()
+ if err != nil {
+ return err
+ }
+ return dec.err
+}
diff --git a/decode_unsafe_test.go b/decode_unsafe_test.go
@@ -0,0 +1,338 @@
+package gojay
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestUnmarshalUnsafeAllTypes(t *testing.T) {
+ testCases := []struct {
+ name string
+ v interface{}
+ d []byte
+ expectations func(err error, v interface{}, t *testing.T)
+ }{
+ {
+ v: new(string),
+ d: []byte(`"test string"`),
+ name: "test decode string",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*string)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "test string", *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(string),
+ d: []byte(`null`),
+ name: "test decode string null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*string)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(int),
+ d: []byte(`1`),
+ name: "test decode int",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*int)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, 1, *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(int64),
+ d: []byte(`1`),
+ name: "test decode int64",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*int64)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, int64(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(uint64),
+ d: []byte(`1`),
+ name: "test decode uint64",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*uint64)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, uint64(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(uint64),
+ d: []byte(`-1`),
+ name: "test decode uint64 negative",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*uint64)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, uint64(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(int32),
+ d: []byte(`1`),
+ name: "test decode int32",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*int32)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, int32(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(uint32),
+ d: []byte(`1`),
+ name: "test decode uint32",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*uint32)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, uint32(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(uint32),
+ d: []byte(`-1`),
+ name: "test decode uint32 negative",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*uint32)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, uint32(1), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(float64),
+ d: []byte(`1.15`),
+ name: "test decode float64",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*float64)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, float64(1.15), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(float64),
+ d: []byte(`null`),
+ name: "test decode float64 null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*float64)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, float64(0), *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(bool),
+ d: []byte(`true`),
+ name: "test decode bool true",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*bool)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, true, *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(bool),
+ d: []byte(`false`),
+ name: "test decode bool false",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*bool)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, false, *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(bool),
+ d: []byte(`null`),
+ name: "test decode bool null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*bool)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, false, *vt, "v must be equal to 1")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":"test"}`),
+ name: "test decode object",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "test", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":null}`),
+ name: "test decode object null key",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`null`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 1, "len of vt must be 1")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"},{"test":"test2"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 2, "len of vt must be 2")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ assert.Equal(t, "test2", vt[1].test, "vt[1].test must be equal to 'test2'")
+ },
+ },
+ {
+ v: new(struct{}),
+ d: []byte(`{"test":"test"}`),
+ name: "test decode invalid type",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ assert.NotNil(t, err, "err must not be nil")
+ assert.IsType(t, InvalidUnmarshalError(""), err, "err must be of type InvalidUnmarshalError")
+ assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(v).String()), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg")
+ },
+ },
+ }
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.name, func(*testing.T) {
+ err := Unsafe.Unmarshal(testCase.d, testCase.v)
+ testCase.expectations(err, testCase.v, t)
+ })
+ }
+}
+
+func TestUnmarshalUnsafeObjects(t *testing.T) {
+ testCases := []struct {
+ name string
+ v UnmarshalerObject
+ d []byte
+ expectations func(err error, v interface{}, t *testing.T)
+ }{
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":"test"}`),
+ name: "test decode object",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "test", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`{"test":null}`),
+ name: "test decode object null key",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`null`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vt := v.(*testDecodeObj)
+ assert.Nil(t, err, "err must be nil")
+ assert.Equal(t, "", vt.test, "v.test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeObj),
+ d: []byte(`invalid json`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ assert.NotNil(t, err, "err must not be nil")
+ assert.IsType(t, InvalidJSONError(""), err, "err must be of type InvalidJSONError")
+ },
+ },
+ }
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.name, func(*testing.T) {
+ err := Unsafe.UnmarshalObject(testCase.d, testCase.v)
+ testCase.expectations(err, testCase.v, t)
+ })
+ }
+}
+
+func TestUnmarshalUnsafeArrays(t *testing.T) {
+ testCases := []struct {
+ name string
+ v UnmarshalerArray
+ d []byte
+ expectations func(err error, v interface{}, t *testing.T)
+ }{
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 1, "len of vt must be 1")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`[{"test":"test"},{"test":"test2"}]`),
+ name: "test decode slice",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ vtPtr := v.(*testDecodeSlice)
+ vt := *vtPtr
+ assert.Nil(t, err, "err must be nil")
+ assert.Len(t, vt, 2, "len of vt must be 2")
+ assert.Equal(t, "test", vt[0].test, "vt[0].test must be equal to 'test'")
+ assert.Equal(t, "test2", vt[1].test, "vt[1].test must be equal to 'test2'")
+ },
+ },
+ {
+ v: new(testDecodeSlice),
+ d: []byte(`invalid json`),
+ name: "test decode object null",
+ expectations: func(err error, v interface{}, t *testing.T) {
+ assert.NotNil(t, err, "err must not be nil")
+ assert.IsType(t, InvalidJSONError(""), err, "err must be of type InvalidJSONError")
+ },
+ },
+ }
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.name, func(*testing.T) {
+ err := Unsafe.UnmarshalArray(testCase.d, testCase.v)
+ testCase.expectations(err, testCase.v, t)
+ })
+ }
+}