go-marshal: Allow array lens to be consts and simple expressions.
Previously, go-marshal only allowed literals for array lengths. However, it's very common for ABI structs to have a fix-sized array whose length is defined by a constant; for example PATH_MAX. Having to convert all such arrays to have literal lengths is too awkward. PiperOrigin-RevId: 304289345
This commit is contained in:
parent
aecd3a25a9
commit
1561ae3037
|
@ -356,7 +356,7 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf
|
|||
case *ast.ArrayType:
|
||||
i.validateArrayNewtype(t.spec.Name, ty)
|
||||
// After validate, we can safely call arrayLen.
|
||||
i.emitMarshallableForArrayNewtype(t.spec.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
|
||||
i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
|
||||
if t.slice != nil {
|
||||
abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
|
||||
}
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
package gomarshal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// interfaceGenerator generates marshalling interfaces for a single type.
|
||||
|
@ -224,3 +226,51 @@ func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
|
|||
g.emit("// must live until the use above.\n")
|
||||
g.emit("runtime.KeepAlive(%s)\n", ptrVar)
|
||||
}
|
||||
|
||||
func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
|
||||
switch x := e.X.(type) {
|
||||
case *ast.BinaryExpr:
|
||||
// Recursively expand sub-expression.
|
||||
g.expandBinaryExpr(b, x)
|
||||
case *ast.Ident:
|
||||
fmt.Fprintf(b, "%s", x.Name)
|
||||
case *ast.BasicLit:
|
||||
fmt.Fprintf(b, "%s", x.Value)
|
||||
default:
|
||||
g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
|
||||
}
|
||||
|
||||
fmt.Fprintf(b, "%s", e.Op)
|
||||
|
||||
switch y := e.Y.(type) {
|
||||
case *ast.BinaryExpr:
|
||||
// Recursively expand sub-expression.
|
||||
g.expandBinaryExpr(b, y)
|
||||
case *ast.Ident:
|
||||
fmt.Fprintf(b, "%s", y.Name)
|
||||
case *ast.BasicLit:
|
||||
fmt.Fprintf(b, "%s", y.Value)
|
||||
default:
|
||||
g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
|
||||
}
|
||||
}
|
||||
|
||||
// arrayLenExpr returns a string containing a valid golang expression
|
||||
// representing the length of array a. The returned expression should be treated
|
||||
// as a single value, and will be already parenthesized as required.
|
||||
func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
|
||||
var b strings.Builder
|
||||
|
||||
switch l := a.Len.(type) {
|
||||
case *ast.Ident:
|
||||
fmt.Fprintf(&b, "%s", l.Name)
|
||||
case *ast.BasicLit:
|
||||
fmt.Fprintf(&b, "%s", l.Value)
|
||||
case *ast.BinaryExpr:
|
||||
g.expandBinaryExpr(&b, l)
|
||||
return fmt.Sprintf("(%s)", b.String())
|
||||
default:
|
||||
g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
|
|
@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType
|
|||
g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
|
||||
}
|
||||
|
||||
if _, ok := a.Len.(*ast.BasicLit); !ok {
|
||||
g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
|
||||
}
|
||||
|
||||
if _, ok := a.Elt.(*ast.Ident); !ok {
|
||||
g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
|
||||
}
|
||||
|
||||
if arrayLen(a) <= 0 {
|
||||
g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
|
||||
}
|
||||
}
|
||||
|
||||
func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
|
||||
func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
|
||||
g.recordUsedImport("io")
|
||||
g.recordUsedImport("marshal")
|
||||
g.recordUsedImport("reflect")
|
||||
|
@ -49,13 +41,15 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
|
|||
g.recordUsedImport("unsafe")
|
||||
g.recordUsedImport("usermem")
|
||||
|
||||
lenExpr := g.arrayLenExpr(a)
|
||||
|
||||
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
|
||||
g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
|
||||
g.inIndent(func() {
|
||||
if size, dynamic := g.scalarSize(elt); !dynamic {
|
||||
g.emit("return %d\n", size*len)
|
||||
g.emit("return %d * %s\n", size, lenExpr)
|
||||
} else {
|
||||
g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
|
||||
g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
|
||||
}
|
||||
})
|
||||
g.emit("}\n\n")
|
||||
|
@ -63,7 +57,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
|
|||
g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
|
||||
g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
|
||||
g.inIndent(func() {
|
||||
g.emit("for idx := 0; idx < %d; idx++ {\n", len)
|
||||
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
|
||||
g.inIndent(func() {
|
||||
g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
|
||||
})
|
||||
|
@ -74,7 +68,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
|
|||
g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
|
||||
g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
|
||||
g.inIndent(func() {
|
||||
g.emit("for idx := 0; idx < %d; idx++ {\n", len)
|
||||
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
|
||||
g.inIndent(func() {
|
||||
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
|
||||
})
|
||||
|
|
|
@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
|
|||
// No validation to perform on selector fields. However this
|
||||
// callback must still be provided.
|
||||
},
|
||||
array: func(n, _ *ast.Ident, len int) {
|
||||
g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
|
||||
array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
|
||||
g.validateArrayNewtype(n, a)
|
||||
},
|
||||
unhandled: func(_ *ast.Ident) {
|
||||
g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
|
||||
|
@ -112,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
|
|||
g.recordUsedMarshallable(tName)
|
||||
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
|
||||
},
|
||||
array: func(n, t *ast.Ident, len int) {
|
||||
if len < 1 {
|
||||
// Zero-length arrays should've been rejected by validate().
|
||||
panic("unreachable")
|
||||
}
|
||||
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
|
||||
lenExpr := g.arrayLenExpr(a)
|
||||
if size, dynamic := g.scalarSize(t); !dynamic {
|
||||
primitiveSize += size * len
|
||||
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
|
||||
} else {
|
||||
g.recordUsedMarshallable(t.Name)
|
||||
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
|
||||
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
|
||||
}
|
||||
},
|
||||
}.dispatch)
|
||||
|
@ -169,22 +166,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
|
|||
}
|
||||
g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
|
||||
},
|
||||
array: func(n, t *ast.Ident, size int) {
|
||||
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
|
||||
lenExpr := g.arrayLenExpr(a)
|
||||
if n.Name == "_" {
|
||||
g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
|
||||
if len, dynamic := g.scalarSize(t); !dynamic {
|
||||
g.shift("dst", len*size)
|
||||
g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
|
||||
if size, dynamic := g.scalarSize(t); !dynamic {
|
||||
g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
|
||||
} else {
|
||||
// We can't use shiftDynamic here because we don't have
|
||||
// an instance of the dynamic type we can reference here
|
||||
// (since the version in this struct is anonymous). Use
|
||||
// a typed nil pointer to call SizeBytes() instead.
|
||||
g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
|
||||
g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
g.emit("for idx := 0; idx < %d; idx++ {\n", size)
|
||||
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
|
||||
g.inIndent(func() {
|
||||
g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
|
||||
})
|
||||
|
@ -224,22 +222,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
|
|||
}
|
||||
g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
|
||||
},
|
||||
array: func(n, t *ast.Ident, size int) {
|
||||
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
|
||||
lenExpr := g.arrayLenExpr(a)
|
||||
if n.Name == "_" {
|
||||
g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
|
||||
if len, dynamic := g.scalarSize(t); !dynamic {
|
||||
g.shift("src", len*size)
|
||||
g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
|
||||
if size, dynamic := g.scalarSize(t); !dynamic {
|
||||
g.emit("src = src[%d*(%s):]\n", size, lenExpr)
|
||||
} else {
|
||||
// We can't use shiftDynamic here because we don't have
|
||||
// an instance of the dynamic type we can referece here
|
||||
// (since the version in this struct is anonymous). Use
|
||||
// a typed nil pointer to call SizeBytes() instead.
|
||||
g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
|
||||
g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
g.emit("for idx := 0; idx < %d; idx++ {\n", size)
|
||||
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
|
||||
g.inIndent(func() {
|
||||
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
|
||||
})
|
||||
|
|
|
@ -25,7 +25,6 @@ import (
|
|||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
|
|||
type fieldDispatcher struct {
|
||||
primitive func(n, t *ast.Ident)
|
||||
selector func(n, tX, tSel *ast.Ident)
|
||||
array func(n, t *ast.Ident, size int)
|
||||
array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
|
||||
unhandled func(n *ast.Ident)
|
||||
}
|
||||
|
||||
// Precondition: a must have a literal for the array length. Consts and
|
||||
// expressions are not allowed as array lengths, and should be rejected by the
|
||||
// caller.
|
||||
func arrayLen(a *ast.ArrayType) int {
|
||||
if a.Len == nil {
|
||||
// Probably a slice? Must be handled by caller.
|
||||
panic("Nil array length in array type")
|
||||
}
|
||||
lenLit, ok := a.Len.(*ast.BasicLit)
|
||||
if !ok {
|
||||
panic("Array has non-literal for length")
|
||||
}
|
||||
len, err := strconv.Atoi(lenLit.Value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
|
||||
}
|
||||
return len
|
||||
}
|
||||
|
||||
// Precondition: All dispatch callbacks that will be invoked must be
|
||||
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
|
||||
func (fd fieldDispatcher) dispatch(f *ast.Field) {
|
||||
|
@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
|
|||
case *ast.ArrayType:
|
||||
switch t := v.Elt.(type) {
|
||||
case *ast.Ident:
|
||||
fd.array(name, t, arrayLen(v))
|
||||
fd.array(name, v, t)
|
||||
default:
|
||||
// Should be handled with a better error message during validate.
|
||||
panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
|
||||
|
|
|
@ -146,3 +146,31 @@ type SignalSet uint64
|
|||
//
|
||||
// +marshal slice:SignalSetAliasSlice
|
||||
type SignalSetAlias SignalSet
|
||||
|
||||
const sizeA = 64
|
||||
const sizeB = 8
|
||||
|
||||
// TestArray is a test data structure on an array with a constant length.
|
||||
//
|
||||
// +marshal
|
||||
type TestArray [sizeA]int32
|
||||
|
||||
// TestArray2 is a newtype on an array with a simple arithmetic expression of
|
||||
// constants for the array length.
|
||||
//
|
||||
// +marshal
|
||||
type TestArray2 [sizeA * sizeB]int32
|
||||
|
||||
// TestArray2 is a newtype on an array with a simple arithmetic expression of
|
||||
// mixed constants and literals for the array length.
|
||||
//
|
||||
// +marshal
|
||||
type TestArray3 [sizeA*sizeB + 12]int32
|
||||
|
||||
// Type9 is a test data type containing an array with a non-literal length.
|
||||
//
|
||||
// +marshal
|
||||
type Type9 struct {
|
||||
x int64
|
||||
y [sizeA]int32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue