summaryrefslogtreecommitdiff
path: root/kernel/network/tcp.c
diff options
context:
space:
mode:
authorAnton Kling <anton@kling.gg>2024-06-22 14:34:21 +0200
committerAnton Kling <anton@kling.gg>2024-06-22 14:34:21 +0200
commit01b88a7bf9fb4c78bd632bfccb06f3d320a21fd5 (patch)
tree20d9a6dcc155e7c8b6e067c6ba6d7b42df4365fd /kernel/network/tcp.c
parentaf313dec6b7698b6f948b97669aa7be91717a451 (diff)
Kernel stuff
Diffstat (limited to 'kernel/network/tcp.c')
-rw-r--r--kernel/network/tcp.c337
1 files changed, 144 insertions, 193 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c
index d171722..2dc7180 100644
--- a/kernel/network/tcp.c
+++ b/kernel/network/tcp.c
@@ -6,23 +6,12 @@
#include <network/arp.h>
#include <network/bytes.h>
#include <network/ipv4.h>
+#include <network/tcp.h>
#include <network/udp.h>
#include <random.h>
-#define CWR (1 << 7)
-#define ECE (1 << 6)
-#define URG (1 << 5)
-#define ACK (1 << 4)
-#define PSH (1 << 3)
-#define RST (1 << 2)
-#define SYN (1 << 1)
-#define FIN (1 << 0)
-
#define MSS 536
-// FIXME: This should be dynamic
-#define WINDOW_SIZE 4096
-
struct __attribute__((__packed__)) TCP_HEADER {
u16 src_port;
u16 dst_port;
@@ -36,36 +25,9 @@ struct __attribute__((__packed__)) TCP_HEADER {
u16 urgent_pointer;
};
-struct __attribute__((__packed__)) PSEUDO_TCP_HEADER {
- u32 src_addr;
- u32 dst_addr;
- u8 zero;
- u8 protocol;
- u16 tcp_length;
- u16 src_port;
- u16 dst_port;
- u32 seq_num;
- u32 ack_num;
- u8 reserved : 4;
- u8 data_offset : 4;
- u8 flags;
- u16 window_size;
- u16 checksum_zero; // ?????
- u16 urgent_pointer;
-};
-
-void tcp_wait_reply(struct TcpConnection *con) {
- for (;;) {
- if (con->unhandled_packet) {
- return;
- }
- switch_task();
- }
-}
-
-u16 tcp_checksum(u16 *buffer, int size) {
- unsigned long cksum = 0;
- while (size > 1) {
+u16 tcp_checksum(u16 *buffer, u32 size) {
+ u32 cksum = 0;
+ for (; size > 1;) {
cksum += *buffer++;
size -= sizeof(u16);
}
@@ -90,6 +52,7 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
int buffer_length =
pseudo + header->data_offset * sizeof(u32) + payload_length;
u8 buffer[buffer_length];
+
u8 *ptr = buffer;
memcpy(ptr, &src_ip.d, sizeof(u32));
ptr += sizeof(u32);
@@ -109,70 +72,25 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
return tcp_checksum((u16 *)buffer, buffer_length);
}
-struct TcpPacket {
- u32 time;
- u32 seq_num;
- u16 payload_length;
- u8 *buffer;
- u16 length;
-};
-
static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length,
u32 seq_num, u32 payload_length) {
- if (payload_length > 0) {
- struct TcpPacket *packet = kmalloc(sizeof(struct TcpPacket));
- assert(packet);
- packet->time = pit_num_ms();
- packet->seq_num = seq_num;
- packet->buffer = buffer;
- packet->length = length;
- packet->payload_length = payload_length;
-
- assert(relist_add(&con->inflight, packet, NULL));
- }
send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, buffer, length);
}
-void tcp_resend_packets(struct TcpConnection *con) {
- if (0 == con->inflight.num_entries) {
- return;
- }
- int did_resend = 0;
- for (u32 i = 0;; i++) {
- struct TcpPacket *packet;
- int end;
- if (!relist_get(&con->inflight, i, (void *)&packet, &end)) {
- if (end) {
- break;
- }
- continue;
- }
- if (packet->time + 200 > pit_num_ms()) {
- continue;
- }
- // resend the packet
- did_resend = 1;
- relist_remove(&con->inflight, i);
- tcp_send(con, packet->buffer, packet->length, packet->seq_num,
- packet->payload_length);
- kfree(packet);
- }
- if (did_resend) {
- con->max_inflight = 1;
- con->window_size = MSS;
- }
-}
-
void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
struct TCP_HEADER header = {0};
header.src_port = htons(con->incoming_port);
header.dst_port = htons(con->outgoing_port);
- header.seq_num = htonl(con->seq);
- header.ack_num = htonl(con->ack);
+ header.seq_num = htonl(con->snd_nxt);
+ if (flags & ACK) {
+ header.ack_num = htonl(con->sent_ack);
+ } else {
+ header.ack_num = 0;
+ }
header.data_offset = 5;
header.reserved = 0;
header.flags = flags;
- header.window_size = htons(WINDOW_SIZE);
+ header.window_size = htons(con->rcv_wnd);
header.urgent_pointer = 0;
u8 payload[0];
@@ -185,28 +103,41 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- tcp_send(con, send_buffer, send_len, con->seq, 0);
-}
-
-void tcp_close_connection(struct TcpConnection *con) {
- tcp_send_empty_payload(con, FIN | ACK);
-}
+ tcp_send(con, send_buffer, send_len, con->snd_nxt, 0);
-void tcp_send_ack(struct TcpConnection *con) {
- tcp_send_empty_payload(con, ACK);
+ con->snd_nxt += (flags & SYN) ? 1 : 0;
+ con->snd_nxt += (flags & FIN) ? 1 : 0;
+ con->snd_max = max(con->snd_nxt, con->snd_max);
}
-void tcp_send_syn(struct TcpConnection *con) {
- tcp_send_empty_payload(con, SYN);
- con->seq++;
+void tcp_close_connection(struct TcpConnection *con) {
+ if (TCP_STATE_CLOSE_WAIT == con->state) {
+ tcp_send_empty_payload(con, FIN);
+ con->state = TCP_STATE_LAST_ACK;
+ return;
+ }
+ if (TCP_STATE_ESTABLISHED == con->state) {
+ tcp_send_empty_payload(con, FIN);
+ con->state = TCP_STATE_FIN_WAIT1;
+ return;
+ }
+ if (TCP_STATE_SYN_RECIEVED == con->state) {
+ tcp_send_empty_payload(con, FIN);
+ con->state = TCP_STATE_FIN_WAIT1;
+ return;
+ }
+ if (TCP_STATE_SYN_SENT == con->state) {
+ con->state = TCP_STATE_CLOSED;
+ // TODO: Cleanup
+ return;
+ }
}
u16 tcp_can_send(struct TcpConnection *con) {
- if (con->inflight.num_entries > con->max_inflight) {
- tcp_resend_packets(con);
+ if (TCP_STATE_CLOSED == con->state) {
return 0;
}
- return con->current_window_size;
+ return con->snd_una + con->snd_wnd - con->snd_max;
}
int send_tcp_packet(struct TcpConnection *con, const u8 *payload,
@@ -215,23 +146,24 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload,
return 0;
}
- if (payload_length > MSS) {
- if (0 == send_tcp_packet(con, payload, MSS)) {
+ u16 len = min(1000, payload_length);
+ if (payload_length > len) {
+ if (0 == send_tcp_packet(con, payload, len)) {
return 0;
}
- payload_length -= MSS;
- payload += MSS;
+ payload_length -= len;
+ payload += len;
return send_tcp_packet(con, payload, payload_length);
}
struct TCP_HEADER header = {0};
header.src_port = htons(con->incoming_port);
header.dst_port = htons(con->outgoing_port);
- header.seq_num = htonl(con->seq);
- header.ack_num = htonl(con->ack);
+ header.seq_num = htonl(con->snd_nxt);
+ header.ack_num = htonl(con->sent_ack);
header.data_offset = 5;
header.reserved = 0;
header.flags = PSH | ACK;
- header.window_size = htons(WINDOW_SIZE);
+ header.window_size = htons(con->rcv_wnd);
header.urgent_pointer = 0;
header.checksum = tcp_calculate_checksum(
ip_address, con->outgoing_ip, (const u8 *)payload, payload_length,
@@ -242,9 +174,10 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload,
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- tcp_send(con, send_buffer, send_len, con->seq, payload_length);
+ tcp_send(con, send_buffer, send_len, con->snd_nxt, payload_length);
- con->seq += payload_length;
+ con->snd_nxt += payload_length;
+ con->snd_max = max(con->snd_nxt, con->snd_max);
return 1;
}
@@ -272,97 +205,115 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload,
u32 seq_num = htonl(n_seq_num);
u32 ack_num = htonl(n_ack_num);
u16 window_size = htons(n_window_size);
+ struct TcpConnection *con =
+ tcp_find_connection(src_ip, src_port, dst_ip, dst_port);
- (void)ack_num;
+ if (con->state == TCP_STATE_LISTEN || con->state == TCP_STATE_SYN_SENT) {
+ con->rcv_nxt = seq_num;
+ }
- if (SYN == flags) {
- struct TcpConnection *con =
- internal_tcp_incoming(src_ip.d, src_port, 0, dst_port);
- if (!con) {
- return;
- }
- con->window_size = window_size;
- con->current_window_size = MSS;
- con->ack = seq_num + 1;
- tcp_send_empty_payload(con, SYN | ACK);
- con->seq++;
+ u32 seq_num_end = seq_num + tcp_payload_length - 1;
+ if (!((con->rcv_nxt <= seq_num) && seq_num < con->rcv_nxt + con->rcv_wnd) &&
+ !((con->rcv_nxt <= seq_num_end) &&
+ seq_num_end < con->rcv_nxt + con->rcv_wnd)) {
+ kprintf("seq_num: %d\n", seq_num);
+ kprintf("seq_num_end: %d\n", seq_num_end);
+ kprintf("con->rcv_nxt: %d\n", con->rcv_nxt);
+ kprintf("con->rcv_wnd: %d\n", con->rcv_wnd);
+ // Invalid segment
+ kprintf("invalid segment\n");
return;
}
- struct TcpConnection *incoming_connection =
- tcp_find_connection(src_ip, src_port, dst_port);
- if (!incoming_connection) {
- kprintf("unable to find open port for incoming connection\n");
+ if (ack_num > con->snd_max) {
+ // TODO: Odd ACK number, what should be done?
+ kprintf("odd ACK\n");
return;
}
- incoming_connection->window_size = window_size;
- incoming_connection->unhandled_packet = 1;
- if (0 != (flags & RST)) {
- klog("Requested port is closed", LOG_NOTE);
- incoming_connection->dead = 1;
+
+ if (ack_num < con->snd_una) {
+ // TODO duplicate ACK
+ kprintf("duplicate ACK\n");
return;
}
- if (ACK == flags) {
- if (0 == incoming_connection->handshake_state) {
- // Then it is probably a response to the SYN|ACK we sent.
- incoming_connection->handshake_state = 1;
- return;
+
+ con->snd_wnd = window_size;
+ con->snd_una = ack_num;
+
+ con->sent_ack = max(con->sent_ack, seq_num + tcp_payload_length);
+ if (FIN & flags) {
+ con->sent_ack++;
+ }
+
+ switch (con->state) {
+ case TCP_STATE_LISTEN: {
+ if (SYN & flags) {
+ tcp_send_empty_payload(con, SYN | ACK);
+ con->state = TCP_STATE_SYN_RECIEVED;
+ con->rcv_nxt++;
+ break;
}
+ break;
}
- if ((SYN | ACK) == flags) {
- assert(0 == incoming_connection->handshake_state);
- incoming_connection->handshake_state = 1;
- incoming_connection->ack = seq_num + 1;
- tcp_send_ack(incoming_connection);
- return;
+ case TCP_STATE_SYN_RECIEVED: {
+ if (ACK & flags) {
+ con->state = TCP_STATE_ESTABLISHED;
+ break;
+ }
+ break;
+ }
+ case TCP_STATE_SYN_SENT: {
+ if ((ACK & flags) && (SYN & flags)) {
+ tcp_send_empty_payload(con, ACK);
+ con->state = TCP_STATE_ESTABLISHED;
+ con->rcv_nxt++;
+ break;
+ }
+ break;
}
- if (ACK & flags) {
- incoming_connection->seq_ack = ack_num;
- if (incoming_connection->inflight.num_entries > 0) {
- for (u32 i = 0;; i++) {
- struct TcpPacket *packet;
- int end;
- if (!relist_get(&incoming_connection->inflight, i, (void *)&packet,
- &end)) {
- if (end) {
- break;
- }
- continue;
- }
- if (packet->seq_num + packet->payload_length > ack_num) {
- continue;
- }
- relist_remove(&incoming_connection->inflight, i);
- kfree(packet->buffer);
- kfree(packet);
- break;
- }
+ case TCP_STATE_ESTABLISHED: {
+ if (FIN & flags) {
+ tcp_send_empty_payload(con, ACK);
+ con->state = TCP_STATE_CLOSE_WAIT;
+ break;
}
- tcp_resend_packets(incoming_connection);
- if (0 == incoming_connection->inflight.num_entries) {
- incoming_connection->max_inflight++;
- u32 rest = incoming_connection->window_size -
- incoming_connection->current_window_size;
- if (rest > 0) {
- incoming_connection->current_window_size += min(rest, MSS);
- }
+ if (tcp_payload_length > 0) {
+ int rc = ringbuffer_write(&con->incoming_buffer, tcp_payload,
+ tcp_payload_length);
+ con->rcv_nxt += rc;
+ con->rcv_wnd = ringbuffer_unused(&con->incoming_buffer);
+ tcp_send_empty_payload(con, ACK);
}
+ break;
}
- if (tcp_payload_length > 0) {
- u32 len = ringbuffer_write(&incoming_connection->incoming_buffer,
- tcp_payload, tcp_payload_length);
- assert(len == tcp_payload_length);
- incoming_connection->ack += len;
- tcp_send_ack(incoming_connection);
+ case TCP_STATE_FIN_WAIT1: {
+ if ((ACK & flags) && (FIN & flags)) {
+ tcp_send_empty_payload(con, ACK);
+ con->state = TCP_STATE_TIME_WAIT;
+ break;
+ }
+ if (ACK & flags) {
+ con->state = TCP_STATE_FIN_WAIT2;
+ break;
+ }
+ if (FIN & flags) {
+ tcp_send_empty_payload(con, ACK);
+ con->state = TCP_STATE_CLOSING;
+ break;
+ }
+ break;
}
- if (0 != (flags & FIN)) {
- incoming_connection->ack++;
-
- tcp_send_empty_payload(incoming_connection, FIN | ACK);
- incoming_connection->seq++;
-
- incoming_connection->dead = 1; // FIXME: It should wait for a ACK
- // of the FIN before the connection
- // is closed.
+ case TCP_STATE_LAST_ACK: {
+ if (ACK & flags) {
+ // TODO: cleanup
+ con->state = TCP_STATE_CLOSED;
+ break;
+ }
+ break;
+ }
+ default: {
+ klog(LOG_WARN, "TCP state not handled %d", con->state);
+ break;
}
+ };
}