From 01b88a7bf9fb4c78bd632bfccb06f3d320a21fd5 Mon Sep 17 00:00:00 2001 From: Anton Kling Date: Sat, 22 Jun 2024 14:34:21 +0200 Subject: Kernel stuff --- kernel/network/ethernet.c | 4 +- kernel/network/tcp.c | 337 ++++++++++++++++++++-------------------------- kernel/network/tcp.h | 28 +++- kernel/network/udp.c | 3 +- 4 files changed, 173 insertions(+), 199 deletions(-) (limited to 'kernel/network') diff --git a/kernel/network/ethernet.c b/kernel/network/ethernet.c index 1bda07f..a4f2f85 100644 --- a/kernel/network/ethernet.c +++ b/kernel/network/ethernet.c @@ -1,12 +1,13 @@ #include #include +#include #include -#include #include #include #include #include #include +#include struct ETHERNET_HEADER { u8 mac_dst[6]; @@ -69,7 +70,6 @@ void handle_ethernet(const u8 *packet, u64 packet_length) { void send_ethernet_packet(u8 mac_dst[6], u16 type, u8 *payload, u64 payload_length) { - assert(payload_length <= 1500); // FIXME: Janky allocation, do this better u64 buffer_size = sizeof(struct ETHERNET_HEADER) + payload_length + sizeof(u32); diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index d171722..2dc7180 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -6,23 +6,12 @@ #include #include #include +#include #include #include -#define CWR (1 << 7) -#define ECE (1 << 6) -#define URG (1 << 5) -#define ACK (1 << 4) -#define PSH (1 << 3) -#define RST (1 << 2) -#define SYN (1 << 1) -#define FIN (1 << 0) - #define MSS 536 -// FIXME: This should be dynamic -#define WINDOW_SIZE 4096 - struct __attribute__((__packed__)) TCP_HEADER { u16 src_port; u16 dst_port; @@ -36,36 +25,9 @@ struct __attribute__((__packed__)) TCP_HEADER { u16 urgent_pointer; }; -struct __attribute__((__packed__)) PSEUDO_TCP_HEADER { - u32 src_addr; - u32 dst_addr; - u8 zero; - u8 protocol; - u16 tcp_length; - u16 src_port; - u16 dst_port; - u32 seq_num; - u32 ack_num; - u8 reserved : 4; - u8 data_offset : 4; - u8 flags; - u16 window_size; - u16 checksum_zero; // ????? - u16 urgent_pointer; -}; - -void tcp_wait_reply(struct TcpConnection *con) { - for (;;) { - if (con->unhandled_packet) { - return; - } - switch_task(); - } -} - -u16 tcp_checksum(u16 *buffer, int size) { - unsigned long cksum = 0; - while (size > 1) { +u16 tcp_checksum(u16 *buffer, u32 size) { + u32 cksum = 0; + for (; size > 1;) { cksum += *buffer++; size -= sizeof(u16); } @@ -90,6 +52,7 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, int buffer_length = pseudo + header->data_offset * sizeof(u32) + payload_length; u8 buffer[buffer_length]; + u8 *ptr = buffer; memcpy(ptr, &src_ip.d, sizeof(u32)); ptr += sizeof(u32); @@ -109,70 +72,25 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, return tcp_checksum((u16 *)buffer, buffer_length); } -struct TcpPacket { - u32 time; - u32 seq_num; - u16 payload_length; - u8 *buffer; - u16 length; -}; - static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length, u32 seq_num, u32 payload_length) { - if (payload_length > 0) { - struct TcpPacket *packet = kmalloc(sizeof(struct TcpPacket)); - assert(packet); - packet->time = pit_num_ms(); - packet->seq_num = seq_num; - packet->buffer = buffer; - packet->length = length; - packet->payload_length = payload_length; - - assert(relist_add(&con->inflight, packet, NULL)); - } send_ipv4_packet((ipv4_t){.d = con->outgoing_ip}, 6, buffer, length); } -void tcp_resend_packets(struct TcpConnection *con) { - if (0 == con->inflight.num_entries) { - return; - } - int did_resend = 0; - for (u32 i = 0;; i++) { - struct TcpPacket *packet; - int end; - if (!relist_get(&con->inflight, i, (void *)&packet, &end)) { - if (end) { - break; - } - continue; - } - if (packet->time + 200 > pit_num_ms()) { - continue; - } - // resend the packet - did_resend = 1; - relist_remove(&con->inflight, i); - tcp_send(con, packet->buffer, packet->length, packet->seq_num, - packet->payload_length); - kfree(packet); - } - if (did_resend) { - con->max_inflight = 1; - con->window_size = MSS; - } -} - void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { struct TCP_HEADER header = {0}; header.src_port = htons(con->incoming_port); header.dst_port = htons(con->outgoing_port); - header.seq_num = htonl(con->seq); - header.ack_num = htonl(con->ack); + header.seq_num = htonl(con->snd_nxt); + if (flags & ACK) { + header.ack_num = htonl(con->sent_ack); + } else { + header.ack_num = 0; + } header.data_offset = 5; header.reserved = 0; header.flags = flags; - header.window_size = htons(WINDOW_SIZE); + header.window_size = htons(con->rcv_wnd); header.urgent_pointer = 0; u8 payload[0]; @@ -185,28 +103,41 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - tcp_send(con, send_buffer, send_len, con->seq, 0); -} - -void tcp_close_connection(struct TcpConnection *con) { - tcp_send_empty_payload(con, FIN | ACK); -} + tcp_send(con, send_buffer, send_len, con->snd_nxt, 0); -void tcp_send_ack(struct TcpConnection *con) { - tcp_send_empty_payload(con, ACK); + con->snd_nxt += (flags & SYN) ? 1 : 0; + con->snd_nxt += (flags & FIN) ? 1 : 0; + con->snd_max = max(con->snd_nxt, con->snd_max); } -void tcp_send_syn(struct TcpConnection *con) { - tcp_send_empty_payload(con, SYN); - con->seq++; +void tcp_close_connection(struct TcpConnection *con) { + if (TCP_STATE_CLOSE_WAIT == con->state) { + tcp_send_empty_payload(con, FIN); + con->state = TCP_STATE_LAST_ACK; + return; + } + if (TCP_STATE_ESTABLISHED == con->state) { + tcp_send_empty_payload(con, FIN); + con->state = TCP_STATE_FIN_WAIT1; + return; + } + if (TCP_STATE_SYN_RECIEVED == con->state) { + tcp_send_empty_payload(con, FIN); + con->state = TCP_STATE_FIN_WAIT1; + return; + } + if (TCP_STATE_SYN_SENT == con->state) { + con->state = TCP_STATE_CLOSED; + // TODO: Cleanup + return; + } } u16 tcp_can_send(struct TcpConnection *con) { - if (con->inflight.num_entries > con->max_inflight) { - tcp_resend_packets(con); + if (TCP_STATE_CLOSED == con->state) { return 0; } - return con->current_window_size; + return con->snd_una + con->snd_wnd - con->snd_max; } int send_tcp_packet(struct TcpConnection *con, const u8 *payload, @@ -215,23 +146,24 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload, return 0; } - if (payload_length > MSS) { - if (0 == send_tcp_packet(con, payload, MSS)) { + u16 len = min(1000, payload_length); + if (payload_length > len) { + if (0 == send_tcp_packet(con, payload, len)) { return 0; } - payload_length -= MSS; - payload += MSS; + payload_length -= len; + payload += len; return send_tcp_packet(con, payload, payload_length); } struct TCP_HEADER header = {0}; header.src_port = htons(con->incoming_port); header.dst_port = htons(con->outgoing_port); - header.seq_num = htonl(con->seq); - header.ack_num = htonl(con->ack); + header.seq_num = htonl(con->snd_nxt); + header.ack_num = htonl(con->sent_ack); header.data_offset = 5; header.reserved = 0; header.flags = PSH | ACK; - header.window_size = htons(WINDOW_SIZE); + header.window_size = htons(con->rcv_wnd); header.urgent_pointer = 0; header.checksum = tcp_calculate_checksum( ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, @@ -242,9 +174,10 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload, memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - tcp_send(con, send_buffer, send_len, con->seq, payload_length); + tcp_send(con, send_buffer, send_len, con->snd_nxt, payload_length); - con->seq += payload_length; + con->snd_nxt += payload_length; + con->snd_max = max(con->snd_nxt, con->snd_max); return 1; } @@ -272,97 +205,115 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, u32 seq_num = htonl(n_seq_num); u32 ack_num = htonl(n_ack_num); u16 window_size = htons(n_window_size); + struct TcpConnection *con = + tcp_find_connection(src_ip, src_port, dst_ip, dst_port); - (void)ack_num; + if (con->state == TCP_STATE_LISTEN || con->state == TCP_STATE_SYN_SENT) { + con->rcv_nxt = seq_num; + } - if (SYN == flags) { - struct TcpConnection *con = - internal_tcp_incoming(src_ip.d, src_port, 0, dst_port); - if (!con) { - return; - } - con->window_size = window_size; - con->current_window_size = MSS; - con->ack = seq_num + 1; - tcp_send_empty_payload(con, SYN | ACK); - con->seq++; + u32 seq_num_end = seq_num + tcp_payload_length - 1; + if (!((con->rcv_nxt <= seq_num) && seq_num < con->rcv_nxt + con->rcv_wnd) && + !((con->rcv_nxt <= seq_num_end) && + seq_num_end < con->rcv_nxt + con->rcv_wnd)) { + kprintf("seq_num: %d\n", seq_num); + kprintf("seq_num_end: %d\n", seq_num_end); + kprintf("con->rcv_nxt: %d\n", con->rcv_nxt); + kprintf("con->rcv_wnd: %d\n", con->rcv_wnd); + // Invalid segment + kprintf("invalid segment\n"); return; } - struct TcpConnection *incoming_connection = - tcp_find_connection(src_ip, src_port, dst_port); - if (!incoming_connection) { - kprintf("unable to find open port for incoming connection\n"); + if (ack_num > con->snd_max) { + // TODO: Odd ACK number, what should be done? + kprintf("odd ACK\n"); return; } - incoming_connection->window_size = window_size; - incoming_connection->unhandled_packet = 1; - if (0 != (flags & RST)) { - klog("Requested port is closed", LOG_NOTE); - incoming_connection->dead = 1; + + if (ack_num < con->snd_una) { + // TODO duplicate ACK + kprintf("duplicate ACK\n"); return; } - if (ACK == flags) { - if (0 == incoming_connection->handshake_state) { - // Then it is probably a response to the SYN|ACK we sent. - incoming_connection->handshake_state = 1; - return; + + con->snd_wnd = window_size; + con->snd_una = ack_num; + + con->sent_ack = max(con->sent_ack, seq_num + tcp_payload_length); + if (FIN & flags) { + con->sent_ack++; + } + + switch (con->state) { + case TCP_STATE_LISTEN: { + if (SYN & flags) { + tcp_send_empty_payload(con, SYN | ACK); + con->state = TCP_STATE_SYN_RECIEVED; + con->rcv_nxt++; + break; } + break; } - if ((SYN | ACK) == flags) { - assert(0 == incoming_connection->handshake_state); - incoming_connection->handshake_state = 1; - incoming_connection->ack = seq_num + 1; - tcp_send_ack(incoming_connection); - return; + case TCP_STATE_SYN_RECIEVED: { + if (ACK & flags) { + con->state = TCP_STATE_ESTABLISHED; + break; + } + break; + } + case TCP_STATE_SYN_SENT: { + if ((ACK & flags) && (SYN & flags)) { + tcp_send_empty_payload(con, ACK); + con->state = TCP_STATE_ESTABLISHED; + con->rcv_nxt++; + break; + } + break; } - if (ACK & flags) { - incoming_connection->seq_ack = ack_num; - if (incoming_connection->inflight.num_entries > 0) { - for (u32 i = 0;; i++) { - struct TcpPacket *packet; - int end; - if (!relist_get(&incoming_connection->inflight, i, (void *)&packet, - &end)) { - if (end) { - break; - } - continue; - } - if (packet->seq_num + packet->payload_length > ack_num) { - continue; - } - relist_remove(&incoming_connection->inflight, i); - kfree(packet->buffer); - kfree(packet); - break; - } + case TCP_STATE_ESTABLISHED: { + if (FIN & flags) { + tcp_send_empty_payload(con, ACK); + con->state = TCP_STATE_CLOSE_WAIT; + break; } - tcp_resend_packets(incoming_connection); - if (0 == incoming_connection->inflight.num_entries) { - incoming_connection->max_inflight++; - u32 rest = incoming_connection->window_size - - incoming_connection->current_window_size; - if (rest > 0) { - incoming_connection->current_window_size += min(rest, MSS); - } + if (tcp_payload_length > 0) { + int rc = ringbuffer_write(&con->incoming_buffer, tcp_payload, + tcp_payload_length); + con->rcv_nxt += rc; + con->rcv_wnd = ringbuffer_unused(&con->incoming_buffer); + tcp_send_empty_payload(con, ACK); } + break; } - if (tcp_payload_length > 0) { - u32 len = ringbuffer_write(&incoming_connection->incoming_buffer, - tcp_payload, tcp_payload_length); - assert(len == tcp_payload_length); - incoming_connection->ack += len; - tcp_send_ack(incoming_connection); + case TCP_STATE_FIN_WAIT1: { + if ((ACK & flags) && (FIN & flags)) { + tcp_send_empty_payload(con, ACK); + con->state = TCP_STATE_TIME_WAIT; + break; + } + if (ACK & flags) { + con->state = TCP_STATE_FIN_WAIT2; + break; + } + if (FIN & flags) { + tcp_send_empty_payload(con, ACK); + con->state = TCP_STATE_CLOSING; + break; + } + break; } - if (0 != (flags & FIN)) { - incoming_connection->ack++; - - tcp_send_empty_payload(incoming_connection, FIN | ACK); - incoming_connection->seq++; - - incoming_connection->dead = 1; // FIXME: It should wait for a ACK - // of the FIN before the connection - // is closed. + case TCP_STATE_LAST_ACK: { + if (ACK & flags) { + // TODO: cleanup + con->state = TCP_STATE_CLOSED; + break; + } + break; + } + default: { + klog(LOG_WARN, "TCP state not handled %d", con->state); + break; } + }; } diff --git a/kernel/network/tcp.h b/kernel/network/tcp.h index ec253bd..90975e4 100644 --- a/kernel/network/tcp.h +++ b/kernel/network/tcp.h @@ -1,9 +1,31 @@ #include #include -void tcp_send_syn(struct TcpConnection *con); -void tcp_wait_reply(struct TcpConnection *con); -void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, u32 payload_length); + +#define CWR (1 << 7) +#define ECE (1 << 6) +#define URG (1 << 5) +#define ACK (1 << 4) +#define PSH (1 << 3) +#define RST (1 << 2) +#define SYN (1 << 1) +#define FIN (1 << 0) + +#define TCP_STATE_CLOSED 0 +#define TCP_STATE_LISTEN 1 +#define TCP_STATE_SYN_SENT 2 +#define TCP_STATE_SYN_RECIEVED 3 +#define TCP_STATE_ESTABLISHED 4 +#define TCP_STATE_CLOSE_WAIT 5 +#define TCP_STATE_FIN_WAIT1 6 +#define TCP_STATE_CLOSING 7 +#define TCP_STATE_LAST_ACK 8 +#define TCP_STATE_FIN_WAIT2 9 +#define TCP_STATE_TIME_WAIT 10 + +void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, + u32 payload_length); int send_tcp_packet(struct TcpConnection *con, const u8 *payload, u16 payload_length); void tcp_close_connection(struct TcpConnection *con); u16 tcp_can_send(struct TcpConnection *con); +void tcp_send_empty_payload(struct TcpConnection *con, u8 flags); diff --git a/kernel/network/udp.c b/kernel/network/udp.c index 0c8d6e9..c0a5237 100644 --- a/kernel/network/udp.c +++ b/kernel/network/udp.c @@ -24,7 +24,8 @@ void send_udp_packet(struct sockaddr_in *src, const struct sockaddr_in *dst, kfree(packet); } -void handle_udp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, u32 packet_length) { +void handle_udp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, + u32 packet_length) { (void)dst_ip; if (packet_length < sizeof(u16[4])) { return; -- cgit v1.2.3