summaryrefslogtreecommitdiff
path: root/kernel/network/tcp.c
diff options
context:
space:
mode:
authorAnton Kling <anton@kling.gg>2024-06-11 13:33:01 +0200
committerAnton Kling <anton@kling.gg>2024-06-11 15:18:40 +0200
commitabf9cf5bec2712465400417cc8232fee2d1cce28 (patch)
tree3e43f9cc8e194aa20de6e2e5a1916a10ced69f44 /kernel/network/tcp.c
parentb118759096ef08bfcc003933059de46a7a964ad7 (diff)
TCP stuff
Diffstat (limited to 'kernel/network/tcp.c')
-rw-r--r--kernel/network/tcp.c269
1 files changed, 195 insertions, 74 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c
index aeabbf2..fc8d97c 100644
--- a/kernel/network/tcp.c
+++ b/kernel/network/tcp.c
@@ -1,10 +1,13 @@
#include <assert.h>
#include <cpu/arch_inst.h>
#include <drivers/pit.h>
+#include <fs/vfs.h>
+#include <math.h>
#include <network/arp.h>
#include <network/bytes.h>
#include <network/ipv4.h>
#include <network/udp.h>
+#include <random.h>
#define CWR (1 << 7)
#define ECE (1 << 6)
@@ -15,6 +18,8 @@
#define SYN (1 << 1)
#define FIN (1 << 0)
+#define MSS 536
+
// FIXME: This should be dynamic
#define WINDOW_SIZE 4096
@@ -73,28 +78,82 @@ u16 tcp_checksum(u16 *buffer, int size) {
return (u16)(~cksum);
}
-void tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
- u16 payload_length, struct TCP_HEADER *header) {
- struct PSEUDO_TCP_HEADER ps;
- memset(&ps, 0, sizeof(ps));
- memcpy(&ps.src_addr, &src_ip.d, sizeof(u32));
- memcpy(&ps.dst_addr, &dst_ip, sizeof(u32));
- ps.protocol = 6;
- ps.tcp_length = htons(20 + payload_length);
- ps.src_port = header->src_port;
- ps.dst_port = header->dst_port;
- ps.seq_num = header->seq_num;
- ps.ack_num = header->ack_num;
- ps.data_offset = header->data_offset;
- ps.reserved = header->reserved;
- ps.flags = header->flags;
- ps.window_size = header->window_size;
- ps.urgent_pointer = header->urgent_pointer;
- int buffer_length = sizeof(ps) + payload_length;
+u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
+ u16 payload_length, const struct TCP_HEADER *header,
+ int total) {
+ if (total < header->data_offset + payload_length) {
+ return 0;
+ }
+
+ int pseudo = sizeof(u32) * 2 + sizeof(u16) + sizeof(u8) * 2;
+
+ int buffer_length =
+ pseudo + header->data_offset * sizeof(u32) + payload_length;
u8 buffer[buffer_length];
- memcpy(buffer, &ps, sizeof(ps));
- memcpy(buffer + sizeof(ps), payload, payload_length);
- header->checksum = tcp_checksum((u16 *)buffer, buffer_length);
+ u8 *ptr = buffer;
+ memcpy(ptr, &src_ip.d, sizeof(u32));
+ ptr += sizeof(u32);
+ memcpy(ptr, &dst_ip, sizeof(u32));
+ ptr += sizeof(u32);
+ *ptr = 0;
+ ptr += sizeof(u8);
+ *ptr = 6;
+ ptr += sizeof(u8);
+ *(u16 *)ptr = htons(header->data_offset * sizeof(u32) + payload_length);
+ ptr += sizeof(u16);
+ memcpy(ptr, header, header->data_offset * sizeof(u32));
+ memset(ptr + 16, 0, sizeof(u16)); // set checksum to zero
+ ptr += header->data_offset * sizeof(u32);
+ memcpy(ptr, payload, payload_length);
+
+ return tcp_checksum((u16 *)buffer, buffer_length);
+}
+
+struct TcpPacket {
+ u32 time;
+ u32 seq_num;
+ u16 payload_length;
+ u8 *buffer;
+ u16 length;
+};
+
+static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length,
+ u32 seq_num, u32 payload_length) {
+ if (payload_length > 0) {
+ struct TcpPacket *packet = kmalloc(sizeof(struct TcpPacket));
+ assert(packet);
+ packet->time = pit_num_ms();
+ packet->seq_num = seq_num;
+ packet->buffer = buffer;
+ packet->length = length;
+ packet->payload_length = payload_length;
+
+ assert(relist_add(&con->inflight, packet, NULL));
+ }
+ send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, buffer, length);
+}
+
+void tcp_resend_packets(struct TcpConnection *con) {
+ if (con->inflight.num_entries > 0) {
+ for (u32 i = 0;; i++) {
+ struct TcpPacket *packet;
+ int end;
+ if (!relist_get(&con->inflight, i, (void *)&packet, &end)) {
+ if (end) {
+ break;
+ }
+ continue;
+ }
+ if (packet->time + 200 > pit_num_ms()) {
+ continue;
+ }
+ // resend the packet
+ relist_remove(&con->inflight, i);
+ tcp_send(con, packet->buffer, packet->length, packet->seq_num,
+ packet->payload_length);
+ kfree(packet);
+ }
+ }
}
void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
@@ -111,14 +170,15 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
u8 payload[0];
u16 payload_length = 0;
- tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
- payload_length, &header);
+ header.checksum = tcp_calculate_checksum(
+ ip_address, con->outgoing_ip, (const u8 *)payload, payload_length,
+ &header, sizeof(struct TCP_HEADER) + payload_length);
int send_len = sizeof(header) + payload_length;
- u8 send_buffer[send_len];
+ u8 *send_buffer = kmalloc(send_len);
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, send_buffer, send_len);
+ tcp_send(con, send_buffer, send_len, con->seq, 0);
}
void tcp_close_connection(struct TcpConnection *con) {
@@ -134,12 +194,30 @@ void tcp_send_syn(struct TcpConnection *con) {
con->seq++;
}
-void send_tcp_packet(struct TcpConnection *con, const u8 *payload,
- u16 payload_length) {
- if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) {
- send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER));
- payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER);
- payload += 1500 - 20 - sizeof(struct TCP_HEADER);
+int tcp_can_send(struct TcpConnection *con, u16 payload_length) {
+ if (con->inflight.num_entries > 2) {
+ tcp_resend_packets(con);
+ return 0;
+ }
+ if (con->seq - con->seq_ack + payload_length > con->current_window_size) {
+ tcp_resend_packets(con);
+ return 0;
+ }
+ return 1;
+}
+
+int send_tcp_packet(struct TcpConnection *con, const u8 *payload,
+ u16 payload_length) {
+ if (!tcp_can_send(con, payload_length)) {
+ return 0;
+ }
+
+ if (payload_length > MSS) {
+ if (0 == send_tcp_packet(con, payload, MSS)) {
+ return 0;
+ }
+ payload_length -= MSS;
+ payload += MSS;
return send_tcp_packet(con, payload, payload_length);
}
struct TCP_HEADER header = {0};
@@ -152,24 +230,37 @@ void send_tcp_packet(struct TcpConnection *con, const u8 *payload,
header.flags = PSH | ACK;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
- payload_length, &header);
+ header.checksum = tcp_calculate_checksum(
+ ip_address, con->outgoing_ip, (const u8 *)payload, payload_length,
+ &header, sizeof(struct TCP_HEADER) + payload_length);
int send_len = sizeof(header) + payload_length;
- u8 send_buffer[send_len];
+ u8 *send_buffer = kmalloc(send_len);
+ assert(send_buffer);
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, send_buffer, send_len);
+
+ tcp_send(con, send_buffer, send_len, con->seq, payload_length);
con->seq += payload_length;
+ return 1;
}
-void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) {
+void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload,
+ u32 payload_length) {
const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload;
- (void)header;
+ u16 tcp_payload_length = payload_length - header->data_offset * sizeof(u32);
+ const u8 *tcp_payload = payload + header->data_offset * sizeof(u32);
+ u16 checksum =
+ tcp_calculate_checksum(src_ip, dst_ip.d, tcp_payload, tcp_payload_length,
+ header, payload_length);
+ if (header->checksum != checksum) {
+ return;
+ }
u16 n_src_port = header->src_port;
u16 n_dst_port = header->dst_port;
u32 n_seq_num = header->seq_num;
u32 n_ack_num = header->ack_num;
+ u32 n_window_size = header->window_size;
u8 flags = header->flags;
@@ -177,14 +268,18 @@ void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) {
u16 dst_port = htons(n_dst_port);
u32 seq_num = htonl(n_seq_num);
u32 ack_num = htonl(n_ack_num);
+ u16 window_size = htons(n_window_size);
+
(void)ack_num;
if (SYN == flags) {
struct TcpConnection *con =
internal_tcp_incoming(src_ip.d, src_port, 0, dst_port);
- if(!con) {
+ if (!con) {
return;
}
+ con->window_size = window_size;
+ con->current_window_size = MSS;
con->ack = seq_num + 1;
tcp_send_empty_payload(con, SYN | ACK);
con->seq++;
@@ -195,49 +290,75 @@ void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) {
tcp_find_connection(src_ip, src_port, dst_port);
if (!incoming_connection) {
kprintf("unable to find open port for incoming connection\n");
+ return;
+ }
+ incoming_connection->window_size = window_size;
+ incoming_connection->unhandled_packet = 1;
+ if (0 != (flags & RST)) {
+ klog("Requested port is closed", LOG_NOTE);
+ incoming_connection->dead = 1;
+ return;
}
- if (incoming_connection) {
- incoming_connection->unhandled_packet = 1;
- if (0 != (flags & RST)) {
- klog("Requested port is closed", LOG_NOTE);
- incoming_connection->dead = 1;
+ if (ACK == flags) {
+ if (0 == incoming_connection->handshake_state) {
+ // Then it is probably a response to the SYN|ACK we sent.
+ incoming_connection->handshake_state = 1;
return;
}
- if (ACK == flags) {
- if (0 == incoming_connection->handshake_state) {
- // Then it is probably a response to the SYN|ACK we sent.
- incoming_connection->handshake_state = 1;
- return;
+ }
+ if ((SYN | ACK) == flags) {
+ assert(0 == incoming_connection->handshake_state);
+ incoming_connection->handshake_state = 1;
+ incoming_connection->ack = seq_num + 1;
+ tcp_send_ack(incoming_connection);
+ return;
+ }
+ if (ACK & flags) {
+ incoming_connection->seq_ack = ack_num;
+ if (incoming_connection->inflight.num_entries > 0) {
+ for (u32 i = 0;; i++) {
+ struct TcpPacket *packet;
+ int end;
+ if (!relist_get(&incoming_connection->inflight, i, (void *)&packet,
+ &end)) {
+ if (end) {
+ break;
+ }
+ continue;
+ }
+ if (packet->seq_num + packet->payload_length != ack_num) {
+ continue;
+ }
+ relist_remove(&incoming_connection->inflight, i);
+ kfree(packet->buffer);
+ kfree(packet);
+ break;
}
}
- if ((SYN | ACK) == flags) {
- assert(0 == incoming_connection->handshake_state);
- incoming_connection->handshake_state = 1;
-
- incoming_connection->ack = seq_num + 1;
-
- tcp_send_ack(incoming_connection);
- }
- u16 tcp_payload_length = payload_length - header->data_offset * sizeof(u32);
- if (tcp_payload_length > 0) {
- const u8 *tcp_payload = payload + header->data_offset * sizeof(u32);
- u32 len = ringbuffer_write(&incoming_connection->incoming_buffer,
- tcp_payload, tcp_payload_length);
- assert(len == tcp_payload_length);
- incoming_connection->ack += len;
- tcp_send_ack(incoming_connection);
+ tcp_resend_packets(incoming_connection);
+ if (0 == incoming_connection->inflight.num_entries) {
+ u32 rest = incoming_connection->window_size -
+ incoming_connection->current_window_size;
+ if (rest > 0) {
+ incoming_connection->current_window_size += min(rest, MSS);
+ }
}
- if (0 != (flags & FIN)) {
- incoming_connection->ack++;
+ }
+ if (tcp_payload_length > 0) {
+ u32 len = ringbuffer_write(&incoming_connection->incoming_buffer,
+ tcp_payload, tcp_payload_length);
+ assert(len == tcp_payload_length);
+ incoming_connection->ack += len;
+ tcp_send_ack(incoming_connection);
+ }
+ if (0 != (flags & FIN)) {
+ incoming_connection->ack++;
- tcp_send_empty_payload(incoming_connection, FIN | ACK);
- incoming_connection->seq++;
+ tcp_send_empty_payload(incoming_connection, FIN | ACK);
+ incoming_connection->seq++;
- incoming_connection->dead = 1; // FIXME: It should wait for a ACK
- // of the FIN before the connection
- // is closed.
- }
- } else {
- return;
+ incoming_connection->dead = 1; // FIXME: It should wait for a ACK
+ // of the FIN before the connection
+ // is closed.
}
}