// Copyright 2018 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dhcp import ( "context" "fmt" "log" "sync" "time" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) // Server is a DHCP server. type Server struct { stack *stack.Stack broadcast tcpip.FullAddress wq waiter.Queue ep tcpip.Endpoint addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure cfg Config cfgopts []option // cfg to send to client handlers []chan header mu sync.Mutex leases map[tcpip.LinkAddress]serverLease } // NewServer creates a new DHCP server and begins serving. // The server continues serving until ctx is done. func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) { s := &Server{ stack: stack, addrs: addrs, cfg: cfg, cfgopts: cfg.encode(), broadcast: tcpip.FullAddress{ Addr: "\xff\xff\xff\xff", Port: clientPort, }, handlers: make([]chan header, 8), leases: make(map[tcpip.LinkAddress]serverLease), } var err *tcpip.Error s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq) if err != nil { return nil, fmt.Errorf("dhcp: server endpoint: %v", err) } serverBroadcast := tcpip.FullAddress{ Addr: "", Port: serverPort, } if err := s.ep.Bind(serverBroadcast, nil); err != nil { return nil, fmt.Errorf("dhcp: server bind: %v", err) } for i := 0; i < len(s.handlers); i++ { ch := make(chan header, 8) s.handlers[i] = ch go s.handler(ctx, ch) } go s.expirer(ctx) go s.reader(ctx) return s, nil } func (s *Server) expirer(ctx context.Context) { t := time.NewTicker(1 * time.Minute) defer t.Stop() for { select { case <-t.C: s.mu.Lock() for linkAddr, lease := range s.leases { if time.Since(lease.start) > s.cfg.LeaseLength { lease.state = leaseExpired s.leases[linkAddr] = lease } } s.mu.Unlock() case <-ctx.Done(): return } } } // reader listens for all incoming DHCP packets and fans them out to // handling goroutines based on XID as session identifiers. func (s *Server) reader(ctx context.Context) { we, ch := waiter.NewChannelEntry(nil) s.wq.EventRegister(&we, waiter.EventIn) defer s.wq.EventUnregister(&we) for { var addr tcpip.FullAddress v, _, err := s.ep.Read(&addr) if err == tcpip.ErrWouldBlock { select { case <-ch: continue case <-ctx.Done(): return } } h := header(v) if !h.isValid() || h.op() != opRequest { continue } xid := h.xid() // Fan out the packet to a handler goroutine. // // Use a consistent handler for a given xid, so that // packets from a particular client are processed // in order. ch := s.handlers[int(xid)%len(s.handlers)] select { case <-ctx.Done(): return case ch <- h: default: // drop the packet } } } func (s *Server) handler(ctx context.Context, ch chan header) { for { select { case h := <-ch: if h == nil { return } opts, err := h.options() if err != nil { continue } // TODO: Handle DHCPRELEASE and DHCPDECLINE. msgtype, err := opts.dhcpMsgType() if err != nil { continue } switch msgtype { case dhcpDISCOVER: s.handleDiscover(h, opts) case dhcpREQUEST: s.handleRequest(h, opts) } case <-ctx.Done(): return } } } func (s *Server) handleDiscover(hreq header, opts options) { linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) xid := hreq.xid() s.mu.Lock() lease := s.leases[linkAddr] switch lease.state { case leaseNew: if len(s.leases) < len(s.addrs) { // Find an unused address. // TODO: avoid building this state on each request. alloced := make(map[tcpip.Address]bool) for _, lease := range s.leases { alloced[lease.addr] = true } for _, addr := range s.addrs { if !alloced[addr] { lease = serverLease{ start: time.Now(), addr: addr, xid: xid, state: leaseOffer, } s.leases[linkAddr] = lease break } } } else { // No more addresses, take an expired address. for k, oldLease := range s.leases { if oldLease.state == leaseExpired { delete(s.leases, k) lease = serverLease{ start: time.Now(), addr: lease.addr, xid: xid, state: leaseOffer, } s.leases[linkAddr] = lease break } } log.Printf("server has no more addresses") s.mu.Unlock() return } case leaseOffer, leaseAck, leaseExpired: lease = serverLease{ start: time.Now(), addr: s.leases[linkAddr].addr, xid: xid, state: leaseOffer, } s.leases[linkAddr] = lease } s.mu.Unlock() // DHCPOFFER opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}} opts = append(opts, s.cfgopts...) h := make(header, headerBaseSize+opts.len()) h.init() h.setOp(opReply) copy(h.xidbytes(), hreq.xidbytes()) copy(h.yiaddr(), lease.addr) copy(h.siaddr(), s.cfg.ServerAddress) copy(h.chaddr(), hreq.chaddr()) h.setOptions(opts) s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) } func (s *Server) handleRequest(hreq header, opts options) { linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) xid := hreq.xid() s.mu.Lock() lease := s.leases[linkAddr] switch lease.state { case leaseOffer, leaseAck, leaseExpired: lease = serverLease{ start: time.Now(), addr: s.leases[linkAddr].addr, xid: xid, state: leaseAck, } s.leases[linkAddr] = lease } s.mu.Unlock() if lease.state == leaseNew { // TODO: NACK or accept request return } // DHCPACK opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}} opts = append(opts, s.cfgopts...) h := make(header, headerBaseSize+opts.len()) h.init() h.setOp(opReply) copy(h.xidbytes(), hreq.xidbytes()) copy(h.yiaddr(), lease.addr) copy(h.siaddr(), s.cfg.ServerAddress) copy(h.chaddr(), hreq.chaddr()) h.setOptions(opts) s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) } type leaseState int const ( leaseNew leaseState = iota leaseOffer leaseAck leaseExpired ) type serverLease struct { start time.Time addr tcpip.Address xid xid state leaseState }