From e6c8f7298b40757a410d9df6319824c4f0d70351 Mon Sep 17 00:00:00 2001 From: Anton Kling Date: Wed, 28 Feb 2024 21:47:49 +0100 Subject: TCP/UDP: Start rewrite of network sockets Having sockets be file descriptors seems like a bad idea so I trying to make UDP and TCP sockets be more independent and not be abstracted away as much. --- kernel/network/tcp.c | 166 +++++++++++++++++++++++++++++++++++++-------------- kernel/network/tcp.h | 7 ++- kernel/network/udp.c | 42 +------------ 3 files changed, 127 insertions(+), 88 deletions(-) (limited to 'kernel/network') diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index f45d1af..717c7db 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -3,7 +3,6 @@ #include #include #include -#include extern u8 ip_address[4]; #define CWR (1 << 7) @@ -49,6 +48,16 @@ struct __attribute__((__packed__)) PSEUDO_TCP_HEADER { u16 urgent_pointer; }; +void tcp_wait_reply(struct TcpConnection *con) { + for (;;) { + if (con->unhandled_packet) { + return; + } + // TODO: Make the scheduler halt the process + switch_task(); + } +} + u16 tcp_checksum(u16 *buffer, int size) { unsigned long cksum = 0; while (size > 1) { @@ -64,11 +73,11 @@ u16 tcp_checksum(u16 *buffer, int size) { return (u16)(~cksum); } -void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload, +void tcp_calculate_checksum(u8 src_ip[4], u32 dst_ip, const u8 *payload, u16 payload_length, struct TCP_HEADER *header) { struct PSEUDO_TCP_HEADER ps = {0}; memcpy(&ps.src_addr, src_ip, sizeof(u32)); - memcpy(&ps.dst_addr, dst_ip, sizeof(u32)); + memcpy(&ps.dst_addr, &dst_ip, sizeof(u32)); ps.protocol = 6; ps.tcp_length = htons(20 + payload_length); ps.src_port = header->src_port; @@ -88,80 +97,145 @@ void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload, header->checksum = tcp_checksum((u16 *)buffer, buffer_length); } -void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) { +void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { struct TCP_HEADER header = {0}; - header.src_port = htons(inc->dst_port); - header.dst_port = inc->n_port; - header.seq_num = htonl(inc->seq_num); - header.ack_num = htonl(inc->ack_num); + header.src_port = htons(con->incoming_port); + header.dst_port = htons(con->outgoing_port); + header.seq_num = htonl(con->seq); + header.ack_num = htonl(con->ack); header.data_offset = 5; header.reserved = 0; - header.flags = FIN | ACK; + header.flags = flags; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); + u8 payload[0]; u16 payload_length = 0; - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, + tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, &header); int send_len = sizeof(header) + payload_length; u8 send_buffer[send_len]; memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + + send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); +} + +void tcp_send_ack(struct TcpConnection *con) { + tcp_send_empty_payload(con, ACK); } -void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload, +void tcp_send_syn(struct TcpConnection *con) { + tcp_send_empty_payload(con, SYN); + con->seq++; +} + +// void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, const u8 *payload, +// u16 payload_length) { +void send_tcp_packet(struct TcpConnection *con, const u8 *payload, u16 payload_length) { if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) { - send_tcp_packet(inc, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); + send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER); payload += 1500 - 20 - sizeof(struct TCP_HEADER); - return send_tcp_packet(inc, payload, payload_length); + return send_tcp_packet(con, payload, payload_length); } struct TCP_HEADER header = {0}; - header.src_port = htons(inc->dst_port); - header.dst_port = inc->n_port; - header.seq_num = htonl(inc->seq_num); - header.ack_num = htonl(inc->ack_num); + header.src_port = htons(con->incoming_port); + header.dst_port = htons(con->outgoing_port); + header.seq_num = htonl(con->seq); + header.ack_num = htonl(con->ack); header.data_offset = 5; header.reserved = 0; header.flags = PSH | ACK; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, + tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, &header); int send_len = sizeof(header) + payload_length; u8 send_buffer[send_len]; memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); - inc->seq_num += payload_length; + con->seq += payload_length; } -void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, u8 flags, - u32 inc_seq_num, u16 n_dst_port, u16 n_src_port) { - struct TCP_HEADER header = {0}; - header.src_port = n_dst_port; - header.dst_port = n_src_port; - header.seq_num = 0; - header.ack_num = htonl(inc_seq_num + 1); - header.data_offset = 5; - header.reserved = 0; - header.flags = flags; - header.window_size = htons(WINDOW_SIZE); - header.urgent_pointer = 0; - char payload[0]; - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, 0, &header); - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); - send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header)); -} +void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { + const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload; + (void)header; + u16 n_src_port = *(u16 *)(payload); + u16 n_dst_port = *(u16 *)(payload + 2); + u32 n_seq_num = *(u32 *)(payload + 4); + u32 n_ack_num = *(u32 *)(payload + 8); + + // u8 flags = *(payload + 13); + u8 flags = header->flags; + u16 src_port = htons(n_src_port); + (void)src_port; + u16 dst_port = htons(n_dst_port); + u32 seq_num = htonl(n_seq_num); + u32 ack_num = htonl(n_ack_num); + (void)ack_num; + + if (SYN == flags) { + u32 t; + memcpy(&t, src_ip, sizeof(u8[4])); + struct TcpConnection *con = internal_tcp_incoming(t, src_port, 0, dst_port); + assert(con); + con->ack = seq_num + 1; + tcp_send_empty_payload(con, SYN | ACK); + return; + } + + struct TcpConnection *incoming_connection = tcp_find_connection(dst_port); + if (incoming_connection) { + incoming_connection->unhandled_packet = 1; + if (0 != (flags & RST)) { + klog("Requested port is closed", LOG_NOTE); + incoming_connection->dead = 1; + return; + } + if (ACK == flags) { + if (0 == incoming_connection->handshake_state) { + // Then it is probably a response to the SYN|ACK we sent. + incoming_connection->handshake_state = 1; + return; + } + } + if ((SYN | ACK) == flags) { + assert(0 == incoming_connection->handshake_state); + incoming_connection->handshake_state = 1; + + incoming_connection->ack = seq_num + 1; + + tcp_send_ack(incoming_connection); + } + if (0 != (flags & PSH)) { + u16 tcp_payload_length = + payload_length - header->data_offset * sizeof(u32); + int len = fifo_object_write( + (u8 *)(payload + header->data_offset * sizeof(u32)), 0, + tcp_payload_length, incoming_connection->data_file); + assert(len >= 0); + incoming_connection->ack += len; + tcp_send_ack(incoming_connection); + } + if (0 != (flags & FIN)) { + incoming_connection->ack++; + + tcp_send_empty_payload(incoming_connection, FIN | ACK); + + incoming_connection->dead = 1; // FIXME: It should wait for a ACK + // of the FIN before the connection + // is closed. + } + } else { + assert(NULL); + } +} +/* void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload; u16 n_src_port = *(u16 *)(payload); @@ -233,11 +307,11 @@ void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { header.flags = ACK; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - char payload[0]; - tcp_calculate_checksum(ip_address, src_ip, (const u8 *)payload, 0, &header); u32 dst_ip; memcpy(&dst_ip, src_ip, sizeof(dst_ip)); + char payload[0]; + tcp_calculate_checksum(ip_address, dst_ip, (const u8 *)payload, 0, &header); send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header)); return; } -} +}*/ diff --git a/kernel/network/tcp.h b/kernel/network/tcp.h index 2a836a4..0f9e818 100644 --- a/kernel/network/tcp.h +++ b/kernel/network/tcp.h @@ -1,4 +1,7 @@ +#include +void tcp_send_syn(struct TcpConnection *con); +void tcp_wait_reply(struct TcpConnection *con); void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length); -void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload, +void send_tcp_packet(struct TcpConnection *con, const u8 *payload, u16 payload_length); -void tcp_close_connection(struct INCOMING_TCP_CONNECTION *s); + void tcp_close_connection(struct INCOMING_TCP_CONNECTION * s); diff --git a/kernel/network/udp.c b/kernel/network/udp.c index 5aae050..4f3848a 100644 --- a/kernel/network/udp.c +++ b/kernel/network/udp.c @@ -20,44 +20,6 @@ void send_udp_packet(struct sockaddr_in *src, const struct sockaddr_in *dst, } void handle_udp(u8 src_ip[4], const u8 *payload, u32 packet_length) { - assert(packet_length >= 8); - // n_.* means network format(big endian) - // h_.* means host format((probably) little endian) - u16 n_source_port = *(u16 *)payload; - u16 h_source_port = ntohs(n_source_port); - (void)h_source_port; - u16 h_dst_port = ntohs(*(u16 *)(payload + 2)); - u16 h_length = ntohs(*(u16 *)(payload + 4)); - assert(h_length == packet_length); - u16 data_length = h_length - 8; - const u8 *data = payload + 8; - - // Find the open port - OPEN_INET_SOCKET *in_s = find_open_udp_port(htons(h_dst_port)); - assert(in_s); - SOCKET *s = in_s->s; - vfs_fd_t *fifo_file = s->ptr_socket_fd; - - // Write the sockaddr struct such that it can later be - // given to userland if asked. - struct sockaddr_in /*{ - sa_family_t sin_family; - union { - u32 s_addr; - } sin_addr; - u16 sin_port; - }*/ in; - in.sin_family = AF_INET; - memcpy(&in.sin_addr.s_addr, src_ip, sizeof(u32)); - in.sin_port = n_source_port; - socklen_t sock_length = sizeof(struct sockaddr_in); - - raw_vfs_pwrite(fifo_file, &sock_length, sizeof(sock_length), 0); - raw_vfs_pwrite(fifo_file, &in, sizeof(in), 0); - - // Write the UDP payload length(not including header) - raw_vfs_pwrite(fifo_file, &data_length, sizeof(u16), 0); - - // Write the UDP payload - raw_vfs_pwrite(fifo_file, (char *)data, data_length, 0); + // TODO: Reimplement + assert(NULL); } -- cgit v1.2.3