diff options
author | Anton Kling <anton@kling.gg> | 2024-06-11 13:33:01 +0200 |
---|---|---|
committer | Anton Kling <anton@kling.gg> | 2024-06-11 15:18:40 +0200 |
commit | abf9cf5bec2712465400417cc8232fee2d1cce28 (patch) | |
tree | 3e43f9cc8e194aa20de6e2e5a1916a10ced69f44 /kernel/network/tcp.c | |
parent | b118759096ef08bfcc003933059de46a7a964ad7 (diff) |
TCP stuff
Diffstat (limited to 'kernel/network/tcp.c')
-rw-r--r-- | kernel/network/tcp.c | 269 |
1 files changed, 195 insertions, 74 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index aeabbf2..fc8d97c 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -1,10 +1,13 @@ #include <assert.h> #include <cpu/arch_inst.h> #include <drivers/pit.h> +#include <fs/vfs.h> +#include <math.h> #include <network/arp.h> #include <network/bytes.h> #include <network/ipv4.h> #include <network/udp.h> +#include <random.h> #define CWR (1 << 7) #define ECE (1 << 6) @@ -15,6 +18,8 @@ #define SYN (1 << 1) #define FIN (1 << 0) +#define MSS 536 + // FIXME: This should be dynamic #define WINDOW_SIZE 4096 @@ -73,28 +78,82 @@ u16 tcp_checksum(u16 *buffer, int size) { return (u16)(~cksum); } -void tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, - u16 payload_length, struct TCP_HEADER *header) { - struct PSEUDO_TCP_HEADER ps; - memset(&ps, 0, sizeof(ps)); - memcpy(&ps.src_addr, &src_ip.d, sizeof(u32)); - memcpy(&ps.dst_addr, &dst_ip, sizeof(u32)); - ps.protocol = 6; - ps.tcp_length = htons(20 + payload_length); - ps.src_port = header->src_port; - ps.dst_port = header->dst_port; - ps.seq_num = header->seq_num; - ps.ack_num = header->ack_num; - ps.data_offset = header->data_offset; - ps.reserved = header->reserved; - ps.flags = header->flags; - ps.window_size = header->window_size; - ps.urgent_pointer = header->urgent_pointer; - int buffer_length = sizeof(ps) + payload_length; +u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, + u16 payload_length, const struct TCP_HEADER *header, + int total) { + if (total < header->data_offset + payload_length) { + return 0; + } + + int pseudo = sizeof(u32) * 2 + sizeof(u16) + sizeof(u8) * 2; + + int buffer_length = + pseudo + header->data_offset * sizeof(u32) + payload_length; u8 buffer[buffer_length]; - memcpy(buffer, &ps, sizeof(ps)); - memcpy(buffer + sizeof(ps), payload, payload_length); - header->checksum = tcp_checksum((u16 *)buffer, buffer_length); + u8 *ptr = buffer; + memcpy(ptr, &src_ip.d, sizeof(u32)); + ptr += sizeof(u32); + memcpy(ptr, &dst_ip, sizeof(u32)); + ptr += sizeof(u32); + *ptr = 0; + ptr += sizeof(u8); + *ptr = 6; + ptr += sizeof(u8); + *(u16 *)ptr = htons(header->data_offset * sizeof(u32) + payload_length); + ptr += sizeof(u16); + memcpy(ptr, header, header->data_offset * sizeof(u32)); + memset(ptr + 16, 0, sizeof(u16)); // set checksum to zero + ptr += header->data_offset * sizeof(u32); + memcpy(ptr, payload, payload_length); + + return tcp_checksum((u16 *)buffer, buffer_length); +} + +struct TcpPacket { + u32 time; + u32 seq_num; + u16 payload_length; + u8 *buffer; + u16 length; +}; + +static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length, + u32 seq_num, u32 payload_length) { + if (payload_length > 0) { + struct TcpPacket *packet = kmalloc(sizeof(struct TcpPacket)); + assert(packet); + packet->time = pit_num_ms(); + packet->seq_num = seq_num; + packet->buffer = buffer; + packet->length = length; + packet->payload_length = payload_length; + + assert(relist_add(&con->inflight, packet, NULL)); + } + send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, buffer, length); +} + +void tcp_resend_packets(struct TcpConnection *con) { + if (con->inflight.num_entries > 0) { + for (u32 i = 0;; i++) { + struct TcpPacket *packet; + int end; + if (!relist_get(&con->inflight, i, (void *)&packet, &end)) { + if (end) { + break; + } + continue; + } + if (packet->time + 200 > pit_num_ms()) { + continue; + } + // resend the packet + relist_remove(&con->inflight, i); + tcp_send(con, packet->buffer, packet->length, packet->seq_num, + packet->payload_length); + kfree(packet); + } + } } void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { @@ -111,14 +170,15 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { u8 payload[0]; u16 payload_length = 0; - tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, - payload_length, &header); + header.checksum = tcp_calculate_checksum( + ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, + &header, sizeof(struct TCP_HEADER) + payload_length); int send_len = sizeof(header) + payload_length; - u8 send_buffer[send_len]; + u8 *send_buffer = kmalloc(send_len); memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, send_buffer, send_len); + tcp_send(con, send_buffer, send_len, con->seq, 0); } void tcp_close_connection(struct TcpConnection *con) { @@ -134,12 +194,30 @@ void tcp_send_syn(struct TcpConnection *con) { con->seq++; } -void send_tcp_packet(struct TcpConnection *con, const u8 *payload, - u16 payload_length) { - if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) { - send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); - payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER); - payload += 1500 - 20 - sizeof(struct TCP_HEADER); +int tcp_can_send(struct TcpConnection *con, u16 payload_length) { + if (con->inflight.num_entries > 2) { + tcp_resend_packets(con); + return 0; + } + if (con->seq - con->seq_ack + payload_length > con->current_window_size) { + tcp_resend_packets(con); + return 0; + } + return 1; +} + +int send_tcp_packet(struct TcpConnection *con, const u8 *payload, + u16 payload_length) { + if (!tcp_can_send(con, payload_length)) { + return 0; + } + + if (payload_length > MSS) { + if (0 == send_tcp_packet(con, payload, MSS)) { + return 0; + } + payload_length -= MSS; + payload += MSS; return send_tcp_packet(con, payload, payload_length); } struct TCP_HEADER header = {0}; @@ -152,24 +230,37 @@ void send_tcp_packet(struct TcpConnection *con, const u8 *payload, header.flags = PSH | ACK; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, - payload_length, &header); + header.checksum = tcp_calculate_checksum( + ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, + &header, sizeof(struct TCP_HEADER) + payload_length); int send_len = sizeof(header) + payload_length; - u8 send_buffer[send_len]; + u8 *send_buffer = kmalloc(send_len); + assert(send_buffer); memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, send_buffer, send_len); + + tcp_send(con, send_buffer, send_len, con->seq, payload_length); con->seq += payload_length; + return 1; } -void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) { +void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, + u32 payload_length) { const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload; - (void)header; + u16 tcp_payload_length = payload_length - header->data_offset * sizeof(u32); + const u8 *tcp_payload = payload + header->data_offset * sizeof(u32); + u16 checksum = + tcp_calculate_checksum(src_ip, dst_ip.d, tcp_payload, tcp_payload_length, + header, payload_length); + if (header->checksum != checksum) { + return; + } u16 n_src_port = header->src_port; u16 n_dst_port = header->dst_port; u32 n_seq_num = header->seq_num; u32 n_ack_num = header->ack_num; + u32 n_window_size = header->window_size; u8 flags = header->flags; @@ -177,14 +268,18 @@ void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) { u16 dst_port = htons(n_dst_port); u32 seq_num = htonl(n_seq_num); u32 ack_num = htonl(n_ack_num); + u16 window_size = htons(n_window_size); + (void)ack_num; if (SYN == flags) { struct TcpConnection *con = internal_tcp_incoming(src_ip.d, src_port, 0, dst_port); - if(!con) { + if (!con) { return; } + con->window_size = window_size; + con->current_window_size = MSS; con->ack = seq_num + 1; tcp_send_empty_payload(con, SYN | ACK); con->seq++; @@ -195,49 +290,75 @@ void handle_tcp(ipv4_t src_ip, const u8 *payload, u32 payload_length) { tcp_find_connection(src_ip, src_port, dst_port); if (!incoming_connection) { kprintf("unable to find open port for incoming connection\n"); + return; + } + incoming_connection->window_size = window_size; + incoming_connection->unhandled_packet = 1; + if (0 != (flags & RST)) { + klog("Requested port is closed", LOG_NOTE); + incoming_connection->dead = 1; + return; } - if (incoming_connection) { - incoming_connection->unhandled_packet = 1; - if (0 != (flags & RST)) { - klog("Requested port is closed", LOG_NOTE); - incoming_connection->dead = 1; + if (ACK == flags) { + if (0 == incoming_connection->handshake_state) { + // Then it is probably a response to the SYN|ACK we sent. + incoming_connection->handshake_state = 1; return; } - if (ACK == flags) { - if (0 == incoming_connection->handshake_state) { - // Then it is probably a response to the SYN|ACK we sent. - incoming_connection->handshake_state = 1; - return; + } + if ((SYN | ACK) == flags) { + assert(0 == incoming_connection->handshake_state); + incoming_connection->handshake_state = 1; + incoming_connection->ack = seq_num + 1; + tcp_send_ack(incoming_connection); + return; + } + if (ACK & flags) { + incoming_connection->seq_ack = ack_num; + if (incoming_connection->inflight.num_entries > 0) { + for (u32 i = 0;; i++) { + struct TcpPacket *packet; + int end; + if (!relist_get(&incoming_connection->inflight, i, (void *)&packet, + &end)) { + if (end) { + break; + } + continue; + } + if (packet->seq_num + packet->payload_length != ack_num) { + continue; + } + relist_remove(&incoming_connection->inflight, i); + kfree(packet->buffer); + kfree(packet); + break; } } - if ((SYN | ACK) == flags) { - assert(0 == incoming_connection->handshake_state); - incoming_connection->handshake_state = 1; - - incoming_connection->ack = seq_num + 1; - - tcp_send_ack(incoming_connection); - } - u16 tcp_payload_length = payload_length - header->data_offset * sizeof(u32); - if (tcp_payload_length > 0) { - const u8 *tcp_payload = payload + header->data_offset * sizeof(u32); - u32 len = ringbuffer_write(&incoming_connection->incoming_buffer, - tcp_payload, tcp_payload_length); - assert(len == tcp_payload_length); - incoming_connection->ack += len; - tcp_send_ack(incoming_connection); + tcp_resend_packets(incoming_connection); + if (0 == incoming_connection->inflight.num_entries) { + u32 rest = incoming_connection->window_size - + incoming_connection->current_window_size; + if (rest > 0) { + incoming_connection->current_window_size += min(rest, MSS); + } } - if (0 != (flags & FIN)) { - incoming_connection->ack++; + } + if (tcp_payload_length > 0) { + u32 len = ringbuffer_write(&incoming_connection->incoming_buffer, + tcp_payload, tcp_payload_length); + assert(len == tcp_payload_length); + incoming_connection->ack += len; + tcp_send_ack(incoming_connection); + } + if (0 != (flags & FIN)) { + incoming_connection->ack++; - tcp_send_empty_payload(incoming_connection, FIN | ACK); - incoming_connection->seq++; + tcp_send_empty_payload(incoming_connection, FIN | ACK); + incoming_connection->seq++; - incoming_connection->dead = 1; // FIXME: It should wait for a ACK - // of the FIN before the connection - // is closed. - } - } else { - return; + incoming_connection->dead = 1; // FIXME: It should wait for a ACK + // of the FIN before the connection + // is closed. } } |