600 lines
16 KiB
Go
600 lines
16 KiB
Go
// Copyright 2018 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Package urpc provides a minimal RPC package based on unet.
|
|
//
|
|
// RPC requests are _not_ concurrent and methods must be explicitly
|
|
// registered. However, files may be send as part of the payload.
|
|
package urpc
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
|
|
"gvisor.googlesource.com/gvisor/pkg/log"
|
|
"gvisor.googlesource.com/gvisor/pkg/unet"
|
|
)
|
|
|
|
// maxFiles determines the maximum file payload.
|
|
const maxFiles = 16
|
|
|
|
// ErrTooManyFiles is returned when too many file descriptors are mapped.
|
|
var ErrTooManyFiles = errors.New("too many files")
|
|
|
|
// ErrUnknownMethod is returned when a method is not known.
|
|
var ErrUnknownMethod = errors.New("unknown method")
|
|
|
|
// errStopped is an internal error indicating the server has been stopped.
|
|
var errStopped = errors.New("stopped")
|
|
|
|
// RemoteError is an error returned by the remote invocation.
|
|
//
|
|
// This indicates that the RPC transport was correct, but that the called
|
|
// function itself returned an error.
|
|
type RemoteError struct {
|
|
// Message is the result of calling Error() on the remote error.
|
|
Message string
|
|
}
|
|
|
|
// Error returns the remote error string.
|
|
func (r RemoteError) Error() string {
|
|
return r.Message
|
|
}
|
|
|
|
// FilePayload may be _embedded_ in another type in order to send or receive a
|
|
// file as a result of an RPC. These are not actually serialized, rather they
|
|
// are sent via an accompanying SCM_RIGHTS message (plumbed through the unet
|
|
// package).
|
|
type FilePayload struct {
|
|
Files []*os.File `json:"-"`
|
|
}
|
|
|
|
// filePayload returns the file. It may be nil.
|
|
func (f *FilePayload) filePayload() []*os.File {
|
|
return f.Files
|
|
}
|
|
|
|
// setFilePayload sets the payload.
|
|
func (f *FilePayload) setFilePayload(fs []*os.File) {
|
|
f.Files = fs
|
|
}
|
|
|
|
// closeAll closes a slice of files.
|
|
func closeAll(files []*os.File) {
|
|
for _, f := range files {
|
|
f.Close()
|
|
}
|
|
}
|
|
|
|
// filePayloader is implemented only by FilePayload and will be implicitly
|
|
// implemented by types that have the FilePayload embedded. Note that there is
|
|
// no way to implement these methods other than by embedding FilePayload, due
|
|
// to the way unexported method names are mangled.
|
|
type filePayloader interface {
|
|
filePayload() []*os.File
|
|
setFilePayload([]*os.File)
|
|
}
|
|
|
|
// clientCall is the client=>server method call on the client side.
|
|
type clientCall struct {
|
|
Method string `json:"method"`
|
|
Arg interface{} `json:"arg"`
|
|
}
|
|
|
|
// serverCall is the client=>server method call on the server side.
|
|
type serverCall struct {
|
|
Method string `json:"method"`
|
|
Arg json.RawMessage `json:"arg"`
|
|
}
|
|
|
|
// callResult is the server=>client method call result.
|
|
type callResult struct {
|
|
Success bool `json:"success"`
|
|
Err string `json:"err"`
|
|
Result interface{} `json:"result"`
|
|
}
|
|
|
|
// registeredMethod is method registered with the server.
|
|
type registeredMethod struct {
|
|
// fn is the underlying function.
|
|
fn reflect.Value
|
|
|
|
// rcvr is the receiver value.
|
|
rcvr reflect.Value
|
|
|
|
// argType is a typed argument.
|
|
argType reflect.Type
|
|
|
|
// resultType is also a type result.
|
|
resultType reflect.Type
|
|
}
|
|
|
|
// clientState is client metadata.
|
|
//
|
|
// The following are valid states:
|
|
//
|
|
// idle - not processing any requests, no close request.
|
|
// processing - actively processing, no close request.
|
|
// closeRequested - actively processing, pending close.
|
|
// closed - client connection has been closed.
|
|
//
|
|
// The following transitions are possible:
|
|
//
|
|
// idle -> processing, closed
|
|
// processing -> idle, closeRequested
|
|
// closeRequested -> closed
|
|
//
|
|
type clientState int
|
|
|
|
// See clientState.
|
|
const (
|
|
idle clientState = iota
|
|
processing
|
|
closeRequested
|
|
closed
|
|
)
|
|
|
|
// Server is an RPC server.
|
|
type Server struct {
|
|
// mu protects all fields, except wg.
|
|
mu sync.Mutex
|
|
|
|
// methods is the set of server methods.
|
|
methods map[string]registeredMethod
|
|
|
|
// clients is a map of clients.
|
|
clients map[*unet.Socket]clientState
|
|
|
|
// wg is a wait group for all outstanding clients.
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// NewServer returns a new server.
|
|
func NewServer() *Server {
|
|
return &Server{
|
|
methods: make(map[string]registeredMethod),
|
|
clients: make(map[*unet.Socket]clientState),
|
|
}
|
|
}
|
|
|
|
// Register registers the given object as an RPC receiver.
|
|
//
|
|
// This functions is the same way as the built-in RPC package, but it does not
|
|
// tolerate any object with non-conforming methods. Any non-confirming methods
|
|
// will lead to an immediate panic, instead of being skipped or an error.
|
|
// Panics will also be generated by anonymous objects and duplicate entries.
|
|
func (s *Server) Register(obj interface{}) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
typ := reflect.TypeOf(obj)
|
|
|
|
// If we got a pointer, deref it to the underlying object. We need this to
|
|
// obtain the name of the underlying type.
|
|
typDeref := typ
|
|
if typ.Kind() == reflect.Ptr {
|
|
typDeref = typ.Elem()
|
|
}
|
|
|
|
for m := 0; m < typ.NumMethod(); m++ {
|
|
method := typ.Method(m)
|
|
|
|
if typDeref.Name() == "" {
|
|
// Can't be anonymous.
|
|
panic("type not named.")
|
|
}
|
|
|
|
prettyName := typDeref.Name() + "." + method.Name
|
|
if _, ok := s.methods[prettyName]; ok {
|
|
// Duplicate entry.
|
|
panic(fmt.Sprintf("method %s is duplicated.", prettyName))
|
|
}
|
|
|
|
if method.PkgPath != "" {
|
|
// Must be exported.
|
|
panic(fmt.Sprintf("method %s is not exported.", prettyName))
|
|
}
|
|
mtype := method.Type
|
|
if mtype.NumIn() != 3 {
|
|
// Need exactly two arguments (+ receiver).
|
|
panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName))
|
|
}
|
|
argType := mtype.In(1)
|
|
if argType.Kind() != reflect.Ptr {
|
|
// Need arg pointer.
|
|
panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName))
|
|
}
|
|
resultType := mtype.In(2)
|
|
if resultType.Kind() != reflect.Ptr {
|
|
// Need result pointer.
|
|
panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName))
|
|
}
|
|
if mtype.NumOut() != 1 {
|
|
// Need single return.
|
|
panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName))
|
|
}
|
|
if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() {
|
|
// Need error return.
|
|
panic(fmt.Sprintf("method %s has non-error return value.", prettyName))
|
|
}
|
|
|
|
// Register the method.
|
|
s.methods[prettyName] = registeredMethod{
|
|
fn: method.Func,
|
|
rcvr: reflect.ValueOf(obj),
|
|
argType: argType,
|
|
resultType: resultType,
|
|
}
|
|
}
|
|
}
|
|
|
|
// lookup looks up the given method.
|
|
func (s *Server) lookup(method string) (registeredMethod, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
rm, ok := s.methods[method]
|
|
return rm, ok
|
|
}
|
|
|
|
// handleOne handles a single call.
|
|
func (s *Server) handleOne(client *unet.Socket) error {
|
|
// Unmarshal the call.
|
|
var c serverCall
|
|
newFs, err := unmarshal(client, &c)
|
|
if err != nil {
|
|
// Client is dead.
|
|
return err
|
|
}
|
|
|
|
// Start the request.
|
|
if !s.clientBeginRequest(client) {
|
|
// Client is dead; don't process this call.
|
|
return errStopped
|
|
}
|
|
defer s.clientEndRequest(client)
|
|
|
|
// Lookup the method.
|
|
rm, ok := s.lookup(c.Method)
|
|
if !ok {
|
|
// Try to serialize the error.
|
|
return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil)
|
|
}
|
|
|
|
// Unmarshal the arguments now that we know the type.
|
|
na := reflect.New(rm.argType.Elem())
|
|
if err := json.Unmarshal(c.Arg, na.Interface()); err != nil {
|
|
return marshal(client, &callResult{Err: err.Error()}, nil)
|
|
}
|
|
|
|
// Set the file payload as an argument.
|
|
if fp, ok := na.Interface().(filePayloader); ok {
|
|
fp.setFilePayload(newFs)
|
|
}
|
|
|
|
// Call the method.
|
|
re := reflect.New(rm.resultType.Elem())
|
|
rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re})
|
|
if errVal := rValues[0].Interface(); errVal != nil {
|
|
return marshal(client, &callResult{Err: errVal.(error).Error()}, nil)
|
|
}
|
|
|
|
// Set the resulting payload.
|
|
var fs []*os.File
|
|
if fp, ok := re.Interface().(filePayloader); ok {
|
|
fs = fp.filePayload()
|
|
if len(fs) > maxFiles {
|
|
// Ugh. Send an error to the client, despite success.
|
|
return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil)
|
|
}
|
|
}
|
|
|
|
// Marshal the result.
|
|
return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs)
|
|
}
|
|
|
|
// clientBeginRequest begins a request.
|
|
//
|
|
// If true is returned, the request may be processed. If false is returned,
|
|
// then the server has been stopped and the request should be skipped.
|
|
func (s *Server) clientBeginRequest(client *unet.Socket) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
switch state := s.clients[client]; state {
|
|
case idle:
|
|
// Mark as processing.
|
|
s.clients[client] = processing
|
|
return true
|
|
case closed:
|
|
// Whoops, how did this happen? Must have closed immediately
|
|
// following the deserialization. Don't let the RPC actually go
|
|
// through, since we won't be able to serialize a proper
|
|
// response.
|
|
return false
|
|
default:
|
|
// Should not happen.
|
|
panic(fmt.Sprintf("expected idle or closed, got %d", state))
|
|
}
|
|
}
|
|
|
|
// clientEndRequest ends a request.
|
|
func (s *Server) clientEndRequest(client *unet.Socket) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
switch state := s.clients[client]; state {
|
|
case processing:
|
|
// Return to idle.
|
|
s.clients[client] = idle
|
|
case closeRequested:
|
|
// Close the connection.
|
|
client.Close()
|
|
s.clients[client] = closed
|
|
default:
|
|
// Should not happen.
|
|
panic(fmt.Sprintf("expected processing or requestClose, got %d", state))
|
|
}
|
|
}
|
|
|
|
// clientRegister registers a connection.
|
|
//
|
|
// See Stop for more context.
|
|
func (s *Server) clientRegister(client *unet.Socket) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.clients[client] = idle
|
|
s.wg.Add(1)
|
|
}
|
|
|
|
// clientUnregister unregisters and closes a connection if necessary.
|
|
//
|
|
// See Stop for more context.
|
|
func (s *Server) clientUnregister(client *unet.Socket) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
switch state := s.clients[client]; state {
|
|
case idle:
|
|
// Close the connection.
|
|
client.Close()
|
|
case closed:
|
|
// Already done.
|
|
default:
|
|
// Should not happen.
|
|
panic(fmt.Sprintf("expected idle or closed, got %d", state))
|
|
}
|
|
delete(s.clients, client)
|
|
s.wg.Done()
|
|
}
|
|
|
|
// handleRegistered handles calls from a registered client.
|
|
func (s *Server) handleRegistered(client *unet.Socket) error {
|
|
for {
|
|
// Handle one call.
|
|
if err := s.handleOne(client); err != nil {
|
|
// Client is dead.
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle synchronously handles a single client over a connection.
|
|
func (s *Server) Handle(client *unet.Socket) error {
|
|
s.clientRegister(client)
|
|
defer s.clientUnregister(client)
|
|
return s.handleRegistered(client)
|
|
}
|
|
|
|
// StartHandling creates a goroutine that handles a single client over a
|
|
// connection.
|
|
func (s *Server) StartHandling(client *unet.Socket) {
|
|
s.clientRegister(client)
|
|
go func() { // S/R-SAFE: out of scope
|
|
defer s.clientUnregister(client)
|
|
s.handleRegistered(client)
|
|
}()
|
|
}
|
|
|
|
// Stop safely terminates outstanding clients.
|
|
//
|
|
// No new requests should be initiated after calling Stop. Existing clients
|
|
// will be closed after completing any pending RPCs. This method will block
|
|
// until all clients have disconnected.
|
|
func (s *Server) Stop() {
|
|
// Wait for all outstanding requests.
|
|
defer s.wg.Wait()
|
|
|
|
// Close all known clients.
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for client, state := range s.clients {
|
|
switch state {
|
|
case idle:
|
|
// Close connection now.
|
|
client.Close()
|
|
s.clients[client] = closed
|
|
case processing:
|
|
// Request close when done.
|
|
s.clients[client] = closeRequested
|
|
}
|
|
}
|
|
}
|
|
|
|
// Client is a urpc client.
|
|
type Client struct {
|
|
// mu protects all members.
|
|
//
|
|
// It also enforces single-call semantics.
|
|
mu sync.Mutex
|
|
|
|
// Socket is the underlying socket for this client.
|
|
//
|
|
// This _must_ be provided and must be closed manually by calling
|
|
// Close.
|
|
Socket *unet.Socket
|
|
}
|
|
|
|
// NewClient returns a new client.
|
|
func NewClient(socket *unet.Socket) *Client {
|
|
return &Client{
|
|
Socket: socket,
|
|
}
|
|
}
|
|
|
|
// marshal sends the given FD and json struct.
|
|
func marshal(s *unet.Socket, v interface{}, fs []*os.File) error {
|
|
// Marshal to a buffer.
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error())
|
|
return err
|
|
}
|
|
|
|
// Write to the socket.
|
|
w := s.Writer(true)
|
|
if fs != nil {
|
|
var fds []int
|
|
for _, f := range fs {
|
|
fds = append(fds, int(f.Fd()))
|
|
}
|
|
w.PackFDs(fds...)
|
|
}
|
|
|
|
// Send.
|
|
for n := 0; n < len(data); {
|
|
cur, err := w.WriteVec([][]byte{data[n:]})
|
|
if n == 0 && cur < len(data) {
|
|
// Don't send FDs anymore. This call is only made on
|
|
// the first successful call to WriteVec, assuming cur
|
|
// is not sufficient to fill the entire buffer.
|
|
w.PackFDs()
|
|
}
|
|
n += cur
|
|
if err != nil {
|
|
log.Warningf("urpc: error writing %v: %s", data[n:], err.Error())
|
|
return err
|
|
}
|
|
}
|
|
|
|
// We're done sending the fds to the client. Explicitly prevent fs from
|
|
// being GCed until here. Urpc rpcs often unlink the file to send, relying
|
|
// on the kernel to automatically delete it once the last reference is
|
|
// dropped. Until we successfully call sendmsg(2), fs may contain the last
|
|
// references to these files. Without this explicit reference to fs here,
|
|
// the go runtime is free to assume we're done with fs after the fd
|
|
// collection loop above, since it just sees us copying ints.
|
|
runtime.KeepAlive(fs)
|
|
|
|
log.Debugf("urpc: successfully marshalled %d bytes.", len(data))
|
|
return nil
|
|
}
|
|
|
|
// unmarhsal receives an FD (optional) and unmarshals the given struct.
|
|
func unmarshal(s *unet.Socket, v interface{}) ([]*os.File, error) {
|
|
// Receive a single byte.
|
|
r := s.Reader(true)
|
|
r.EnableFDs(maxFiles)
|
|
firstByte := make([]byte, 1)
|
|
|
|
// Extract any FDs that may be there.
|
|
if _, err := r.ReadVec([][]byte{firstByte}); err != nil {
|
|
return nil, err
|
|
}
|
|
fds, err := r.ExtractFDs()
|
|
if err != nil {
|
|
log.Warningf("urpc: error extracting fds: %s", err.Error())
|
|
return nil, err
|
|
}
|
|
var fs []*os.File
|
|
for _, fd := range fds {
|
|
fs = append(fs, os.NewFile(uintptr(fd), "urpc"))
|
|
}
|
|
|
|
// Read the rest.
|
|
d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s))
|
|
// urpc internally decodes / re-encodes the data with interface{} as the
|
|
// intermediate type. We have to unmarshal integers to json.Number type
|
|
// instead of the default float type for those intermediate values, such
|
|
// that when they get re-encoded, their values are not printed out in
|
|
// floating-point formats such as 1e9, which could not be decoded to
|
|
// explicitly typed intergers later.
|
|
d.UseNumber()
|
|
if err := d.Decode(v); err != nil {
|
|
log.Warningf("urpc: error decoding: %s", err.Error())
|
|
for _, f := range fs {
|
|
f.Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// All set.
|
|
log.Debugf("urpc: unmarshal success.")
|
|
return fs, nil
|
|
}
|
|
|
|
// Call calls a function.
|
|
func (c *Client) Call(method string, arg interface{}, result interface{}) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Are there files to send?
|
|
var fs []*os.File
|
|
if fp, ok := arg.(filePayloader); ok {
|
|
fs = fp.filePayload()
|
|
if len(fs) > maxFiles {
|
|
return ErrTooManyFiles
|
|
}
|
|
}
|
|
|
|
// Marshal the data.
|
|
if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Wait for the response.
|
|
callR := callResult{Result: result}
|
|
newFs, err := unmarshal(c.Socket, &callR)
|
|
if err != nil {
|
|
return fmt.Errorf("urpc method %q failed: %v", method, err)
|
|
}
|
|
|
|
// Set the file payload.
|
|
if fp, ok := result.(filePayloader); ok {
|
|
fp.setFilePayload(newFs)
|
|
} else {
|
|
closeAll(newFs)
|
|
}
|
|
|
|
// Did an error occur?
|
|
if !callR.Success {
|
|
return RemoteError{Message: callR.Err}
|
|
}
|
|
|
|
// All set.
|
|
return nil
|
|
}
|
|
|
|
// Close closes the underlying socket.
|
|
//
|
|
// Further calls to the client may result in undefined behavior.
|
|
func (c *Client) Close() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.Socket.Close()
|
|
}
|