commit 660555c39f0f07bd7ab77ddeff977c6aa1729479
parent 688c5d008625b62011496858a3e55f852dccd40f
Author: francoispqt <francois@parquet.ninja>
Date: Sat, 27 Oct 2018 15:44:37 +0800
add sql null types decoding methods, fix error message when invalid type sent to decode method
Diffstat:
7 files changed, 266 insertions(+), 11 deletions(-)
diff --git a/decode.go b/decode.go
@@ -1,9 +1,9 @@
package gojay
import (
+ "database/sql"
"fmt"
"io"
- "reflect"
"time"
)
@@ -222,7 +222,7 @@ func Unmarshal(data []byte, v interface{}) error {
copy(dec.data, data)
err = dec.decodeInterface(vt)
default:
- return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String()))
+ return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt))
}
defer dec.Release()
if err != nil {
@@ -327,7 +327,7 @@ func (dec *Decoder) Decode(v interface{}) error {
case *interface{}:
err = dec.decodeInterface(vt)
default:
- return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String()))
+ return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt))
}
if err != nil {
return err
@@ -536,7 +536,7 @@ func (dec *Decoder) AddArray(v UnmarshalerJSONArray) error {
return dec.Array(v)
}
-// AddArray decodes the next key to a UnmarshalerJSONArray.
+// AddArrayNull decodes the next key to a UnmarshalerJSONArray.
func (dec *Decoder) AddArrayNull(v UnmarshalerJSONArray) error {
return dec.ArrayNull(v)
}
@@ -546,6 +546,88 @@ func (dec *Decoder) AddInterface(v *interface{}) error {
return dec.Interface(v)
}
+// --- SQL types
+
+// AddSQLNullString decodes the next key to qn *sql.NullString
+func (dec *Decoder) AddSQLNullString(v *sql.NullString) error {
+ return dec.SQLNullString(v)
+}
+
+// SQLNullString decodes the next key to an *sql.NullString
+func (dec *Decoder) SQLNullString(v *sql.NullString) error {
+ var b *string
+ if err := dec.StringNull(&b); err != nil {
+ return err
+ }
+ if b == nil {
+ v.Valid = false
+ } else {
+ v.String = *b
+ v.Valid = true
+ }
+ return nil
+}
+
+// AddSQLNullInt64 decodes the next key to qn *sql.NullInt64
+func (dec *Decoder) AddSQLNullInt64(v *sql.NullInt64) error {
+ return dec.SQLNullInt64(v)
+}
+
+// SQLNullInt64 decodes the next key to an *sql.NullInt64
+func (dec *Decoder) SQLNullInt64(v *sql.NullInt64) error {
+ var b *int64
+ if err := dec.Int64Null(&b); err != nil {
+ return err
+ }
+ if b == nil {
+ v.Valid = false
+ } else {
+ v.Int64 = *b
+ v.Valid = true
+ }
+ return nil
+}
+
+// AddSQLNullFloat64 decodes the next key to qn *sql.NullFloat64
+func (dec *Decoder) AddSQLNullFloat64(v *sql.NullFloat64) error {
+ return dec.SQLNullFloat64(v)
+}
+
+// SQLNullFloat64 decodes the next key to an *sql.NullFloat64
+func (dec *Decoder) SQLNullFloat64(v *sql.NullFloat64) error {
+ var b *float64
+ if err := dec.Float64Null(&b); err != nil {
+ return err
+ }
+ if b == nil {
+ v.Valid = false
+ } else {
+ v.Float64 = *b
+ v.Valid = true
+ }
+ return nil
+}
+
+// AddSQLNullBool decodes the next key to an *sql.NullBool
+func (dec *Decoder) AddSQLNullBool(v *sql.NullBool) error {
+ return dec.SQLNullBool(v)
+}
+
+// SQLNullBool decodes the next key to an *sql.NullBool
+func (dec *Decoder) SQLNullBool(v *sql.NullBool) error {
+ var b *bool
+ if err := dec.BoolNull(&b); err != nil {
+ return err
+ }
+ if b == nil {
+ v.Valid = false
+ } else {
+ v.Bool = *b
+ v.Valid = true
+ }
+ return nil
+}
+
// Int decodes the next key to an *int.
// If next key value overflows int, an InvalidUnmarshalError error will be returned.
func (dec *Decoder) Int(v *int) error {
diff --git a/decode_interface_test.go b/decode_interface_test.go
@@ -2,7 +2,6 @@ package gojay
import (
"encoding/json"
- "log"
"strings"
"testing"
@@ -128,7 +127,6 @@ func TestDecodeInterfaceBasic(t *testing.T) {
for _, testCase := range testCases {
t.Run("DecodeInterface()"+testCase.name, func(t *testing.T) {
- log.Print(testCase.name)
var i interface{}
dec := BorrowDecoder(strings.NewReader(testCase.json))
defer dec.Release()
diff --git a/decode_sqlnull_test.go b/decode_sqlnull_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestDecodeSQLNullString(t *testing.T) {
@@ -199,3 +200,178 @@ func TestDecodeSQLNullBool(t *testing.T) {
},
)
}
+
+type SQLDecodeObject struct {
+ S sql.NullString
+ F sql.NullFloat64
+ I sql.NullInt64
+ B sql.NullBool
+}
+
+func (s *SQLDecodeObject) UnmarshalJSONObject(dec *Decoder, k string) error {
+ switch k {
+ case "s":
+ return dec.SQLNullString(&s.S)
+ case "f":
+ return dec.SQLNullFloat64(&s.F)
+ case "i":
+ return dec.SQLNullInt64(&s.I)
+ case "b":
+ return dec.SQLNullBool(&s.B)
+ }
+ return nil
+}
+
+func (s *SQLDecodeObject) NKeys() int {
+ return 0
+}
+
+func TestDecodeSQLNullKeys(t *testing.T) {
+ var testCases = []struct {
+ name string
+ json string
+ expectedResult *SQLDecodeObject
+ }{
+ {
+ name: "basic all valid",
+ json: `{
+ "s": "foo",
+ "f": 0.3,
+ "i": 3,
+ "b": true
+ }`,
+ expectedResult: &SQLDecodeObject{
+ S: sql.NullString{
+ String: "foo",
+ Valid: true,
+ },
+ F: sql.NullFloat64{
+ Float64: 0.3,
+ Valid: true,
+ },
+ I: sql.NullInt64{
+ Int64: 3,
+ Valid: true,
+ },
+ B: sql.NullBool{
+ Bool: true,
+ Valid: true,
+ },
+ },
+ },
+ {
+ name: "string not valid",
+ json: `{
+ "s": null,
+ "f": 0.3,
+ "i": 3,
+ "b": true
+ }`,
+ expectedResult: &SQLDecodeObject{
+ S: sql.NullString{
+ Valid: false,
+ },
+ F: sql.NullFloat64{
+ Float64: 0.3,
+ Valid: true,
+ },
+ I: sql.NullInt64{
+ Int64: 3,
+ Valid: true,
+ },
+ B: sql.NullBool{
+ Bool: true,
+ Valid: true,
+ },
+ },
+ },
+ {
+ name: "string not valid, int not valid",
+ json: `{
+ "s": null,
+ "f": 0.3,
+ "i": null,
+ "b": true
+ }`,
+ expectedResult: &SQLDecodeObject{
+ S: sql.NullString{
+ Valid: false,
+ },
+ F: sql.NullFloat64{
+ Float64: 0.3,
+ Valid: true,
+ },
+ I: sql.NullInt64{
+ Valid: false,
+ },
+ B: sql.NullBool{
+ Bool: true,
+ Valid: true,
+ },
+ },
+ },
+ {
+ name: "keys absent",
+ json: `{
+ "f": 0.3,
+ "i": 3,
+ "b": true
+ }`,
+ expectedResult: &SQLDecodeObject{
+ S: sql.NullString{
+ Valid: false,
+ },
+ F: sql.NullFloat64{
+ Float64: 0.3,
+ Valid: true,
+ },
+ I: sql.NullInt64{
+ Valid: true,
+ Int64: 3,
+ },
+ B: sql.NullBool{
+ Bool: true,
+ Valid: true,
+ },
+ },
+ },
+ {
+ name: "keys all null",
+ json: `{
+ "s": null,
+ "f": null,
+ "i": null,
+ "b": null
+ }`,
+ expectedResult: &SQLDecodeObject{
+ S: sql.NullString{
+ Valid: false,
+ },
+ F: sql.NullFloat64{
+ Valid: false,
+ },
+ I: sql.NullInt64{
+ Valid: false,
+ },
+ B: sql.NullBool{
+ Valid: false,
+ },
+ },
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ var o = &SQLDecodeObject{}
+ var dec = NewDecoder(strings.NewReader(testCase.json))
+ var err = dec.Decode(o)
+ require.Nil(t, err)
+ require.Equal(
+ t,
+ testCase.expectedResult,
+ o,
+ )
+ })
+ }
+
+}
diff --git a/decode_string.go b/decode_string.go
@@ -59,6 +59,7 @@ func (dec *Decoder) decodeStringNull(v **string) error {
case '"':
dec.cursor++
start, end, err := dec.getString()
+
if err != nil {
return err
}
diff --git a/decode_test.go b/decode_test.go
@@ -445,7 +445,7 @@ func allTypesTestCases() []allTypeDecodeTestCase {
expectations: func(err error, v interface{}, t *testing.T) {
assert.NotNil(t, err, "err must not be nil")
assert.IsType(t, InvalidUnmarshalError(""), err, "err must be of type InvalidUnmarshalError")
- assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(v).String()), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg")
+ assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, v), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg")
},
},
}
diff --git a/decode_unsafe.go b/decode_unsafe.go
@@ -2,7 +2,6 @@ package gojay
import (
"fmt"
- "reflect"
)
// Unsafe is the structure holding the unsafe version of the API.
@@ -111,7 +110,7 @@ func (u decUnsafe) Unmarshal(data []byte, v interface{}) error {
dec.data = data
_, err = dec.decodeArray(vt)
default:
- return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String()))
+ return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt))
}
defer dec.Release()
if err != nil {
diff --git a/decode_unsafe_test.go b/decode_unsafe_test.go
@@ -2,7 +2,6 @@ package gojay
import (
"fmt"
- "reflect"
"testing"
"github.com/stretchr/testify/assert"
@@ -277,7 +276,7 @@ func TestUnmarshalUnsafeAllTypes(t *testing.T) {
expectations: func(err error, v interface{}, t *testing.T) {
assert.NotNil(t, err, "err must not be nil")
assert.IsType(t, InvalidUnmarshalError(""), err, "err must be of type InvalidUnmarshalError")
- assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(v).String()), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg")
+ assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, v), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg")
},
},
{