Allow use of SeqAtomic with pointer-containing types.

Per runtime.memmove, pointers are always copied atomically, as this is required
by the GC. (Also, the init() safety check doesn't work because it gets renamed
to <prefix>init() by template instantiation.)

PiperOrigin-RevId: 345800302
This commit is contained in:
Jamie Liu 2020-12-04 19:05:08 -08:00 committed by gVisor bot
parent 7a1de8583d
commit 8a45c81616
3 changed files with 0 additions and 104 deletions

View File

@ -8,20 +8,12 @@
package template
import (
"fmt"
"reflect"
"strings"
"unsafe"
"gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
//
// Value must not contain any pointers, including interface objects, function
// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
// containing any of the above. An init() function will panic if this property
// does not hold.
type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
@ -55,12 +47,3 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
ok = seq.ReadOk(epoch)
return
}
func init() {
var val Value
typ := reflect.TypeOf(val)
name := typ.Name()
if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
}
}

View File

@ -6,8 +6,6 @@
package sync
import (
"fmt"
"reflect"
"sync/atomic"
)
@ -27,9 +25,6 @@ import (
// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
// operations to be made atomic with reads of SeqCount-protected data.
//
// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
// cannot include pointers.
//
// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
// data require instantiating function templates using go_generics (see
// seqatomic.go).
@ -128,32 +123,3 @@ func (s *SeqCount) EndWrite() {
panic("SeqCount.EndWrite outside writer critical section")
}
}
// PointersInType returns a list of pointers reachable from values named
// valName of the given type.
//
// PointersInType is not exhaustive, but it is guaranteed that if typ contains
// at least one pointer, then PointersInTypeOf returns a non-empty list.
func PointersInType(typ reflect.Type, valName string) []string {
switch kind := typ.Kind(); kind {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return nil
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
return []string{valName}
case reflect.Array:
return PointersInType(typ.Elem(), valName+"[]")
case reflect.Struct:
var ptrs []string
for i, n := 0, typ.NumField(); i < n; i++ {
field := typ.Field(i)
ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
}
return ptrs
default:
return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
}
}

View File

@ -6,7 +6,6 @@
package sync
import (
"reflect"
"testing"
"time"
)
@ -99,55 +98,3 @@ func BenchmarkSeqCountReadUncontended(b *testing.B) {
}
})
}
func TestPointersInType(t *testing.T) {
for _, test := range []struct {
name string // used for both test and value name
val interface{}
ptrs []string
}{
{
name: "EmptyStruct",
val: struct{}{},
},
{
name: "Int",
val: int(0),
},
{
name: "MixedStruct",
val: struct {
b bool
I int
ExportedPtr *struct{}
unexportedPtr *struct{}
arr [2]int
ptrArr [2]*int
nestedStruct struct {
nestedNonptr int
nestedPtr *int
}
structArr [1]struct {
nonptr int
ptr *int
}
}{},
ptrs: []string{
"MixedStruct.ExportedPtr",
"MixedStruct.unexportedPtr",
"MixedStruct.ptrArr[]",
"MixedStruct.nestedStruct.nestedPtr",
"MixedStruct.structArr[].ptr",
},
},
} {
t.Run(test.name, func(t *testing.T) {
typ := reflect.TypeOf(test.val)
ptrs := PointersInType(typ, test.name)
t.Logf("Found pointers: %v", ptrs)
if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
}
})
}
}