compressio: stop worker-pool reference / dependency loop.
PiperOrigin-RevId: 212732300 Change-Id: I9a0b9b7c28e7b7439d34656dd4f2f6114d173e22
This commit is contained in:
parent
2eff1fdd06
commit
9dec7a3db9
|
@ -127,9 +127,9 @@ type result struct {
|
|||
// The goroutine will exit when input is closed, and the goroutine will close
|
||||
// output.
|
||||
type worker struct {
|
||||
pool *pool
|
||||
input chan *chunk
|
||||
output chan result
|
||||
hashPool *hashPool
|
||||
input chan *chunk
|
||||
output chan result
|
||||
}
|
||||
|
||||
// work is the main work routine; see worker.
|
||||
|
@ -139,8 +139,8 @@ func (w *worker) work(compress bool, level int) {
|
|||
var h hash.Hash
|
||||
|
||||
for c := range w.input {
|
||||
if h == nil && w.pool.key != nil {
|
||||
h = w.pool.getHash()
|
||||
if h == nil && w.hashPool != nil {
|
||||
h = w.hashPool.getHash()
|
||||
}
|
||||
if compress {
|
||||
mw := io.Writer(c.compressed)
|
||||
|
@ -201,6 +201,42 @@ func (w *worker) work(compress bool, level int) {
|
|||
}
|
||||
}
|
||||
|
||||
type hashPool struct {
|
||||
// mu protexts the hash list.
|
||||
mu sync.Mutex
|
||||
|
||||
// key is the key used to create hash objects.
|
||||
key []byte
|
||||
|
||||
// hashes is the hash object free list. Note that this cannot be
|
||||
// globally shared across readers or writers, as it is key-specific.
|
||||
hashes []hash.Hash
|
||||
}
|
||||
|
||||
// getHash gets a hash object for the pool. It should only be called when the
|
||||
// pool key is non-nil.
|
||||
func (p *hashPool) getHash() hash.Hash {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.hashes) == 0 {
|
||||
return hmac.New(sha256.New, p.key)
|
||||
}
|
||||
|
||||
h := p.hashes[len(p.hashes)-1]
|
||||
p.hashes = p.hashes[:len(p.hashes)-1]
|
||||
return h
|
||||
}
|
||||
|
||||
func (p *hashPool) putHash(h hash.Hash) {
|
||||
h.Reset()
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.hashes = append(p.hashes, h)
|
||||
}
|
||||
|
||||
// pool is common functionality for reader/writers.
|
||||
type pool struct {
|
||||
// workers are the compression/decompression workers.
|
||||
|
@ -210,16 +246,6 @@ type pool struct {
|
|||
// stream and is shared across both the reader and writer.
|
||||
chunkSize uint32
|
||||
|
||||
// key is the key used to create hash objects.
|
||||
key []byte
|
||||
|
||||
// hashMu protexts the hash list.
|
||||
hashMu sync.Mutex
|
||||
|
||||
// hashes is the hash object free list. Note that this cannot be
|
||||
// globally shared across readers or writers, as it is key-specific.
|
||||
hashes []hash.Hash
|
||||
|
||||
// mu protects below; it is generally the responsibility of users to
|
||||
// acquire this mutex before calling any methods on the pool.
|
||||
mu sync.Mutex
|
||||
|
@ -236,19 +262,26 @@ type pool struct {
|
|||
|
||||
// lasSum records the hash of the last chunk processed.
|
||||
lastSum []byte
|
||||
|
||||
// hashPool is the hash object pool. It cannot be embedded into pool
|
||||
// itself as worker refers to it and that would stop pool from being
|
||||
// GCed.
|
||||
hashPool *hashPool
|
||||
}
|
||||
|
||||
// init initializes the worker pool.
|
||||
//
|
||||
// This should only be called once.
|
||||
func (p *pool) init(key []byte, workers int, compress bool, level int) {
|
||||
p.key = key
|
||||
if key != nil {
|
||||
p.hashPool = &hashPool{key: key}
|
||||
}
|
||||
p.workers = make([]worker, workers)
|
||||
for i := 0; i < len(p.workers); i++ {
|
||||
p.workers[i] = worker{
|
||||
pool: p,
|
||||
input: make(chan *chunk, 1),
|
||||
output: make(chan result, 1),
|
||||
hashPool: p.hashPool,
|
||||
input: make(chan *chunk, 1),
|
||||
output: make(chan result, 1),
|
||||
}
|
||||
go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
|
||||
}
|
||||
|
@ -261,30 +294,7 @@ func (p *pool) stop() {
|
|||
close(p.workers[i].input)
|
||||
}
|
||||
p.workers = nil
|
||||
}
|
||||
|
||||
// getHash gets a hash object for the pool. It should only be called when the
|
||||
// pool key is non-nil.
|
||||
func (p *pool) getHash() hash.Hash {
|
||||
p.hashMu.Lock()
|
||||
defer p.hashMu.Unlock()
|
||||
|
||||
if len(p.hashes) == 0 {
|
||||
return hmac.New(sha256.New, p.key)
|
||||
}
|
||||
|
||||
h := p.hashes[len(p.hashes)-1]
|
||||
p.hashes = p.hashes[:len(p.hashes)-1]
|
||||
return h
|
||||
}
|
||||
|
||||
func (p *pool) putHash(h hash.Hash) {
|
||||
h.Reset()
|
||||
|
||||
p.hashMu.Lock()
|
||||
defer p.hashMu.Unlock()
|
||||
|
||||
p.hashes = append(p.hashes, h)
|
||||
p.hashPool = nil
|
||||
}
|
||||
|
||||
// handleResult calls the callback.
|
||||
|
@ -361,11 +371,11 @@ func NewReader(in io.Reader, key []byte) (io.Reader, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if r.key != nil {
|
||||
h := r.getHash()
|
||||
if r.hashPool != nil {
|
||||
h := r.hashPool.getHash()
|
||||
binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
|
||||
r.lastSum = h.Sum(nil)
|
||||
r.putHash(h)
|
||||
r.hashPool.putHash(h)
|
||||
sum := make([]byte, len(r.lastSum))
|
||||
if _, err := io.ReadFull(r.in, sum); err != nil {
|
||||
return nil, err
|
||||
|
@ -477,7 +487,7 @@ func (r *reader) Read(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
var sum []byte
|
||||
if r.key != nil {
|
||||
if r.hashPool != nil {
|
||||
sum = make([]byte, len(r.lastSum))
|
||||
if _, err := io.ReadFull(r.in, sum); err != nil {
|
||||
if err == io.EOF {
|
||||
|
@ -573,11 +583,11 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if w.key != nil {
|
||||
h := w.getHash()
|
||||
if w.hashPool != nil {
|
||||
h := w.hashPool.getHash()
|
||||
binary.WriteUint32(h, binary.BigEndian, chunkSize)
|
||||
w.lastSum = h.Sum(nil)
|
||||
w.putHash(h)
|
||||
w.hashPool.putHash(h)
|
||||
if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -600,10 +610,10 @@ func (w *writer) flush(c *chunk) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if w.key != nil {
|
||||
if w.hashPool != nil {
|
||||
io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
|
||||
sum := c.h.Sum(nil)
|
||||
w.putHash(c.h)
|
||||
w.hashPool.putHash(c.h)
|
||||
c.h = nil
|
||||
if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue