commit ed06be23fa1713f994603eb28bfe7b59f9ae0f2b
parent b8dc210cd231638e2cdf5b330e4050ceab664d0f
Author: francoispqt <francois@parquet.ninja>
Date: Sun, 27 May 2018 23:13:33 +0800
first commit for code generator
Diffstat:
12 files changed, 838 insertions(+), 0 deletions(-)
diff --git a/gojay/Makefile b/gojay/Makefile
@@ -0,0 +1,3 @@
+.PHONY: build
+build:
+ go build ./
+\ No newline at end of file
diff --git a/gojay/gen.go b/gojay/gen.go
@@ -0,0 +1,81 @@
+package main
+
+import (
+ "go/ast"
+ "html/template"
+ "log"
+ "os"
+
+ "github.com/davecgh/go-spew/spew"
+)
+
+const genFileSuffix = "_gojay.go"
+
+var pkgTpl *template.Template
+var gojayImport = []byte("import \"github.com/francoispqt/gojay\"\n")
+
+func init() {
+ t, err := template.New("pkgDef").
+ Parse("package {{.PkgName}} \n\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ pkgTpl = t
+}
+
+func (v *vis) gen() error {
+ // open the file
+ f, err := os.OpenFile(v.genFileName(), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ // write package
+ err = v.writePkg(f)
+ if err != nil {
+ return err
+ }
+ // write import of gojay
+ err = v.writeGojayImport(f)
+ if err != nil {
+ return err
+ }
+ // range over specs
+ // generate interfaces implementations based on type
+ for _, s := range v.specs {
+ switch t := s.Type.(type) {
+ case *ast.StructType:
+ err = v.genStruct(f, s.Name.String(), t)
+ if err != nil {
+ return err
+ }
+ case *ast.ArrayType:
+ spew.Println(t, "arr")
+ }
+ }
+ return nil
+}
+
+func (v *vis) genFileName() string {
+ return v.file[:len(v.file)-3] + genFileSuffix
+}
+
+func (v *vis) writePkg(f *os.File) error {
+ err := pkgTpl.Execute(f, struct {
+ PkgName string
+ }{
+ PkgName: v.pkg,
+ })
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) writeGojayImport(f *os.File) error {
+ _, err := f.Write(gojayImport)
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/gojay/gen_map.go b/gojay/gen_map.go
@@ -0,0 +1 @@
+package main
diff --git a/gojay/gen_struct_marshal.go b/gojay/gen_struct_marshal.go
@@ -0,0 +1,220 @@
+package main
+
+import (
+ "go/ast"
+ "html/template"
+ "log"
+ "os"
+ "strings"
+)
+
+var structMarshalDefTpl *template.Template
+var structMarshalStringTpl *template.Template
+var structMarshalIntTpl *template.Template
+var structMarshalUintTpl *template.Template
+var structMarshalBoolTpl *template.Template
+
+var structIsNilTpl *template.Template
+
+var isNilMethod = `
+// IsNil returns wether the structure is nil value or not
+func (v *{{.StructName}}) IsNil() bool { return v == nil }
+`
+
+func init() {
+ t, err := template.New("structUnmarshalDef").
+ Parse("\n// MarshalJSONObject implements gojay's MarshalerJSONObject" +
+ "\nfunc (v *{{.StructName}}) MarshalJSONOject(enc *gojay.Encoder) {\n",
+ )
+ if err != nil {
+ log.Fatal(err)
+ }
+ structMarshalDefTpl = t
+
+ t, err = template.New("structMarshalCaseString").
+ Parse("\tenc.StringKey(\"{{.Key}}\", v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structMarshalStringTpl = t
+
+ t, err = template.New("structMarshalCaseInt").
+ Parse("\tenc.Int{{.IntLen}}Key(\"{{.Key}}\", v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structMarshalIntTpl = t
+
+ t, err = template.New("structMarshalCaseUint").
+ Parse("\tenc.Uint{{.IntLen}}Key(\"{{.Key}}\", v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structMarshalUintTpl = t
+
+ t, err = template.New("structMarshalCaseBool").
+ Parse("\tenc.BoolKey(\"{{.Key}}\", v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structMarshalBoolTpl = t
+
+ t, err = template.New("structMarhalIsNil").
+ Parse(isNilMethod)
+ if err != nil {
+ log.Fatal(err)
+ }
+ structIsNilTpl = t
+}
+
+func (v *vis) structGenIsNil(f *os.File, n string) error {
+ err := structIsNilTpl.Execute(f, struct {
+ StructName string
+ }{
+ StructName: n,
+ })
+ return err
+}
+
+func (v *vis) structGenMarshalObj(f *os.File, n string, s *ast.StructType) (int, error) {
+ err := structMarshalDefTpl.Execute(f, struct {
+ StructName string
+ }{
+ StructName: n,
+ })
+ if err != nil {
+ return 0, err
+ }
+ keys := 0
+ if len(s.Fields.List) > 0 {
+ // TODO: check tags
+ for _, field := range s.Fields.List {
+ switch t := field.Type.(type) {
+ case *ast.Ident:
+ switch t.String() {
+ case "string":
+ err = v.structMarshalString(f, field)
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "bool":
+ err = v.structMarshalBool(f, field)
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int":
+ err = v.structMarshalInt(f, field, "")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int64":
+ err = v.structMarshalInt(f, field, "64")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int32":
+ err = v.structMarshalInt(f, field, "32")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int16":
+ err = v.structMarshalInt(f, field, "16")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int8":
+ err = v.structMarshalInt(f, field, "8")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint64":
+ err = v.structMarshalUint(f, field, "64")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint32":
+ err = v.structMarshalUint(f, field, "32")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint16":
+ err = v.structMarshalUint(f, field, "16")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint8":
+ err = v.structMarshalUint(f, field, "8")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ }
+ }
+ }
+ }
+ _, err = f.Write([]byte("}\n"))
+ if err != nil {
+ return 0, err
+ }
+ return keys, nil
+}
+
+func (v *vis) structMarshalString(f *os.File, field *ast.Field) error {
+ key := field.Names[0].String()
+ err := structMarshalStringTpl.Execute(f, struct {
+ Field string
+ Key string
+ }{key, strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structMarshalBool(f *os.File, field *ast.Field) error {
+ key := field.Names[0].String()
+ err := structMarshalBoolTpl.Execute(f, struct {
+ Field string
+ Key string
+ }{key, strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structMarshalInt(f *os.File, field *ast.Field, intLen string) error {
+ key := field.Names[0].String()
+ err := structMarshalIntTpl.Execute(f, struct {
+ Field string
+ IntLen string
+ Key string
+ }{key, intLen, strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structMarshalUint(f *os.File, field *ast.Field, intLen string) error {
+ key := field.Names[0].String()
+ err := structMarshalUintTpl.Execute(f, struct {
+ Field string
+ IntLen string
+ Key string
+ }{key, intLen, strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/gojay/gen_struct_unmarshal.go b/gojay/gen_struct_unmarshal.go
@@ -0,0 +1,258 @@
+package main
+
+import (
+ "go/ast"
+ "html/template"
+ "log"
+ "os"
+ "strings"
+)
+
+var structUnmarshalDefTpl *template.Template
+var structUnmarshalCaseTpl *template.Template
+var structUnmarshalStringTpl *template.Template
+var structUnmarshalIntTpl *template.Template
+var structUnmarshalUintTpl *template.Template
+var structUnmarshalBoolTpl *template.Template
+
+var structNKeysTpl *template.Template
+
+var nKeysMethod = `
+// NKeys returns the number of keys to unmarshal
+func (v *{{.StructName}}) NKeys() int { return {{.NKeys}} }
+`
+
+var structUnmarshalSwitchOpen = []byte("\tswitch k {\n")
+var structUnmarshalClose = []byte("\treturn nil\n}\n")
+
+func init() {
+ t, err := template.New("structUnmarshalDef").
+ Parse("\n// UnmarshalJSONOject implements gojay's UnmarshalerJSONObject" +
+ "\nfunc (v *{{.StructName}}) UnmarshalJSONOject(dec *gojay.Decoder, k string) error {\n",
+ )
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalDefTpl = t
+
+ t, err = template.New("structUnmarshalCase").
+ Parse("\tcase \"{{.Key}}\":\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalCaseTpl = t
+
+ t, err = template.New("structUnmarshalCaseString").
+ Parse("\t\treturn dec.String(&v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalStringTpl = t
+
+ t, err = template.New("structUnmarshalCaseString").
+ Parse("\t\treturn dec.Int{{.IntLen}}(&v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalIntTpl = t
+
+ t, err = template.New("structUnmarshalCaseString").
+ Parse("\t\treturn dec.Uint{{.IntLen}}(&v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalUintTpl = t
+
+ t, err = template.New("structUnmarshalCaseString").
+ Parse("\t\treturn dec.Bool(&v.{{.Field}})\n")
+ if err != nil {
+ log.Fatal(err)
+ }
+ structUnmarshalBoolTpl = t
+
+ t, err = template.New("structUnmarshalNKeys").
+ Parse(nKeysMethod)
+ if err != nil {
+ log.Fatal(err)
+ }
+ structNKeysTpl = t
+}
+
+func (v *vis) structGenNKeys(f *os.File, n string, count int) error {
+ err := structNKeysTpl.Execute(f, struct {
+ NKeys int
+ StructName string
+ }{
+ NKeys: count,
+ StructName: n,
+ })
+ return err
+}
+
+func (v *vis) structGenUnmarshalObj(f *os.File, n string, s *ast.StructType) (int, error) {
+ err := structUnmarshalDefTpl.Execute(f, struct {
+ StructName string
+ }{
+ StructName: n,
+ })
+ if err != nil {
+ return 0, err
+ }
+ keys := 0
+ if len(s.Fields.List) > 0 {
+ // open switch statement
+ f.Write(structUnmarshalSwitchOpen)
+
+ // TODO: check tags
+ for _, field := range s.Fields.List {
+ switch t := field.Type.(type) {
+ case *ast.Ident:
+ switch t.String() {
+ case "string":
+ err = v.structUnmarshalString(f, field)
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "bool":
+ err = v.structUnmarshalBool(f, field)
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int":
+ err = v.structUnmarshalInt(f, field, "")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int64":
+ err = v.structUnmarshalInt(f, field, "64")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int32":
+ err = v.structUnmarshalInt(f, field, "32")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int16":
+ err = v.structUnmarshalInt(f, field, "16")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "int8":
+ err = v.structUnmarshalInt(f, field, "8")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint64":
+ err = v.structUnmarshalUint(f, field, "64")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint32":
+ err = v.structUnmarshalUint(f, field, "32")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint16":
+ err = v.structUnmarshalUint(f, field, "16")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ case "uint8":
+ err = v.structUnmarshalUint(f, field, "8")
+ if err != nil {
+ return 0, err
+ }
+ keys++
+ }
+ }
+ }
+ // close switch statement
+ f.Write([]byte("\t}\n"))
+ }
+ _, err = f.Write(structUnmarshalClose)
+ if err != nil {
+ return 0, err
+ }
+ return keys, nil
+}
+
+func (v *vis) structUnmarshalString(f *os.File, field *ast.Field) error {
+ key := field.Names[0].String()
+ err := structUnmarshalCaseTpl.Execute(f, struct {
+ Key string
+ }{strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ err = structUnmarshalStringTpl.Execute(f, struct {
+ Field string
+ }{key})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structUnmarshalBool(f *os.File, field *ast.Field) error {
+ key := field.Names[0].String()
+ err := structUnmarshalCaseTpl.Execute(f, struct {
+ Key string
+ }{strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ err = structUnmarshalBoolTpl.Execute(f, struct {
+ Field string
+ }{key})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structUnmarshalInt(f *os.File, field *ast.Field, intLen string) error {
+ key := field.Names[0].String()
+ err := structUnmarshalCaseTpl.Execute(f, struct {
+ Key string
+ }{strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ err = structUnmarshalIntTpl.Execute(f, struct {
+ Field string
+ IntLen string
+ }{key, intLen})
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (v *vis) structUnmarshalUint(f *os.File, field *ast.Field, intLen string) error {
+ key := field.Names[0].String()
+ err := structUnmarshalCaseTpl.Execute(f, struct {
+ Key string
+ }{strings.ToLower(key)})
+ if err != nil {
+ return err
+ }
+ err = structUnmarshalUintTpl.Execute(f, struct {
+ Field string
+ IntLen string
+ }{key, intLen})
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/gojay/gen_stuct.go b/gojay/gen_stuct.go
@@ -0,0 +1,21 @@
+package main
+
+import (
+ "go/ast"
+ "os"
+)
+
+func (v *vis) genStruct(f *os.File, n string, s *ast.StructType) error {
+ keys, err := v.structGenUnmarshalObj(f, n, s)
+ if err != nil {
+ return err
+ }
+ err = v.structGenNKeys(f, n, keys)
+
+ keys, err = v.structGenMarshalObj(f, n, s)
+ if err != nil {
+ return err
+ }
+ err = v.structGenIsNil(f, n)
+ return nil
+}
diff --git a/gojay/gojay b/gojay/gojay
Binary files differ.
diff --git a/gojay/main.go b/gojay/main.go
@@ -0,0 +1,70 @@
+package main
+
+import (
+ "errors"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+)
+
+const gojayAnnotation = "//gojay:json"
+
+func hasAnnotation(fP string) bool {
+ b, err := ioutil.ReadFile(fP)
+ if err != nil {
+ log.Fatal(err)
+ }
+ return strings.Contains(string(b), gojayAnnotation)
+}
+
+func getPath() (string, error) {
+ p := os.Args[1]
+ return filepath.Abs(p)
+}
+
+func getFiles() ([]string, error) {
+ if len(os.Args) < 2 {
+ return nil, errors.New("Gojay generator takes one argument, 0 given")
+ }
+ p, err := getPath()
+ if err != nil {
+ return nil, err
+ }
+ files, err := ioutil.ReadDir(p)
+ if err != nil {
+ return nil, err
+ }
+ r := make([]string, 0)
+ for _, f := range files {
+ fP := filepath.Join(p, f.Name())
+ if !f.IsDir() && strings.HasSuffix(f.Name(), ".go") && hasAnnotation(fP) {
+ r = append(r, fP)
+ }
+ }
+ return r, nil
+}
+
+func main() {
+ files, err := getFiles()
+ if err != nil {
+ log.Fatal(err)
+ }
+ for _, f := range files {
+ fset := token.NewFileSet()
+ node, err := parser.ParseFile(fset, f, nil, parser.ParseComments)
+ if err != nil {
+ log.Fatal(err)
+ }
+ v := &vis{pkg: node.Name.String(), specs: make([]*ast.TypeSpec, 0), file: f}
+ ast.Walk(v, node)
+ err = v.gen()
+ if err != nil {
+ log.Fatal(err)
+ }
+ }
+}
diff --git a/gojay/tests/basic_structs.go b/gojay/tests/basic_structs.go
@@ -0,0 +1,31 @@
+package tests
+
+//gojay:json
+type A struct {
+ Str string
+ Bool bool
+ Int int
+ Int64 int64
+ Int32 int32
+ Int16 int16
+ Int8 int8
+ Uint64 uint64
+ Uint32 uint32
+ Uint16 uint16
+ Uint8 uint8
+}
+
+//gojay:json
+type B struct {
+ Str string
+ Bool bool
+ Int int
+ Int64 int64
+ Int32 int32
+ Int16 int16
+ Int8 int8
+ Uint64 uint64
+ Uint32 uint32
+ Uint16 uint16
+ Uint8 uint8
+}
diff --git a/gojay/tests/basic_structs_gojay.go b/gojay/tests/basic_structs_gojay.go
@@ -0,0 +1,103 @@
+package tests
+
+import "github.com/francoispqt/gojay"
+
+// UnmarshalJSONOject implements gojay's UnmarshalerJSONObject
+func (v *A) UnmarshalJSONOject(dec *gojay.Decoder, k string) error {
+ switch k {
+ case "str":
+ return dec.String(&v.Str)
+ case "bool":
+ return dec.Bool(&v.Bool)
+ case "int":
+ return dec.Int(&v.Int)
+ case "int64":
+ return dec.Int64(&v.Int64)
+ case "int32":
+ return dec.Int32(&v.Int32)
+ case "int16":
+ return dec.Int16(&v.Int16)
+ case "int8":
+ return dec.Int8(&v.Int8)
+ case "uint64":
+ return dec.Uint64(&v.Uint64)
+ case "uint32":
+ return dec.Uint32(&v.Uint32)
+ case "uint16":
+ return dec.Uint16(&v.Uint16)
+ case "uint8":
+ return dec.Uint8(&v.Uint8)
+ }
+ return nil
+}
+
+// NKeys returns the number of keys to unmarshal
+func (v *A) NKeys() int { return 11 }
+
+// MarshalJSONObject implements gojay's MarshalerJSONObject
+func (v *A) MarshalJSONOject(enc *gojay.Encoder) {
+ enc.StringKey("str", v.Str)
+ enc.BoolKey("bool", v.Bool)
+ enc.IntKey("int", v.Int)
+ enc.Int64Key("int64", v.Int64)
+ enc.Int32Key("int32", v.Int32)
+ enc.Int16Key("int16", v.Int16)
+ enc.Int8Key("int8", v.Int8)
+ enc.Uint64Key("uint64", v.Uint64)
+ enc.Uint32Key("uint32", v.Uint32)
+ enc.Uint16Key("uint16", v.Uint16)
+ enc.Uint8Key("uint8", v.Uint8)
+}
+
+// IsNil returns wether the structure is nil value or not
+func (v *A) IsNil() bool { return v == nil }
+
+// UnmarshalJSONOject implements gojay's UnmarshalerJSONObject
+func (v *B) UnmarshalJSONOject(dec *gojay.Decoder, k string) error {
+ switch k {
+ case "str":
+ return dec.String(&v.Str)
+ case "bool":
+ return dec.Bool(&v.Bool)
+ case "int":
+ return dec.Int(&v.Int)
+ case "int64":
+ return dec.Int64(&v.Int64)
+ case "int32":
+ return dec.Int32(&v.Int32)
+ case "int16":
+ return dec.Int16(&v.Int16)
+ case "int8":
+ return dec.Int8(&v.Int8)
+ case "uint64":
+ return dec.Uint64(&v.Uint64)
+ case "uint32":
+ return dec.Uint32(&v.Uint32)
+ case "uint16":
+ return dec.Uint16(&v.Uint16)
+ case "uint8":
+ return dec.Uint8(&v.Uint8)
+ }
+ return nil
+}
+
+// NKeys returns the number of keys to unmarshal
+func (v *B) NKeys() int { return 11 }
+
+// MarshalJSONObject implements gojay's MarshalerJSONObject
+func (v *B) MarshalJSONOject(enc *gojay.Encoder) {
+ enc.StringKey("str", v.Str)
+ enc.BoolKey("bool", v.Bool)
+ enc.IntKey("int", v.Int)
+ enc.Int64Key("int64", v.Int64)
+ enc.Int32Key("int32", v.Int32)
+ enc.Int16Key("int16", v.Int16)
+ enc.Int8Key("int8", v.Int8)
+ enc.Uint64Key("uint64", v.Uint64)
+ enc.Uint32Key("uint32", v.Uint32)
+ enc.Uint16Key("uint16", v.Uint16)
+ enc.Uint8Key("uint8", v.Uint8)
+}
+
+// IsNil returns wether the structure is nil value or not
+func (v *B) IsNil() bool { return v == nil }
diff --git a/gojay/tests/complex_structs.go b/gojay/tests/complex_structs.go
@@ -0,0 +1 @@
+package tests
diff --git a/gojay/visitor.go b/gojay/visitor.go
@@ -0,0 +1,48 @@
+package main
+
+import (
+ "go/ast"
+ "strings"
+)
+
+func docContains(n *ast.CommentGroup, s string) bool {
+ for _, d := range n.List {
+ if strings.Contains(d.Text, s) {
+ return true
+ }
+ }
+ return false
+}
+
+type vis struct {
+ pkg string
+ specs []*ast.TypeSpec
+ file string
+ commentFound bool
+}
+
+func (v *vis) Visit(n ast.Node) (w ast.Visitor) {
+ switch n := n.(type) {
+ case *ast.Package:
+ v.commentFound = false
+ return v
+ case *ast.File:
+ v.commentFound = false
+ return v
+ case *ast.GenDecl:
+ if n.Doc != nil {
+ v.commentFound = docContains(n.Doc, gojayAnnotation)
+ }
+ return v
+ case *ast.TypeSpec:
+ if v.commentFound {
+ v.specs = append(v.specs, n)
+ }
+ v.commentFound = false
+ return v
+ case *ast.StructType:
+ v.commentFound = false
+ return nil
+ }
+ return nil
+}