173 lines
5.0 KiB
Go
173 lines
5.0 KiB
Go
// Copyright 2018 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package stack
|
|
|
|
import (
|
|
"sync"
|
|
|
|
"gvisor.googlesource.com/gvisor/pkg/tcpip"
|
|
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
|
|
)
|
|
|
|
type protocolIDs struct {
|
|
network tcpip.NetworkProtocolNumber
|
|
transport tcpip.TransportProtocolNumber
|
|
}
|
|
|
|
// transportEndpoints manages all endpoints of a given protocol. It has its own
|
|
// mutex so as to reduce interference between protocols.
|
|
type transportEndpoints struct {
|
|
mu sync.RWMutex
|
|
endpoints map[TransportEndpointID]TransportEndpoint
|
|
}
|
|
|
|
// transportDemuxer demultiplexes packets targeted at a transport endpoint
|
|
// (i.e., after they've been parsed by the network layer). It does two levels
|
|
// of demultiplexing: first based on the network and transport protocols, then
|
|
// based on endpoints IDs.
|
|
type transportDemuxer struct {
|
|
protocol map[protocolIDs]*transportEndpoints
|
|
}
|
|
|
|
func newTransportDemuxer(stack *Stack) *transportDemuxer {
|
|
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
|
|
|
|
// Add each network and transport pair to the demuxer.
|
|
for netProto := range stack.networkProtocols {
|
|
for proto := range stack.transportProtocols {
|
|
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
|
|
}
|
|
}
|
|
|
|
return d
|
|
}
|
|
|
|
// registerEndpoint registers the given endpoint with the dispatcher such that
|
|
// packets that match the endpoint ID are delivered to it.
|
|
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
|
for i, n := range netProtos {
|
|
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
|
|
d.unregisterEndpoint(netProtos[:i], protocol, id)
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
|
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
eps.mu.Lock()
|
|
defer eps.mu.Unlock()
|
|
|
|
if _, ok := eps.endpoints[id]; ok {
|
|
return tcpip.ErrPortInUse
|
|
}
|
|
|
|
eps.endpoints[id] = ep
|
|
|
|
return nil
|
|
}
|
|
|
|
// unregisterEndpoint unregisters the endpoint with the given id such that it
|
|
// won't receive any more packets.
|
|
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
|
for _, n := range netProtos {
|
|
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
|
|
eps.mu.Lock()
|
|
delete(eps.endpoints, id)
|
|
eps.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// deliverPacket attempts to deliver the given packet. Returns true if it found
|
|
// an endpoint, false otherwise.
|
|
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool {
|
|
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
eps.mu.RLock()
|
|
ep := d.findEndpointLocked(eps, vv, id)
|
|
eps.mu.RUnlock()
|
|
|
|
// Fail if we didn't find one.
|
|
if ep == nil {
|
|
return false
|
|
}
|
|
|
|
// Deliver the packet.
|
|
ep.HandlePacket(r, id, vv)
|
|
|
|
return true
|
|
}
|
|
|
|
// deliverControlPacket attempts to deliver the given control packet. Returns
|
|
// true if it found an endpoint, false otherwise.
|
|
func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView, id TransportEndpointID) bool {
|
|
eps, ok := d.protocol[protocolIDs{net, trans}]
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// Try to find the endpoint.
|
|
eps.mu.RLock()
|
|
ep := d.findEndpointLocked(eps, vv, id)
|
|
eps.mu.RUnlock()
|
|
|
|
// Fail if we didn't find one.
|
|
if ep == nil {
|
|
return false
|
|
}
|
|
|
|
// Deliver the packet.
|
|
ep.HandleControlPacket(id, typ, extra, vv)
|
|
|
|
return true
|
|
}
|
|
|
|
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
|
|
// Try to find a match with the id as provided.
|
|
if ep := eps.endpoints[id]; ep != nil {
|
|
return ep
|
|
}
|
|
|
|
// Try to find a match with the id minus the local address.
|
|
nid := id
|
|
|
|
nid.LocalAddress = ""
|
|
if ep := eps.endpoints[nid]; ep != nil {
|
|
return ep
|
|
}
|
|
|
|
// Try to find a match with the id minus the remote part.
|
|
nid.LocalAddress = id.LocalAddress
|
|
nid.RemoteAddress = ""
|
|
nid.RemotePort = 0
|
|
if ep := eps.endpoints[nid]; ep != nil {
|
|
return ep
|
|
}
|
|
|
|
// Try to find a match with only the local port.
|
|
nid.LocalAddress = ""
|
|
return eps.endpoints[nid]
|
|
}
|