gvisor/pkg/sentry/safemem/io_test.go

200 lines
6.2 KiB
Go

// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package safemem
import (
"bytes"
"io"
"testing"
)
func makeBlocks(slices ...[]byte) []Block {
blocks := make([]Block, 0, len(slices))
for _, s := range slices {
blocks = append(blocks, BlockFromSafeSlice(s))
}
return blocks
}
func TestFromIOReaderFullRead(t *testing.T) {
r := FromIOReader{bytes.NewBufferString("foobar")}
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
if wantN := uint64(6); n != wantN || err != nil {
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
}
}
}
type eofHidingReader struct {
Reader io.Reader
}
func (r eofHidingReader) Read(dst []byte) (int, error) {
n, err := r.Reader.Read(dst)
if err == io.EOF {
return n, nil
}
return n, err
}
func TestFromIOReaderPartialRead(t *testing.T) {
r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
// FromIOReader should stop after the eofHidingReader returns (1, nil)
// for a 3-byte read.
if wantN := uint64(4); n != wantN || err != nil {
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
}
}
}
type singleByteReader struct {
Reader io.Reader
}
func (r singleByteReader) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return r.Reader.Read(dst)
}
return r.Reader.Read(dst[:1])
}
func TestSingleByteReader(t *testing.T) {
r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
// FromIOReader should stop after the singleByteReader returns (1, nil)
// for a 3-byte read.
if wantN := uint64(1); n != wantN || err != nil {
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
}
}
}
func TestReadFullToBlocks(t *testing.T) {
r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
// ReadFullToBlocks should call into FromIOReader => singleByteReader
// repeatedly until dsts is exhausted.
if wantN := uint64(6); n != wantN || err != nil {
t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
}
}
}
func TestFromIOWriterFullWrite(t *testing.T) {
srcs := makeBlocks([]byte("foo"), []byte("bar"))
var dst bytes.Buffer
w := FromIOWriter{&dst}
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
if wantN := uint64(6); n != wantN || err != nil {
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
t.Errorf("dst: got %q, wanted %q", got, want)
}
}
type limitedWriter struct {
Writer io.Writer
Done int
Limit int
}
func (w *limitedWriter) Write(src []byte) (int, error) {
count := len(src)
if count > (w.Limit - w.Done) {
count = w.Limit - w.Done
}
n, err := w.Writer.Write(src[:count])
w.Done += n
return n, err
}
func TestFromIOWriterPartialWrite(t *testing.T) {
srcs := makeBlocks([]byte("foo"), []byte("bar"))
var dst bytes.Buffer
w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
// FromIOWriter should stop after the limitedWriter returns (1, nil) for a
// 3-byte write.
if wantN := uint64(4); n != wantN || err != nil {
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
t.Errorf("dst: got %q, wanted %q", got, want)
}
}
type singleByteWriter struct {
Writer io.Writer
}
func (w singleByteWriter) Write(src []byte) (int, error) {
if len(src) == 0 {
return w.Writer.Write(src)
}
return w.Writer.Write(src[:1])
}
func TestSingleByteWriter(t *testing.T) {
srcs := makeBlocks([]byte("foo"), []byte("bar"))
var dst bytes.Buffer
w := FromIOWriter{singleByteWriter{&dst}}
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
// FromIOWriter should stop after the singleByteWriter returns (1, nil)
// for a 3-byte write.
if wantN := uint64(1); n != wantN || err != nil {
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
t.Errorf("dst: got %q, wanted %q", got, want)
}
}
func TestWriteFullToBlocks(t *testing.T) {
srcs := makeBlocks([]byte("foo"), []byte("bar"))
var dst bytes.Buffer
w := FromIOWriter{singleByteWriter{&dst}}
n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
// WriteFullToBlocks should call into FromIOWriter => singleByteWriter
// repeatedly until srcs is exhausted.
if wantN := uint64(6); n != wantN || err != nil {
t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
}
if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
t.Errorf("dst: got %q, wanted %q", got, want)
}
}