// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package urpc import ( "errors" "os" "testing" "gvisor.dev/gvisor/pkg/unet" ) type test struct { } type testArg struct { StringArg string IntArg int FilePayload } type testResult struct { StringResult string IntResult int FilePayload } func (t test) Func(a *testArg, r *testResult) error { r.StringResult = a.StringArg r.IntResult = a.IntArg return nil } func (t test) Err(a *testArg, r *testResult) error { return errors.New("test error") } func (t test) FailNoFile(a *testArg, r *testResult) error { if a.Files == nil { return errors.New("no file found") } return nil } func (t test) SendFile(a *testArg, r *testResult) error { r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} return nil } func (t test) TooManyFiles(a *testArg, r *testResult) error { for i := 0; i <= maxFiles; i++ { r.Files = append(r.Files, os.Stdin) } return nil } func startServer(socket *unet.Socket) { s := NewServer() s.Register(test{}) s.StartHandling(socket) } func testClient() (*Client, error) { serverSock, clientSock, err := unet.SocketPair(false) if err != nil { return nil, err } startServer(serverSock) return NewClient(clientSock), nil } func TestCall(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult if err := c.Call("test.Func", &testArg{}, &r); err != nil { t.Errorf("basic call failed: %v", err) } else if r.StringResult != "" || r.IntResult != 0 { t.Errorf("unexpected result, got %v expected zero value", r) } if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil { t.Errorf("basic call failed: %v", err) } else if r.StringResult != "hello" { t.Errorf("unexpected result, got %v expected hello", r.StringResult) } if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil { t.Errorf("basic call failed: %v", err) } else if r.IntResult != 1 { t.Errorf("unexpected result, got %v expected 1", r.IntResult) } } func TestUnknownMethod(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult if err := c.Call("test.Unknown", &testArg{}, &r); err == nil { t.Errorf("expected non-nil err, got nil") } else if err.Error() != ErrUnknownMethod.Error() { t.Errorf("expected test error, got %v", err) } } func TestErr(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult if err := c.Call("test.Err", &testArg{}, &r); err == nil { t.Errorf("expected non-nil err, got nil") } else if err.Error() != "test error" { t.Errorf("expected test error, got %v", err) } } func TestSendFile(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil { t.Errorf("expected non-nil err, got nil") } if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil { t.Errorf("expected nil err, got %v", err) } } func TestRecvFile(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult if err := c.Call("test.SendFile", &testArg{}, &r); err != nil { t.Errorf("expected nil err, got %v", err) } if r.Files == nil { t.Errorf("expected file, got nil") } } func TestShutdown(t *testing.T) { serverSock, clientSock, err := unet.SocketPair(false) if err != nil { t.Fatalf("error creating test client: %v", err) } clientSock.Close() s := NewServer() if err := s.Handle(serverSock); err == nil { t.Errorf("expected non-nil err, got nil") } } func TestTooManyFiles(t *testing.T) { c, err := testClient() if err != nil { t.Fatalf("error creating test client: %v", err) } defer c.Close() var r testResult var a testArg for i := 0; i <= maxFiles; i++ { a.Files = append(a.Files, os.Stdin) } // Client-side error. if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles { t.Errorf("expected ErrTooManyFiles, got %v", err) } // Server-side error. if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil { t.Errorf("expected non-nil err, got nil") } else if err.Error() != "too many files" { t.Errorf("expected too many files, got %v", err.Error()) } }