200 lines
6.2 KiB
Go
200 lines
6.2 KiB
Go
// Copyright 2018 The gVisor Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package safemem
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"testing"
|
|
)
|
|
|
|
func makeBlocks(slices ...[]byte) []Block {
|
|
blocks := make([]Block, 0, len(slices))
|
|
for _, s := range slices {
|
|
blocks = append(blocks, BlockFromSafeSlice(s))
|
|
}
|
|
return blocks
|
|
}
|
|
|
|
func TestFromIOReaderFullRead(t *testing.T) {
|
|
r := FromIOReader{bytes.NewBufferString("foobar")}
|
|
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
|
|
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
|
|
if wantN := uint64(6); n != wantN || err != nil {
|
|
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
|
|
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
|
|
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
type eofHidingReader struct {
|
|
Reader io.Reader
|
|
}
|
|
|
|
func (r eofHidingReader) Read(dst []byte) (int, error) {
|
|
n, err := r.Reader.Read(dst)
|
|
if err == io.EOF {
|
|
return n, nil
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func TestFromIOReaderPartialRead(t *testing.T) {
|
|
r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
|
|
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
|
|
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
|
|
// FromIOReader should stop after the eofHidingReader returns (1, nil)
|
|
// for a 3-byte read.
|
|
if wantN := uint64(4); n != wantN || err != nil {
|
|
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
|
|
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
|
|
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
type singleByteReader struct {
|
|
Reader io.Reader
|
|
}
|
|
|
|
func (r singleByteReader) Read(dst []byte) (int, error) {
|
|
if len(dst) == 0 {
|
|
return r.Reader.Read(dst)
|
|
}
|
|
return r.Reader.Read(dst[:1])
|
|
}
|
|
|
|
func TestSingleByteReader(t *testing.T) {
|
|
r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
|
|
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
|
|
n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
|
|
// FromIOReader should stop after the singleByteReader returns (1, nil)
|
|
// for a 3-byte read.
|
|
if wantN := uint64(1); n != wantN || err != nil {
|
|
t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
|
|
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
|
|
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReadFullToBlocks(t *testing.T) {
|
|
r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
|
|
dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
|
|
n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
|
|
// ReadFullToBlocks should call into FromIOReader => singleByteReader
|
|
// repeatedly until dsts is exhausted.
|
|
if wantN := uint64(6); n != wantN || err != nil {
|
|
t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
|
|
if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
|
|
t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFromIOWriterFullWrite(t *testing.T) {
|
|
srcs := makeBlocks([]byte("foo"), []byte("bar"))
|
|
var dst bytes.Buffer
|
|
w := FromIOWriter{&dst}
|
|
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
|
|
if wantN := uint64(6); n != wantN || err != nil {
|
|
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
|
|
t.Errorf("dst: got %q, wanted %q", got, want)
|
|
}
|
|
}
|
|
|
|
type limitedWriter struct {
|
|
Writer io.Writer
|
|
Done int
|
|
Limit int
|
|
}
|
|
|
|
func (w *limitedWriter) Write(src []byte) (int, error) {
|
|
count := len(src)
|
|
if count > (w.Limit - w.Done) {
|
|
count = w.Limit - w.Done
|
|
}
|
|
n, err := w.Writer.Write(src[:count])
|
|
w.Done += n
|
|
return n, err
|
|
}
|
|
|
|
func TestFromIOWriterPartialWrite(t *testing.T) {
|
|
srcs := makeBlocks([]byte("foo"), []byte("bar"))
|
|
var dst bytes.Buffer
|
|
w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
|
|
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
|
|
// FromIOWriter should stop after the limitedWriter returns (1, nil) for a
|
|
// 3-byte write.
|
|
if wantN := uint64(4); n != wantN || err != nil {
|
|
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
|
|
t.Errorf("dst: got %q, wanted %q", got, want)
|
|
}
|
|
}
|
|
|
|
type singleByteWriter struct {
|
|
Writer io.Writer
|
|
}
|
|
|
|
func (w singleByteWriter) Write(src []byte) (int, error) {
|
|
if len(src) == 0 {
|
|
return w.Writer.Write(src)
|
|
}
|
|
return w.Writer.Write(src[:1])
|
|
}
|
|
|
|
func TestSingleByteWriter(t *testing.T) {
|
|
srcs := makeBlocks([]byte("foo"), []byte("bar"))
|
|
var dst bytes.Buffer
|
|
w := FromIOWriter{singleByteWriter{&dst}}
|
|
n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
|
|
// FromIOWriter should stop after the singleByteWriter returns (1, nil)
|
|
// for a 3-byte write.
|
|
if wantN := uint64(1); n != wantN || err != nil {
|
|
t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
|
|
t.Errorf("dst: got %q, wanted %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestWriteFullToBlocks(t *testing.T) {
|
|
srcs := makeBlocks([]byte("foo"), []byte("bar"))
|
|
var dst bytes.Buffer
|
|
w := FromIOWriter{singleByteWriter{&dst}}
|
|
n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
|
|
// WriteFullToBlocks should call into FromIOWriter => singleByteWriter
|
|
// repeatedly until srcs is exhausted.
|
|
if wantN := uint64(6); n != wantN || err != nil {
|
|
t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
|
|
}
|
|
if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
|
|
t.Errorf("dst: got %q, wanted %q", got, want)
|
|
}
|
|
}
|