diff --git a/generate.go b/generate.go index 4e0cdd8..bc8c327 100644 --- a/generate.go +++ b/generate.go @@ -52,7 +52,7 @@ func genUnmarshalJSON(mtyp *marshalerType) Function { input = Name(m.scope.newIdent("input")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name())) dec = Name(m.scope.newIdent("dec")) - json = Name(m.scope.parent.packageName("encoding/json")) + json = Name("jsonIter") ) fn := Function{ Receiver: recv, @@ -80,7 +80,7 @@ func genMarshalJSON(mtyp *marshalerType) Function { recv = m.receiver() intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name())) enc = Name(m.scope.newIdent("enc")) - json = Name(m.scope.parent.packageName("encoding/json")) + json = Name("jsonIter") ) fn := Function{ Receiver: recv, diff --git a/main.go b/main.go index 795587f..cf446a9 100644 --- a/main.go +++ b/main.go @@ -179,6 +179,7 @@ import ( "io" "os" "reflect" + "slices" "strings" "github.com/garslo/gogen" @@ -298,6 +299,12 @@ func generate(mtyp *marshalerType, cfg *Config) ([]byte, error) { fmt.Fprintln(w) mtyp.scope.writeImportDecl(w) fmt.Fprintln(w) + + if slices.Contains(cfg.Formats, "json") { + fmt.Fprintln(w, "var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary") + fmt.Fprintln(w) + } + if mtyp.override != nil { writeUseOfOverride(w, mtyp.override, mtyp.scope.qualify) } @@ -360,7 +367,7 @@ func newMarshalerType(fs *token.FileSet, imp types.Importer, typ *types.Named) * mtyp.scope.addReferences(styp) // Add packages which are always needed. - mtyp.scope.addImport("encoding/json") + mtyp.scope.addImport("github.com/json-iterator/go") mtyp.scope.addImport("errors") for i := 0; i < styp.NumFields(); i++ { diff --git a/types_util.go b/types_util.go index 44ab1df..ae3bbaf 100644 --- a/types_util.go +++ b/types_util.go @@ -9,8 +9,11 @@ import ( "fmt" "go/types" "io" + "os" "sort" "strconv" + + "golang.org/x/tools/go/packages" ) // walkNamedTypes runs the callback for all named types contained in the given type. @@ -215,6 +218,20 @@ func (s *fileScope) writeImportDecl(w io.Writer) { // addImport loads a package and adds it to the import set. func (s *fileScope) addImport(path string) { pkg, err := s.imp.Import(path) + if err != nil { + // Fallback to module-aware importer via go/packages + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedImports | packages.NeedDeps, + // Preserve build environment so that go/packages can resolve modules. + Env: os.Environ(), + Tests: false, + } + pkgs, perr := packages.Load(cfg, path) + if perr == nil && len(pkgs) > 0 && pkgs[0].Types != nil { + pkg = pkgs[0].Types + err = nil + } + } if err != nil { panic(fmt.Errorf("can't import %q: %v", path, err)) }