Refactor parser to use a for loop instead of recursion.
This makes the code shorter and less repetitive. TESTED: All unit tests still pass. PiperOrigin-RevId: 306161475
This commit is contained in:
parent
2020349468
commit
ef0b5584e5
|
@ -211,11 +211,7 @@ func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers {
|
|||
if b == nil {
|
||||
break
|
||||
}
|
||||
layers, err := ParseEther(b)
|
||||
if err != nil {
|
||||
conn.t.Logf("debug: can't parse frame, ignoring: %s", err)
|
||||
continue // Ignore packets that can't be parsed.
|
||||
}
|
||||
layers := Parse(ParseEther, b)
|
||||
if !conn.incoming.match(layers) {
|
||||
continue // Ignore packets that don't match the expected incoming.
|
||||
}
|
||||
|
@ -418,11 +414,7 @@ func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP {
|
|||
if b == nil {
|
||||
break
|
||||
}
|
||||
layers, err := ParseEther(b)
|
||||
if err != nil {
|
||||
conn.t.Logf("can't parse frame: %s", err)
|
||||
continue // Ignore packets that can't be parsed.
|
||||
}
|
||||
layers := Parse(ParseEther, b)
|
||||
if !conn.incoming.match(layers) {
|
||||
continue // Ignore packets that don't match the expected incoming.
|
||||
}
|
||||
|
|
|
@ -172,27 +172,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol
|
|||
return &v
|
||||
}
|
||||
|
||||
// LayerParser parses the input bytes and returns a Layer along with the next
|
||||
// LayerParser to run. If there is no more parsing to do, the returned
|
||||
// LayerParser is nil.
|
||||
type LayerParser func([]byte) (Layer, LayerParser)
|
||||
|
||||
// Parse parses bytes starting with the first LayerParser and using successive
|
||||
// LayerParsers until all the bytes are parsed.
|
||||
func Parse(parser LayerParser, b []byte) Layers {
|
||||
var layers Layers
|
||||
for {
|
||||
var layer Layer
|
||||
layer, parser = parser(b)
|
||||
layers = append(layers, layer)
|
||||
if parser == nil {
|
||||
break
|
||||
}
|
||||
b = b[layer.length():]
|
||||
}
|
||||
layers.linkLayers()
|
||||
return layers
|
||||
}
|
||||
|
||||
// ParseEther parses the bytes assuming that they start with an ethernet header
|
||||
// and continues parsing further encapsulations.
|
||||
func ParseEther(b []byte) (Layers, error) {
|
||||
func ParseEther(b []byte) (Layer, LayerParser) {
|
||||
h := header.Ethernet(b)
|
||||
ether := Ether{
|
||||
SrcAddr: LinkAddress(h.SourceAddress()),
|
||||
DstAddr: LinkAddress(h.DestinationAddress()),
|
||||
Type: NetworkProtocolNumber(h.Type()),
|
||||
}
|
||||
layers := Layers{ðer}
|
||||
var nextParser LayerParser
|
||||
switch h.Type() {
|
||||
case header.IPv4ProtocolNumber:
|
||||
moreLayers, err := ParseIPv4(b[ether.length():])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(layers, moreLayers...), nil
|
||||
nextParser = ParseIPv4
|
||||
default:
|
||||
// TODO(b/150301488): Support more protocols, like IPv6.
|
||||
return nil, fmt.Errorf("ethernet header's type field is unrecognized: %#04x", h.Type())
|
||||
// Assume that the rest is a payload.
|
||||
nextParser = ParsePayload
|
||||
}
|
||||
return ðer, nextParser
|
||||
}
|
||||
|
||||
func (l *Ether) match(other Layer) bool {
|
||||
|
@ -313,7 +332,7 @@ func Address(v tcpip.Address) *tcpip.Address {
|
|||
|
||||
// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
|
||||
// continues parsing further encapsulations.
|
||||
func ParseIPv4(b []byte) (Layers, error) {
|
||||
func ParseIPv4(b []byte) (Layer, LayerParser) {
|
||||
h := header.IPv4(b)
|
||||
tos, _ := h.TOS()
|
||||
ipv4 := IPv4{
|
||||
|
@ -329,22 +348,17 @@ func ParseIPv4(b []byte) (Layers, error) {
|
|||
SrcAddr: Address(h.SourceAddress()),
|
||||
DstAddr: Address(h.DestinationAddress()),
|
||||
}
|
||||
layers := Layers{&ipv4}
|
||||
var nextParser LayerParser
|
||||
switch h.TransportProtocol() {
|
||||
case header.TCPProtocolNumber:
|
||||
moreLayers, err := ParseTCP(b[ipv4.length():])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(layers, moreLayers...), nil
|
||||
nextParser = ParseTCP
|
||||
case header.UDPProtocolNumber:
|
||||
moreLayers, err := ParseUDP(b[ipv4.length():])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(layers, moreLayers...), nil
|
||||
nextParser = ParseUDP
|
||||
default:
|
||||
// Assume that the rest is a payload.
|
||||
nextParser = ParsePayload
|
||||
}
|
||||
return nil, fmt.Errorf("ipv4 header's protocol field is unrecognized: %#02x", h.Protocol())
|
||||
return &ipv4, nextParser
|
||||
}
|
||||
|
||||
func (l *IPv4) match(other Layer) bool {
|
||||
|
@ -470,7 +484,7 @@ func Uint32(v uint32) *uint32 {
|
|||
|
||||
// ParseTCP parses the bytes assuming that they start with a tcp header and
|
||||
// continues parsing further encapsulations.
|
||||
func ParseTCP(b []byte) (Layers, error) {
|
||||
func ParseTCP(b []byte) (Layer, LayerParser) {
|
||||
h := header.TCP(b)
|
||||
tcp := TCP{
|
||||
SrcPort: Uint16(h.SourcePort()),
|
||||
|
@ -483,12 +497,7 @@ func ParseTCP(b []byte) (Layers, error) {
|
|||
Checksum: Uint16(h.Checksum()),
|
||||
UrgentPointer: Uint16(h.UrgentPointer()),
|
||||
}
|
||||
layers := Layers{&tcp}
|
||||
moreLayers, err := ParsePayload(b[tcp.length():])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(layers, moreLayers...), nil
|
||||
return &tcp, ParsePayload
|
||||
}
|
||||
|
||||
func (l *TCP) match(other Layer) bool {
|
||||
|
@ -557,8 +566,8 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error {
|
|||
}
|
||||
|
||||
// ParseUDP parses the bytes assuming that they start with a udp header and
|
||||
// continues parsing further encapsulations.
|
||||
func ParseUDP(b []byte) (Layers, error) {
|
||||
// returns the parsed layer and the next parser to use.
|
||||
func ParseUDP(b []byte) (Layer, LayerParser) {
|
||||
h := header.UDP(b)
|
||||
udp := UDP{
|
||||
SrcPort: Uint16(h.SourcePort()),
|
||||
|
@ -566,12 +575,7 @@ func ParseUDP(b []byte) (Layers, error) {
|
|||
Length: Uint16(h.Length()),
|
||||
Checksum: Uint16(h.Checksum()),
|
||||
}
|
||||
layers := Layers{&udp}
|
||||
moreLayers, err := ParsePayload(b[udp.length():])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(layers, moreLayers...), nil
|
||||
return &udp, ParsePayload
|
||||
}
|
||||
|
||||
func (l *UDP) match(other Layer) bool {
|
||||
|
@ -603,11 +607,11 @@ func (l *Payload) String() string {
|
|||
|
||||
// ParsePayload parses the bytes assuming that they start with a payload and
|
||||
// continue to the end. There can be no further encapsulations.
|
||||
func ParsePayload(b []byte) (Layers, error) {
|
||||
func ParsePayload(b []byte) (Layer, LayerParser) {
|
||||
payload := Payload{
|
||||
Bytes: b,
|
||||
}
|
||||
return Layers{&payload}, nil
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func (l *Payload) toBytes() ([]byte, error) {
|
||||
|
@ -625,15 +629,24 @@ func (l *Payload) length() int {
|
|||
// Layers is an array of Layer and supports similar functions to Layer.
|
||||
type Layers []Layer
|
||||
|
||||
func (ls *Layers) toBytes() ([]byte, error) {
|
||||
// linkLayers sets the linked-list ponters in ls.
|
||||
func (ls *Layers) linkLayers() {
|
||||
for i, l := range *ls {
|
||||
if i > 0 {
|
||||
l.setPrev((*ls)[i-1])
|
||||
} else {
|
||||
l.setPrev(nil)
|
||||
}
|
||||
if i+1 < len(*ls) {
|
||||
l.setNext((*ls)[i+1])
|
||||
} else {
|
||||
l.setNext(nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *Layers) toBytes() ([]byte, error) {
|
||||
ls.linkLayers()
|
||||
outBytes := []byte{}
|
||||
for _, l := range *ls {
|
||||
layerBytes, err := l.toBytes()
|
||||
|
|
Loading…
Reference in New Issue