go-marshal: Allow array lens to be consts and simple expressions.

Previously, go-marshal only allowed literals for array
lengths. However, it's very common for ABI structs to have a fix-sized
array whose length is defined by a constant; for example PATH_MAX.
Having to convert all such arrays to have literal lengths is too
awkward.

PiperOrigin-RevId: 304289345
This commit is contained in:
Rahat Mahmood 2020-04-01 16:50:16 -07:00 committed by gVisor bot
parent aecd3a25a9
commit 1561ae3037
6 changed files with 108 additions and 57 deletions

View File

@ -356,7 +356,7 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf
case *ast.ArrayType:
i.validateArrayNewtype(t.spec.Name, ty)
// After validate, we can safely call arrayLen.
i.emitMarshallableForArrayNewtype(t.spec.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
if t.slice != nil {
abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
}

View File

@ -15,8 +15,10 @@
package gomarshal
import (
"fmt"
"go/ast"
"go/token"
"strings"
)
// interfaceGenerator generates marshalling interfaces for a single type.
@ -224,3 +226,51 @@ func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
g.emit("// must live until the use above.\n")
g.emit("runtime.KeepAlive(%s)\n", ptrVar)
}
func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
switch x := e.X.(type) {
case *ast.BinaryExpr:
// Recursively expand sub-expression.
g.expandBinaryExpr(b, x)
case *ast.Ident:
fmt.Fprintf(b, "%s", x.Name)
case *ast.BasicLit:
fmt.Fprintf(b, "%s", x.Value)
default:
g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
}
fmt.Fprintf(b, "%s", e.Op)
switch y := e.Y.(type) {
case *ast.BinaryExpr:
// Recursively expand sub-expression.
g.expandBinaryExpr(b, y)
case *ast.Ident:
fmt.Fprintf(b, "%s", y.Name)
case *ast.BasicLit:
fmt.Fprintf(b, "%s", y.Value)
default:
g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
}
}
// arrayLenExpr returns a string containing a valid golang expression
// representing the length of array a. The returned expression should be treated
// as a single value, and will be already parenthesized as required.
func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
var b strings.Builder
switch l := a.Len.(type) {
case *ast.Ident:
fmt.Fprintf(&b, "%s", l.Name)
case *ast.BasicLit:
fmt.Fprintf(&b, "%s", l.Value)
case *ast.BinaryExpr:
g.expandBinaryExpr(&b, l)
return fmt.Sprintf("(%s)", b.String())
default:
g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
}
return b.String()
}

View File

@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType
g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
}
if _, ok := a.Len.(*ast.BasicLit); !ok {
g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
}
if _, ok := a.Elt.(*ast.Ident); !ok {
g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
}
if arrayLen(a) <= 0 {
g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
}
}
func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
g.recordUsedImport("io")
g.recordUsedImport("marshal")
g.recordUsedImport("reflect")
@ -49,13 +41,15 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.recordUsedImport("unsafe")
g.recordUsedImport("usermem")
lenExpr := g.arrayLenExpr(a)
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
g.inIndent(func() {
if size, dynamic := g.scalarSize(elt); !dynamic {
g.emit("return %d\n", size*len)
g.emit("return %d * %s\n", size, lenExpr)
} else {
g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
}
})
g.emit("}\n\n")
@ -63,7 +57,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("for idx := 0; idx < %d; idx++ {\n", len)
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
})
@ -74,7 +68,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("for idx := 0; idx < %d; idx++ {\n", len)
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
})

View File

@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
// No validation to perform on selector fields. However this
// callback must still be provided.
},
array: func(n, _ *ast.Ident, len int) {
g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
g.validateArrayNewtype(n, a)
},
unhandled: func(_ *ast.Ident) {
g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
@ -112,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.recordUsedMarshallable(tName)
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
},
array: func(n, t *ast.Ident, len int) {
if len < 1 {
// Zero-length arrays should've been rejected by validate().
panic("unreachable")
}
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
lenExpr := g.arrayLenExpr(a)
if size, dynamic := g.scalarSize(t); !dynamic {
primitiveSize += size * len
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
} else {
g.recordUsedMarshallable(t.Name)
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
}
},
}.dispatch)
@ -169,22 +166,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
}
g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
array: func(n, t *ast.Ident, size int) {
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
if len, dynamic := g.scalarSize(t); !dynamic {
g.shift("dst", len*size)
g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
if size, dynamic := g.scalarSize(t); !dynamic {
g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can reference here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
@ -224,22 +222,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
}
g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
array: func(n, t *ast.Ident, size int) {
array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
if len, dynamic := g.scalarSize(t); !dynamic {
g.shift("src", len*size)
g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
if size, dynamic := g.scalarSize(t); !dynamic {
g.emit("src = src[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can referece here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})

View File

@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
"strconv"
"strings"
)
@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
array func(n, t *ast.Ident, size int)
array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
// Precondition: a must have a literal for the array length. Consts and
// expressions are not allowed as array lengths, and should be rejected by the
// caller.
func arrayLen(a *ast.ArrayType) int {
if a.Len == nil {
// Probably a slice? Must be handled by caller.
panic("Nil array length in array type")
}
lenLit, ok := a.Len.(*ast.BasicLit)
if !ok {
panic("Array has non-literal for length")
}
len, err := strconv.Atoi(lenLit.Value)
if err != nil {
panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
}
return len
}
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.ArrayType:
switch t := v.Elt.(type) {
case *ast.Ident:
fd.array(name, t, arrayLen(v))
fd.array(name, v, t)
default:
// Should be handled with a better error message during validate.
panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))

View File

@ -146,3 +146,31 @@ type SignalSet uint64
//
// +marshal slice:SignalSetAliasSlice
type SignalSetAlias SignalSet
const sizeA = 64
const sizeB = 8
// TestArray is a test data structure on an array with a constant length.
//
// +marshal
type TestArray [sizeA]int32
// TestArray2 is a newtype on an array with a simple arithmetic expression of
// constants for the array length.
//
// +marshal
type TestArray2 [sizeA * sizeB]int32
// TestArray2 is a newtype on an array with a simple arithmetic expression of
// mixed constants and literals for the array length.
//
// +marshal
type TestArray3 [sizeA*sizeB + 12]int32
// Type9 is a test data type containing an array with a non-literal length.
//
// +marshal
type Type9 struct {
x int64
y [sizeA]int32
}