Use closure to avoid manual unlocking

Also increase refcount of raw.endpoint.route while in use.

Avoid allocating an array of size zero.

PiperOrigin-RevId: 359797788
This commit is contained in:
Tamir Duberstein 2021-02-26 11:16:14 -08:00 committed by gVisor bot
parent 8e78d0eda6
commit da2505df94
2 changed files with 64 additions and 68 deletions

View File

@ -583,9 +583,14 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
// raw endpoint first. If there are multiple raw endpoints, they all
// receive the packet.
eps.mu.RLock()
// Copy the list of raw endpoints so we can release eps.mu earlier.
rawEPs := make([]RawTransportEndpoint, len(eps.rawEndpoints))
copy(rawEPs, eps.rawEndpoints)
// Copy the list of raw endpoints to avoid packet handling under lock.
var rawEPs []RawTransportEndpoint
if n := len(eps.rawEndpoints); n != 0 {
rawEPs = make([]RawTransportEndpoint, n)
if m := copy(rawEPs, eps.rawEndpoints); m != n {
panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n))
}
}
eps.mu.RUnlock()
for _, rawEP := range rawEPs {
// Each endpoint gets its own copy of the packet for the sake
@ -593,7 +598,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
rawEP.HandlePacket(pkt.Clone())
}
return len(rawEPs) > 0
return len(rawEPs) != 0
}
// deliverError attempts to deliver the given error to the appropriate transport

View File

@ -271,18 +271,17 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
if opts.More {
return 0, &tcpip.ErrInvalidOptionValue{}
}
payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.closed {
e.mu.RUnlock()
return 0, &tcpip.ErrInvalidEndpointState{}
return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
}
payloadBytes := make([]byte, p.Len())
if _, err := io.ReadFull(p, payloadBytes); err != nil {
e.mu.RUnlock()
return 0, &tcpip.ErrBadBuffer{}
return nil, nil, nil, &tcpip.ErrBadBuffer{}
}
// If this is an unassociated socket and callee provided a nonzero
@ -290,8 +289,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
return 0, &tcpip.ErrInvalidOptionValue{}
return nil, nil, nil, &tcpip.ErrInvalidOptionValue{}
}
dstAddr := ip.DestinationAddress()
// Update dstAddr with the address in the IP header, unless
@ -312,42 +310,35 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
e.mu.RUnlock()
return 0, &tcpip.ErrDestinationRequired{}
return nil, nil, nil, &tcpip.ErrDestinationRequired{}
}
owner := e.owner
route := e.route
e.mu.RUnlock()
return e.finishWrite(payloadBytes, route, owner)
e.route.Acquire()
return payloadBytes, e.route, e.owner, nil
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
e.mu.RUnlock()
return 0, &tcpip.ErrNoRoute{}
return nil, nil, nil, &tcpip.ErrNoRoute{}
}
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
e.mu.RUnlock()
return 0, err
return nil, nil, nil, err
}
owner := e.owner
e.mu.RUnlock()
n, err := e.finishWrite(payloadBytes, route, owner)
route.Release()
return n, err
}
return payloadBytes, route, e.owner, nil
}()
if err != nil {
return 0, err
}
defer route.Release()
// finishWrite writes the payload to a route. It resolves the route if
// necessary.
func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route, owner tcpip.PacketOwner) (int64, tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),