Julia Evans

Day 26: Trying to describe the TCP state machine in a readable way. Failing.

Today I made a bunch of progress (I can now be a TCP server, kinda!), but I want to talk about problems instead.

The main function right now in this TCP stack is called handle(), and it’s responsible for moving from one part of the state machine to another.

In particular, it has to

  • Drop inappropriate packets
  • Increment the current ACK number
  • Make state transitions
  • Send out ACKs and SYNs and FINs and FIN-ACKs when appropriate

It is kind of a mess and I am finding it pretty hard to reason about and test.

Yesterday it looked like this:

def handle(self, packet):
    # Update our state to indicate that we've received the packet
    self.ack = max(self.next_seq(packet), self.ack)
    if hasattr(packet, 'load'):
        self.recv_buffer += packet.load

    recv_flags = packet.sprintf("%TCP.flags%")
    send_flags = ""

    # Handle all the cases for self.state explicitly
    if self.state == "ESTABLISHED" and 'F' in recv_flags:
        send_flags = "F"
        self.state = "TIME-WAIT"
    elif self.state == "ESTABLISHED":
        pass
    elif self.state == "SYN-SENT":
        self.seq += 1
        self.state = "ESTABLISHED"
    elif self.state == "FIN-WAIT-1" and 'F' in recv_flags:
        self.seq += 1
        self.state = "TIME-WAIT"
    else:
        raise BadPacketError("Oh no!")

    self._send_ack(flags=send_flags)

In particular, I thought this _send_ack() call at the end was a great idea, because you always want to send an ACK! Except when you don’t!

In fact, I have just remembered that my primary insight from yesterday was that I always wanted to send an ACK. But it turns out that actually there are a couple of of cases where you don’t want to send an ACK:

  • The packet you’re receiving was itself an ACK
  • You’ve just received a RST packet and are closing the connection

and then I found the _send_ack() call at the bottom confusing, because it wasn’t clear under what conditions the code actually go there.

So now I have, after some suggestions from Allison:

def handle(self, packet):
    if self.last_ack_sent and self.last_ack_sent != packet.seq:
        # We're not in a place to receive this packet. Drop it.
        return

    self.last_ack_sent = max(self.next_seq(packet), self.last_ack_sent)

    recv_flags = packet.sprintf("%TCP.flags%")

    # Handle all the cases for self.state explicitly
    if self._has_load(packet):
        self.recv_buffer += packet.load
        self._send_ack()
    elif "R" in recv_flags:
        self._close()
    elif "S" in recv_flags:
        if self.state == "LISTEN":
            self.state = "SYN-RECEIVED"
            self._set_dest(packet.payload.src, packet.sport)
            self._send_ack(flags="S")
        elif self.state == "SYN-SENT":
            self.seq += 1
            self.state = "ESTABLISHED"
            self._send_ack()
    elif "F" in recv_flags:
        if self.state == "ESTABLISHED":
            self.seq += 1
            self.state = "LAST-ACK"
            self._send_ack(flags="F")
        elif self.state == "FIN-WAIT-1":
            self.seq += 1
            self._send_ack()
            self._close()
    elif "A" in recv_flags:
        if self.state == "SYN-RECEIVED":
            self.state = "ESTABLISHED"
        elif self.state == "LAST-ACK":
            self._close()
    else:
        raise BadPacketError("Oh no!")

This solves a bunch more problems than the first function. In particular, it

  • Ignores packets with the wrong sequence number
  • Updates the last_ack_sent with the next sequence number that we’re expecting

Good things:

  • If I’m in “ESTABLISHED” and get a “FA”, it’s fairly easy to see what’s going to happen in the giant if statement

Bad things:

  • The cases aren’t all mutually exclusive, so the order matters. :[

  • Each one is a seemingly random combination of self.seq += 1, self._send_ack(), self._close(), and a change to self.state. It’s hard to make sure the whole thing is right without a ton of unit testing.

My confusion about this function largely reflects my confusion about TCP at this point, I think.