diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 591b37130..b4c1c70d9 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -127,9 +127,9 @@ type result struct { // The goroutine will exit when input is closed, and the goroutine will close // output. type worker struct { - pool *pool - input chan *chunk - output chan result + hashPool *hashPool + input chan *chunk + output chan result } // work is the main work routine; see worker. @@ -139,8 +139,8 @@ func (w *worker) work(compress bool, level int) { var h hash.Hash for c := range w.input { - if h == nil && w.pool.key != nil { - h = w.pool.getHash() + if h == nil && w.hashPool != nil { + h = w.hashPool.getHash() } if compress { mw := io.Writer(c.compressed) @@ -201,6 +201,42 @@ func (w *worker) work(compress bool, level int) { } } +type hashPool struct { + // mu protexts the hash list. + mu sync.Mutex + + // key is the key used to create hash objects. + key []byte + + // hashes is the hash object free list. Note that this cannot be + // globally shared across readers or writers, as it is key-specific. + hashes []hash.Hash +} + +// getHash gets a hash object for the pool. It should only be called when the +// pool key is non-nil. +func (p *hashPool) getHash() hash.Hash { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.hashes) == 0 { + return hmac.New(sha256.New, p.key) + } + + h := p.hashes[len(p.hashes)-1] + p.hashes = p.hashes[:len(p.hashes)-1] + return h +} + +func (p *hashPool) putHash(h hash.Hash) { + h.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + p.hashes = append(p.hashes, h) +} + // pool is common functionality for reader/writers. type pool struct { // workers are the compression/decompression workers. @@ -210,16 +246,6 @@ type pool struct { // stream and is shared across both the reader and writer. chunkSize uint32 - // key is the key used to create hash objects. - key []byte - - // hashMu protexts the hash list. - hashMu sync.Mutex - - // hashes is the hash object free list. Note that this cannot be - // globally shared across readers or writers, as it is key-specific. - hashes []hash.Hash - // mu protects below; it is generally the responsibility of users to // acquire this mutex before calling any methods on the pool. mu sync.Mutex @@ -236,19 +262,26 @@ type pool struct { // lasSum records the hash of the last chunk processed. lastSum []byte + + // hashPool is the hash object pool. It cannot be embedded into pool + // itself as worker refers to it and that would stop pool from being + // GCed. + hashPool *hashPool } // init initializes the worker pool. // // This should only be called once. func (p *pool) init(key []byte, workers int, compress bool, level int) { - p.key = key + if key != nil { + p.hashPool = &hashPool{key: key} + } p.workers = make([]worker, workers) for i := 0; i < len(p.workers); i++ { p.workers[i] = worker{ - pool: p, - input: make(chan *chunk, 1), - output: make(chan result, 1), + hashPool: p.hashPool, + input: make(chan *chunk, 1), + output: make(chan result, 1), } go p.workers[i].work(compress, level) // S/R-SAFE: In save path only. } @@ -261,30 +294,7 @@ func (p *pool) stop() { close(p.workers[i].input) } p.workers = nil -} - -// getHash gets a hash object for the pool. It should only be called when the -// pool key is non-nil. -func (p *pool) getHash() hash.Hash { - p.hashMu.Lock() - defer p.hashMu.Unlock() - - if len(p.hashes) == 0 { - return hmac.New(sha256.New, p.key) - } - - h := p.hashes[len(p.hashes)-1] - p.hashes = p.hashes[:len(p.hashes)-1] - return h -} - -func (p *pool) putHash(h hash.Hash) { - h.Reset() - - p.hashMu.Lock() - defer p.hashMu.Unlock() - - p.hashes = append(p.hashes, h) + p.hashPool = nil } // handleResult calls the callback. @@ -361,11 +371,11 @@ func NewReader(in io.Reader, key []byte) (io.Reader, error) { return nil, err } - if r.key != nil { - h := r.getHash() + if r.hashPool != nil { + h := r.hashPool.getHash() binary.WriteUint32(h, binary.BigEndian, r.chunkSize) r.lastSum = h.Sum(nil) - r.putHash(h) + r.hashPool.putHash(h) sum := make([]byte, len(r.lastSum)) if _, err := io.ReadFull(r.in, sum); err != nil { return nil, err @@ -477,7 +487,7 @@ func (r *reader) Read(p []byte) (int, error) { } var sum []byte - if r.key != nil { + if r.hashPool != nil { sum = make([]byte, len(r.lastSum)) if _, err := io.ReadFull(r.in, sum); err != nil { if err == io.EOF { @@ -573,11 +583,11 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write return nil, err } - if w.key != nil { - h := w.getHash() + if w.hashPool != nil { + h := w.hashPool.getHash() binary.WriteUint32(h, binary.BigEndian, chunkSize) w.lastSum = h.Sum(nil) - w.putHash(h) + w.hashPool.putHash(h) if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { return nil, err } @@ -600,10 +610,10 @@ func (w *writer) flush(c *chunk) error { return err } - if w.key != nil { + if w.hashPool != nil { io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) sum := c.h.Sum(nil) - w.putHash(c.h) + w.hashPool.putHash(c.h) c.h = nil if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { return err