gojay

high performance JSON encoder/decoder with stream API for Golang
git clone git://git.lair.cx/gojay
Log | Files | Refs | README | LICENSE

commit 660555c39f0f07bd7ab77ddeff977c6aa1729479
parent 688c5d008625b62011496858a3e55f852dccd40f
Author: francoispqt <francois@parquet.ninja>
Date:   Sat, 27 Oct 2018 15:44:37 +0800

add sql null types decoding methods, fix error message when invalid type sent to decode method

Diffstat:
Mdecode.go | 90+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
Mdecode_interface_test.go | 2--
Mdecode_sqlnull_test.go | 176+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mdecode_string.go | 1+
Mdecode_test.go | 2+-
Mdecode_unsafe.go | 3+--
Mdecode_unsafe_test.go | 3+--
7 files changed, 266 insertions(+), 11 deletions(-)

diff --git a/decode.go b/decode.go @@ -1,9 +1,9 @@ package gojay import ( + "database/sql" "fmt" "io" - "reflect" "time" ) @@ -222,7 +222,7 @@ func Unmarshal(data []byte, v interface{}) error { copy(dec.data, data) err = dec.decodeInterface(vt) default: - return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String())) + return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt)) } defer dec.Release() if err != nil { @@ -327,7 +327,7 @@ func (dec *Decoder) Decode(v interface{}) error { case *interface{}: err = dec.decodeInterface(vt) default: - return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String())) + return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt)) } if err != nil { return err @@ -536,7 +536,7 @@ func (dec *Decoder) AddArray(v UnmarshalerJSONArray) error { return dec.Array(v) } -// AddArray decodes the next key to a UnmarshalerJSONArray. +// AddArrayNull decodes the next key to a UnmarshalerJSONArray. func (dec *Decoder) AddArrayNull(v UnmarshalerJSONArray) error { return dec.ArrayNull(v) } @@ -546,6 +546,88 @@ func (dec *Decoder) AddInterface(v *interface{}) error { return dec.Interface(v) } +// --- SQL types + +// AddSQLNullString decodes the next key to qn *sql.NullString +func (dec *Decoder) AddSQLNullString(v *sql.NullString) error { + return dec.SQLNullString(v) +} + +// SQLNullString decodes the next key to an *sql.NullString +func (dec *Decoder) SQLNullString(v *sql.NullString) error { + var b *string + if err := dec.StringNull(&b); err != nil { + return err + } + if b == nil { + v.Valid = false + } else { + v.String = *b + v.Valid = true + } + return nil +} + +// AddSQLNullInt64 decodes the next key to qn *sql.NullInt64 +func (dec *Decoder) AddSQLNullInt64(v *sql.NullInt64) error { + return dec.SQLNullInt64(v) +} + +// SQLNullInt64 decodes the next key to an *sql.NullInt64 +func (dec *Decoder) SQLNullInt64(v *sql.NullInt64) error { + var b *int64 + if err := dec.Int64Null(&b); err != nil { + return err + } + if b == nil { + v.Valid = false + } else { + v.Int64 = *b + v.Valid = true + } + return nil +} + +// AddSQLNullFloat64 decodes the next key to qn *sql.NullFloat64 +func (dec *Decoder) AddSQLNullFloat64(v *sql.NullFloat64) error { + return dec.SQLNullFloat64(v) +} + +// SQLNullFloat64 decodes the next key to an *sql.NullFloat64 +func (dec *Decoder) SQLNullFloat64(v *sql.NullFloat64) error { + var b *float64 + if err := dec.Float64Null(&b); err != nil { + return err + } + if b == nil { + v.Valid = false + } else { + v.Float64 = *b + v.Valid = true + } + return nil +} + +// AddSQLNullBool decodes the next key to an *sql.NullBool +func (dec *Decoder) AddSQLNullBool(v *sql.NullBool) error { + return dec.SQLNullBool(v) +} + +// SQLNullBool decodes the next key to an *sql.NullBool +func (dec *Decoder) SQLNullBool(v *sql.NullBool) error { + var b *bool + if err := dec.BoolNull(&b); err != nil { + return err + } + if b == nil { + v.Valid = false + } else { + v.Bool = *b + v.Valid = true + } + return nil +} + // Int decodes the next key to an *int. // If next key value overflows int, an InvalidUnmarshalError error will be returned. func (dec *Decoder) Int(v *int) error { diff --git a/decode_interface_test.go b/decode_interface_test.go @@ -2,7 +2,6 @@ package gojay import ( "encoding/json" - "log" "strings" "testing" @@ -128,7 +127,6 @@ func TestDecodeInterfaceBasic(t *testing.T) { for _, testCase := range testCases { t.Run("DecodeInterface()"+testCase.name, func(t *testing.T) { - log.Print(testCase.name) var i interface{} dec := BorrowDecoder(strings.NewReader(testCase.json)) defer dec.Release() diff --git a/decode_sqlnull_test.go b/decode_sqlnull_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDecodeSQLNullString(t *testing.T) { @@ -199,3 +200,178 @@ func TestDecodeSQLNullBool(t *testing.T) { }, ) } + +type SQLDecodeObject struct { + S sql.NullString + F sql.NullFloat64 + I sql.NullInt64 + B sql.NullBool +} + +func (s *SQLDecodeObject) UnmarshalJSONObject(dec *Decoder, k string) error { + switch k { + case "s": + return dec.SQLNullString(&s.S) + case "f": + return dec.SQLNullFloat64(&s.F) + case "i": + return dec.SQLNullInt64(&s.I) + case "b": + return dec.SQLNullBool(&s.B) + } + return nil +} + +func (s *SQLDecodeObject) NKeys() int { + return 0 +} + +func TestDecodeSQLNullKeys(t *testing.T) { + var testCases = []struct { + name string + json string + expectedResult *SQLDecodeObject + }{ + { + name: "basic all valid", + json: `{ + "s": "foo", + "f": 0.3, + "i": 3, + "b": true + }`, + expectedResult: &SQLDecodeObject{ + S: sql.NullString{ + String: "foo", + Valid: true, + }, + F: sql.NullFloat64{ + Float64: 0.3, + Valid: true, + }, + I: sql.NullInt64{ + Int64: 3, + Valid: true, + }, + B: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + }, + { + name: "string not valid", + json: `{ + "s": null, + "f": 0.3, + "i": 3, + "b": true + }`, + expectedResult: &SQLDecodeObject{ + S: sql.NullString{ + Valid: false, + }, + F: sql.NullFloat64{ + Float64: 0.3, + Valid: true, + }, + I: sql.NullInt64{ + Int64: 3, + Valid: true, + }, + B: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + }, + { + name: "string not valid, int not valid", + json: `{ + "s": null, + "f": 0.3, + "i": null, + "b": true + }`, + expectedResult: &SQLDecodeObject{ + S: sql.NullString{ + Valid: false, + }, + F: sql.NullFloat64{ + Float64: 0.3, + Valid: true, + }, + I: sql.NullInt64{ + Valid: false, + }, + B: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + }, + { + name: "keys absent", + json: `{ + "f": 0.3, + "i": 3, + "b": true + }`, + expectedResult: &SQLDecodeObject{ + S: sql.NullString{ + Valid: false, + }, + F: sql.NullFloat64{ + Float64: 0.3, + Valid: true, + }, + I: sql.NullInt64{ + Valid: true, + Int64: 3, + }, + B: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + }, + { + name: "keys all null", + json: `{ + "s": null, + "f": null, + "i": null, + "b": null + }`, + expectedResult: &SQLDecodeObject{ + S: sql.NullString{ + Valid: false, + }, + F: sql.NullFloat64{ + Valid: false, + }, + I: sql.NullInt64{ + Valid: false, + }, + B: sql.NullBool{ + Valid: false, + }, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var o = &SQLDecodeObject{} + var dec = NewDecoder(strings.NewReader(testCase.json)) + var err = dec.Decode(o) + require.Nil(t, err) + require.Equal( + t, + testCase.expectedResult, + o, + ) + }) + } + +} diff --git a/decode_string.go b/decode_string.go @@ -59,6 +59,7 @@ func (dec *Decoder) decodeStringNull(v **string) error { case '"': dec.cursor++ start, end, err := dec.getString() + if err != nil { return err } diff --git a/decode_test.go b/decode_test.go @@ -445,7 +445,7 @@ func allTypesTestCases() []allTypeDecodeTestCase { expectations: func(err error, v interface{}, t *testing.T) { assert.NotNil(t, err, "err must not be nil") assert.IsType(t, InvalidUnmarshalError(""), err, "err must be of type InvalidUnmarshalError") - assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(v).String()), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg") + assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, v), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg") }, }, } diff --git a/decode_unsafe.go b/decode_unsafe.go @@ -2,7 +2,6 @@ package gojay import ( "fmt" - "reflect" ) // Unsafe is the structure holding the unsafe version of the API. @@ -111,7 +110,7 @@ func (u decUnsafe) Unmarshal(data []byte, v interface{}) error { dec.data = data _, err = dec.decodeArray(vt) default: - return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(vt).String())) + return InvalidUnmarshalError(fmt.Sprintf(invalidUnmarshalErrorMsg, vt)) } defer dec.Release() if err != nil { diff --git a/decode_unsafe_test.go b/decode_unsafe_test.go @@ -2,7 +2,6 @@ package gojay import ( "fmt" - "reflect" "testing" "github.com/stretchr/testify/assert" @@ -277,7 +276,7 @@ func TestUnmarshalUnsafeAllTypes(t *testing.T) { expectations: func(err error, v interface{}, t *testing.T) { assert.NotNil(t, err, "err must not be nil") assert.IsType(t, InvalidUnmarshalError(""), err, "err must be of type InvalidUnmarshalError") - assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, reflect.TypeOf(v).String()), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg") + assert.Equal(t, fmt.Sprintf(invalidUnmarshalErrorMsg, v), err.Error(), "err message should be equal to invalidUnmarshalErrorMsg") }, }, {