compressio: stop worker-pool reference / dependency loop.

PiperOrigin-RevId: 212732300
Change-Id: I9a0b9b7c28e7b7439d34656dd4f2f6114d173e22
This commit is contained in:
Zhaozhong Ni 2018-09-12 17:23:56 -07:00 committed by Shentubot
parent 2eff1fdd06
commit 9dec7a3db9
1 changed files with 62 additions and 52 deletions

View File

@ -127,9 +127,9 @@ type result struct {
// The goroutine will exit when input is closed, and the goroutine will close
// output.
type worker struct {
pool *pool
input chan *chunk
output chan result
hashPool *hashPool
input chan *chunk
output chan result
}
// work is the main work routine; see worker.
@ -139,8 +139,8 @@ func (w *worker) work(compress bool, level int) {
var h hash.Hash
for c := range w.input {
if h == nil && w.pool.key != nil {
h = w.pool.getHash()
if h == nil && w.hashPool != nil {
h = w.hashPool.getHash()
}
if compress {
mw := io.Writer(c.compressed)
@ -201,6 +201,42 @@ func (w *worker) work(compress bool, level int) {
}
}
type hashPool struct {
// mu protexts the hash list.
mu sync.Mutex
// key is the key used to create hash objects.
key []byte
// hashes is the hash object free list. Note that this cannot be
// globally shared across readers or writers, as it is key-specific.
hashes []hash.Hash
}
// getHash gets a hash object for the pool. It should only be called when the
// pool key is non-nil.
func (p *hashPool) getHash() hash.Hash {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.hashes) == 0 {
return hmac.New(sha256.New, p.key)
}
h := p.hashes[len(p.hashes)-1]
p.hashes = p.hashes[:len(p.hashes)-1]
return h
}
func (p *hashPool) putHash(h hash.Hash) {
h.Reset()
p.mu.Lock()
defer p.mu.Unlock()
p.hashes = append(p.hashes, h)
}
// pool is common functionality for reader/writers.
type pool struct {
// workers are the compression/decompression workers.
@ -210,16 +246,6 @@ type pool struct {
// stream and is shared across both the reader and writer.
chunkSize uint32
// key is the key used to create hash objects.
key []byte
// hashMu protexts the hash list.
hashMu sync.Mutex
// hashes is the hash object free list. Note that this cannot be
// globally shared across readers or writers, as it is key-specific.
hashes []hash.Hash
// mu protects below; it is generally the responsibility of users to
// acquire this mutex before calling any methods on the pool.
mu sync.Mutex
@ -236,19 +262,26 @@ type pool struct {
// lasSum records the hash of the last chunk processed.
lastSum []byte
// hashPool is the hash object pool. It cannot be embedded into pool
// itself as worker refers to it and that would stop pool from being
// GCed.
hashPool *hashPool
}
// init initializes the worker pool.
//
// This should only be called once.
func (p *pool) init(key []byte, workers int, compress bool, level int) {
p.key = key
if key != nil {
p.hashPool = &hashPool{key: key}
}
p.workers = make([]worker, workers)
for i := 0; i < len(p.workers); i++ {
p.workers[i] = worker{
pool: p,
input: make(chan *chunk, 1),
output: make(chan result, 1),
hashPool: p.hashPool,
input: make(chan *chunk, 1),
output: make(chan result, 1),
}
go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
}
@ -261,30 +294,7 @@ func (p *pool) stop() {
close(p.workers[i].input)
}
p.workers = nil
}
// getHash gets a hash object for the pool. It should only be called when the
// pool key is non-nil.
func (p *pool) getHash() hash.Hash {
p.hashMu.Lock()
defer p.hashMu.Unlock()
if len(p.hashes) == 0 {
return hmac.New(sha256.New, p.key)
}
h := p.hashes[len(p.hashes)-1]
p.hashes = p.hashes[:len(p.hashes)-1]
return h
}
func (p *pool) putHash(h hash.Hash) {
h.Reset()
p.hashMu.Lock()
defer p.hashMu.Unlock()
p.hashes = append(p.hashes, h)
p.hashPool = nil
}
// handleResult calls the callback.
@ -361,11 +371,11 @@ func NewReader(in io.Reader, key []byte) (io.Reader, error) {
return nil, err
}
if r.key != nil {
h := r.getHash()
if r.hashPool != nil {
h := r.hashPool.getHash()
binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
r.lastSum = h.Sum(nil)
r.putHash(h)
r.hashPool.putHash(h)
sum := make([]byte, len(r.lastSum))
if _, err := io.ReadFull(r.in, sum); err != nil {
return nil, err
@ -477,7 +487,7 @@ func (r *reader) Read(p []byte) (int, error) {
}
var sum []byte
if r.key != nil {
if r.hashPool != nil {
sum = make([]byte, len(r.lastSum))
if _, err := io.ReadFull(r.in, sum); err != nil {
if err == io.EOF {
@ -573,11 +583,11 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
return nil, err
}
if w.key != nil {
h := w.getHash()
if w.hashPool != nil {
h := w.hashPool.getHash()
binary.WriteUint32(h, binary.BigEndian, chunkSize)
w.lastSum = h.Sum(nil)
w.putHash(h)
w.hashPool.putHash(h)
if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
return nil, err
}
@ -600,10 +610,10 @@ func (w *writer) flush(c *chunk) error {
return err
}
if w.key != nil {
if w.hashPool != nil {
io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
sum := c.h.Sum(nil)
w.putHash(c.h)
w.hashPool.putHash(c.h)
c.h = nil
if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
return err