Infer receiver name for stateify.
PiperOrigin-RevId: 336340035
This commit is contained in:
parent
257703c050
commit
743327817f
|
@ -39,7 +39,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// resolveTypeName returns a qualified type name.
|
// resolveTypeName returns a qualified type name.
|
||||||
func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) {
|
func resolveTypeName(typ ast.Expr) (field string, qualified string) {
|
||||||
for done := false; !done; {
|
for done := false; !done; {
|
||||||
// Resolve star expressions.
|
// Resolve star expressions.
|
||||||
switch rs := typ.(type) {
|
switch rs := typ.(type) {
|
||||||
|
@ -69,11 +69,7 @@ func resolveTypeName(name string, typ ast.Expr) (field string, qualified string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Figure out actual type name.
|
// Figure out actual type name.
|
||||||
ident, ok := typ.(*ast.Ident)
|
field = typ.(*ast.Ident).Name
|
||||||
if !ok {
|
|
||||||
panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name))
|
|
||||||
}
|
|
||||||
field = ident.Name
|
|
||||||
qualified = qualified + field
|
qualified = qualified + field
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -119,7 +115,7 @@ func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
|
||||||
} else {
|
} else {
|
||||||
// Anonymous types can't be embedded, so we don't need
|
// Anonymous types can't be embedded, so we don't need
|
||||||
// to worry about providing a useful name here.
|
// to worry about providing a useful name here.
|
||||||
name, _ = resolveTypeName("", field.Type)
|
name, _ = resolveTypeName(field.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip _ fields.
|
// Skip _ fields.
|
||||||
|
@ -262,52 +258,39 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type method struct {
|
type method struct {
|
||||||
receiver string
|
typeName string
|
||||||
name string
|
methodName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search for and add all methods with a pointer receiver and no other
|
// Search for and add all method to a set. We auto-detecting several
|
||||||
// arguments to a set. We support auto-detecting the existence of
|
// different methods (and insert them if we don't find them, in order
|
||||||
// several different methods with this signature.
|
// to ensure that expectations match reality).
|
||||||
simpleMethods := map[method]struct{}{}
|
//
|
||||||
|
// While we do this, figure out the right receiver name. If there are
|
||||||
|
// multiple distinct receivers, then we will just pick the last one.
|
||||||
|
simpleMethods := make(map[method]struct{})
|
||||||
|
receiverNames := make(map[string]string)
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
|
|
||||||
// Go over all functions.
|
// Go over all functions.
|
||||||
for _, decl := range f.Decls {
|
for _, decl := range f.Decls {
|
||||||
d, ok := decl.(*ast.FuncDecl)
|
d, ok := decl.(*ast.FuncDecl)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if d.Name == nil || d.Recv == nil || d.Type == nil {
|
if d.Recv == nil || len(d.Recv.List) != 1 {
|
||||||
// Not a named method.
|
// Not a named method.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(d.Recv.List) != 1 {
|
|
||||||
// Wrong number of receivers?
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if d.Type.Params != nil && len(d.Type.Params.List) != 0 {
|
|
||||||
// Has argument(s).
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if d.Type.Results != nil && len(d.Type.Results.List) != 0 {
|
|
||||||
// Has return(s).
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
pt, ok := d.Recv.List[0].Type.(*ast.StarExpr)
|
// Save the method and the receiver.
|
||||||
if !ok {
|
name, _ := resolveTypeName(d.Recv.List[0].Type)
|
||||||
// Not a pointer receiver.
|
simpleMethods[method{
|
||||||
continue
|
typeName: name,
|
||||||
|
methodName: d.Name.Name,
|
||||||
|
}] = struct{}{}
|
||||||
|
if len(d.Recv.List[0].Names) > 0 {
|
||||||
|
receiverNames[name] = d.Recv.List[0].Names[0].Name
|
||||||
}
|
}
|
||||||
|
|
||||||
t, ok := pt.X.(*ast.Ident)
|
|
||||||
if !ok {
|
|
||||||
// This shouldn't happen with valid Go.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
simpleMethods[method{t.Name, d.Name.Name}] = struct{}{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,7 +329,11 @@ func main() {
|
||||||
|
|
||||||
for _, gs := range d.Specs {
|
for _, gs := range d.Specs {
|
||||||
ts := gs.(*ast.TypeSpec)
|
ts := gs.(*ast.TypeSpec)
|
||||||
letter := strings.ToLower(ts.Name.Name[:1])
|
recv, ok := receiverNames[ts.Name.Name]
|
||||||
|
if !ok {
|
||||||
|
// Maybe no methods were defined?
|
||||||
|
recv = strings.ToLower(ts.Name.Name[:1])
|
||||||
|
}
|
||||||
switch x := ts.Type.(type) {
|
switch x := ts.Type.(type) {
|
||||||
case *ast.StructType:
|
case *ast.StructType:
|
||||||
maybeEmitImports()
|
maybeEmitImports()
|
||||||
|
@ -363,32 +350,32 @@ func main() {
|
||||||
emitField(name)
|
emitField(name)
|
||||||
}
|
}
|
||||||
emitLoadValue := func(name, typName string) {
|
emitLoadValue := func(name, typName string) {
|
||||||
fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, letter, camelCased(name), typName)
|
fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName)
|
||||||
}
|
}
|
||||||
emitLoad := func(name string) {
|
emitLoad := func(name string) {
|
||||||
fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], letter, name)
|
fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name)
|
||||||
}
|
}
|
||||||
emitLoadWait := func(name string) {
|
emitLoadWait := func(name string) {
|
||||||
fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], letter, name)
|
fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name)
|
||||||
}
|
}
|
||||||
emitSaveValue := func(name, typName string) {
|
emitSaveValue := func(name, typName string) {
|
||||||
fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, letter, camelCased(name))
|
fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name))
|
||||||
fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
|
fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
|
||||||
}
|
}
|
||||||
emitSave := func(name string) {
|
emitSave := func(name string) {
|
||||||
fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], letter, name)
|
fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name)
|
||||||
}
|
}
|
||||||
emitZeroCheck := func(name string) {
|
emitZeroCheck := func(name string) {
|
||||||
fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, letter, name, statePrefix, name, letter, name)
|
fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the type name method.
|
// Generate the type name method.
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
|
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, "}\n\n")
|
fmt.Fprintf(outputFile, "}\n\n")
|
||||||
|
|
||||||
// Generate the fields method.
|
// Generate the fields method.
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, " return []string{\n")
|
fmt.Fprintf(outputFile, " return []string{\n")
|
||||||
scanFields(x, "", scanFunctions{
|
scanFields(x, "", scanFunctions{
|
||||||
normal: emitField,
|
normal: emitField,
|
||||||
|
@ -402,8 +389,11 @@ func main() {
|
||||||
// the code from compiling if a custom beforeSave was defined in a
|
// the code from compiling if a custom beforeSave was defined in a
|
||||||
// file not provided to this binary and prevents inherited methods
|
// file not provided to this binary and prevents inherited methods
|
||||||
// from being called multiple times by overriding them.
|
// from being called multiple times by overriding them.
|
||||||
if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
|
if _, ok := simpleMethods[method{
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", letter, ts.Name.Name)
|
typeName: ts.Name.Name,
|
||||||
|
methodName: "beforeSave",
|
||||||
|
}]; !ok && generateSaverLoader {
|
||||||
|
fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the save method.
|
// Generate the save method.
|
||||||
|
@ -413,8 +403,8 @@ func main() {
|
||||||
// on this specific behavior, but the ability to specify slots
|
// on this specific behavior, but the ability to specify slots
|
||||||
// allows a manual implementation to be order-dependent.
|
// allows a manual implementation to be order-dependent.
|
||||||
if generateSaverLoader {
|
if generateSaverLoader {
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", letter, ts.Name.Name, statePrefix)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
|
||||||
fmt.Fprintf(outputFile, " %s.beforeSave()\n", letter)
|
fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv)
|
||||||
scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
|
scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
|
||||||
scanFields(x, "", scanFunctions{value: emitSaveValue})
|
scanFields(x, "", scanFunctions{value: emitSaveValue})
|
||||||
scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
|
scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
|
||||||
|
@ -423,16 +413,19 @@ func main() {
|
||||||
|
|
||||||
// Define afterLoad if a definition was not found. We do this for
|
// Define afterLoad if a definition was not found. We do this for
|
||||||
// the same reason that we do it for beforeSave.
|
// the same reason that we do it for beforeSave.
|
||||||
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
|
_, hasAfterLoad := simpleMethods[method{
|
||||||
|
typeName: ts.Name.Name,
|
||||||
|
methodName: "afterLoad",
|
||||||
|
}]
|
||||||
if !hasAfterLoad && generateSaverLoader {
|
if !hasAfterLoad && generateSaverLoader {
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", letter, ts.Name.Name)
|
fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the load method.
|
// Generate the load method.
|
||||||
//
|
//
|
||||||
// N.B. See the comment above for the save method.
|
// N.B. See the comment above for the save method.
|
||||||
if generateSaverLoader {
|
if generateSaverLoader {
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", letter, ts.Name.Name, statePrefix)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
|
||||||
scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
|
scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
|
||||||
scanFields(x, "", scanFunctions{value: emitLoadValue})
|
scanFields(x, "", scanFunctions{value: emitLoadValue})
|
||||||
if hasAfterLoad {
|
if hasAfterLoad {
|
||||||
|
@ -440,7 +433,7 @@ func main() {
|
||||||
// AfterLoad is called, the object encodes a dependency on
|
// AfterLoad is called, the object encodes a dependency on
|
||||||
// referred objects (i.e. fields). This means that afterLoad
|
// referred objects (i.e. fields). This means that afterLoad
|
||||||
// will not be called until the other afterLoads are called.
|
// will not be called until the other afterLoads are called.
|
||||||
fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", letter)
|
fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(outputFile, "}\n\n")
|
fmt.Fprintf(outputFile, "}\n\n")
|
||||||
}
|
}
|
||||||
|
@ -452,10 +445,10 @@ func main() {
|
||||||
maybeEmitImports()
|
maybeEmitImports()
|
||||||
|
|
||||||
// Generate the info methods.
|
// Generate the info methods.
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
|
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, "}\n\n")
|
fmt.Fprintf(outputFile, "}\n\n")
|
||||||
fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
|
fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
|
||||||
fmt.Fprintf(outputFile, " return nil\n")
|
fmt.Fprintf(outputFile, " return nil\n")
|
||||||
fmt.Fprintf(outputFile, "}\n\n")
|
fmt.Fprintf(outputFile, "}\n\n")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue