commit e463341caf8850e339dff20d8ebab3db8962ad67
parent d096375da20dd060cb672ea511e0c993d6e115a2
Author: francoispqt <francois@parquet.ninja>
Date: Thu, 10 May 2018 23:23:47 +0800
add boolean an null value validation for compliance with rfc, basic skip will be added in unsafe API
Diffstat:
7 files changed, 182 insertions(+), 19 deletions(-)
diff --git a/decode_bool.go b/decode_bool.go
@@ -17,16 +17,29 @@ func (dec *Decoder) decodeBool(v *bool) error {
case ' ', '\n', '\t', '\r', ',':
continue
case 't':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertTrue()
+ if err != nil {
+ return err
+ }
*v = true
return nil
case 'f':
- dec.cursor = dec.cursor + 5
+ dec.cursor++
+ err := dec.assertFalse()
+ if err != nil {
+ return err
+ }
*v = false
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
*v = false
+ dec.cursor++
return nil
default:
dec.err = InvalidTypeError(
@@ -45,3 +58,82 @@ func (dec *Decoder) decodeBool(v *bool) error {
}
return nil
}
+
+func (dec *Decoder) assertTrue() error {
+ i := 0
+ for ; dec.cursor < dec.length || dec.read(); dec.cursor++ {
+ switch i {
+ case 0:
+ if dec.data[dec.cursor] != 'r' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 1:
+ if dec.data[dec.cursor] != 'u' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 2:
+ if dec.data[dec.cursor] != 'e' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ return nil
+ default:
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ i++
+ }
+ return InvalidJSONError("Invalid JSO")
+}
+
+func (dec *Decoder) assertNull() error {
+ i := 0
+ for ; dec.cursor < dec.length || dec.read(); dec.cursor++ {
+ switch i {
+ case 0:
+ if dec.data[dec.cursor] != 'u' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 1:
+ if dec.data[dec.cursor] != 'l' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 2:
+ if dec.data[dec.cursor] != 'l' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ return nil
+ default:
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ i++
+ }
+ return InvalidJSONError("Invalid JSON")
+}
+
+func (dec *Decoder) assertFalse() error {
+ i := 0
+ for ; dec.cursor < dec.length || dec.read(); dec.cursor++ {
+ switch i {
+ case 0:
+ if dec.data[dec.cursor] != 'a' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 1:
+ if dec.data[dec.cursor] != 'l' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 2:
+ if dec.data[dec.cursor] != 's' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ case 3:
+ if dec.data[dec.cursor] != 'e' {
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ return nil
+ default:
+ return InvalidJSONError(fmt.Sprintf(invalidJSONCharErrorMsg, dec.data[dec.cursor], dec.cursor))
+ }
+ i++
+ }
+ return InvalidJSONError("Invalid JSON")
+}
diff --git a/decode_embedded_json.go b/decode_embedded_json.go
@@ -1,5 +1,7 @@
package gojay
+import "log"
+
// EmbeddedJSON is a raw encoded JSON value.
// It can be used to delay JSON decoding or precompute a JSON encoding.
type EmbeddedJSON []byte
@@ -15,13 +17,32 @@ func (dec *Decoder) decodeEmbeddedJSON(ej *EmbeddedJSON) error {
case ' ', '\n', '\t', '\r', ',':
continue
// is null
- case 'n', 't':
+ case 'n':
+ beginOfEmbeddedJSON = dec.cursor
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
+ dec.cursor++
+ case 't':
beginOfEmbeddedJSON = dec.cursor
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertTrue()
+ if err != nil {
+ return err
+ }
+ dec.cursor++
// is false
case 'f':
beginOfEmbeddedJSON = dec.cursor
- dec.cursor = dec.cursor + 5
+ dec.cursor++
+ err := dec.assertFalse()
+ if err != nil {
+ return err
+ }
+ dec.cursor++
+ log.Print(string(dec.data[:dec.cursor]))
// is an object
case '{':
beginOfEmbeddedJSON = dec.cursor
diff --git a/decode_number.go b/decode_number.go
@@ -83,7 +83,12 @@ func (dec *Decoder) decodeInt(v *int) error {
*v = -int(val)
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
+ dec.cursor++
return nil
default:
dec.err = InvalidTypeError(
@@ -133,7 +138,11 @@ func (dec *Decoder) decodeInt32(v *int32) error {
*v = -val
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
return nil
default:
dec.err = InvalidTypeError(
@@ -185,7 +194,11 @@ func (dec *Decoder) decodeUint32(v *uint32) error {
*v = val
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
return nil
default:
dec.err = InvalidTypeError(
@@ -236,7 +249,11 @@ func (dec *Decoder) decodeInt64(v *int64) error {
*v = -val
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
return nil
default:
dec.err = InvalidTypeError(
@@ -287,7 +304,11 @@ func (dec *Decoder) decodeUint64(v *uint64) error {
*v = val
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
return nil
default:
dec.err = InvalidTypeError(
@@ -337,7 +358,11 @@ func (dec *Decoder) decodeFloat64(v *float64) error {
*v = -val
return nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
return nil
default:
dec.err = InvalidTypeError(
diff --git a/decode_object.go b/decode_object.go
@@ -79,8 +79,12 @@ func (dec *Decoder) decodeObject(j UnmarshalerObject) (int, error) {
}
return dec.cursor, nil
case 'n':
- // is null
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return 0, err
+ }
+ dec.cursor++
return dec.cursor, nil
default:
// can't unmarshall to struct
@@ -184,12 +188,27 @@ func (dec *Decoder) skipData() error {
case ' ', '\n', '\t', '\r', ',':
continue
// is null
- case 'n', 't':
- dec.cursor = dec.cursor + 4
+ case 'n':
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
+ return nil
+ case 't':
+ dec.cursor++
+ err := dec.assertTrue()
+ if err != nil {
+ return err
+ }
return nil
// is false
case 'f':
- dec.cursor = dec.cursor + 5
+ dec.cursor++
+ err := dec.assertFalse()
+ if err != nil {
+ return err
+ }
return nil
// is an object
case '{':
diff --git a/decode_stream_test.go b/decode_stream_test.go
@@ -141,7 +141,6 @@ func TestStreamDecodingObjectsParallel(t *testing.T) {
},
expectations: func(err error, result []*TestObj, t *testing.T) {
assert.Nil(t, err, "err should be nil")
-
assert.Equal(t, 0, result[0].test, "result[0].test should be equal to 0 as input is null")
assert.Equal(t, 0, result[0].test2, "result[0].test2 should be equal to 0 as input is null")
assert.Equal(t, "", result[0].test3, "result[0].test3 should be equal to \"\" as input is null")
diff --git a/decode_string.go b/decode_string.go
@@ -33,7 +33,12 @@ func (dec *Decoder) decodeString(v *string) error {
return nil
// is nil
case 'n':
- dec.cursor = dec.cursor + 4
+ dec.cursor++
+ err := dec.assertNull()
+ if err != nil {
+ return err
+ }
+ dec.cursor++
return nil
default:
dec.err = InvalidTypeError(
diff --git a/errors.go b/errors.go
@@ -1,5 +1,7 @@
package gojay
+const invalidJSONCharErrorMsg = "Invalid JSON character %c found at position %d"
+
// InvalidJSONError is a type representing an error returned when
// Decoding encounters invalid JSON.
type InvalidJSONError string