summaryrefslogtreecommitdiff
path: root/kernel/network
diff options
context:
space:
mode:
authorAnton Kling <anton@kling.gg>2024-02-28 21:47:49 +0100
committerAnton Kling <anton@kling.gg>2024-02-28 21:47:49 +0100
commite6c8f7298b40757a410d9df6319824c4f0d70351 (patch)
treeb90ee0eba9a45c7551d9f23b6e66620ff0ea5b66 /kernel/network
parent4536dc81b4be9a62328826455664cd6d696df8fb (diff)
TCP/UDP: Start rewrite of network sockets
Having sockets be file descriptors seems like a bad idea so I trying to make UDP and TCP sockets be more independent and not be abstracted away as much.
Diffstat (limited to 'kernel/network')
-rw-r--r--kernel/network/tcp.c166
-rw-r--r--kernel/network/tcp.h7
-rw-r--r--kernel/network/udp.c42
3 files changed, 127 insertions, 88 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c
index f45d1af..717c7db 100644
--- a/kernel/network/tcp.c
+++ b/kernel/network/tcp.c
@@ -3,7 +3,6 @@
#include <network/bytes.h>
#include <network/ipv4.h>
#include <network/udp.h>
-#include <socket.h>
extern u8 ip_address[4];
#define CWR (1 << 7)
@@ -49,6 +48,16 @@ struct __attribute__((__packed__)) PSEUDO_TCP_HEADER {
u16 urgent_pointer;
};
+void tcp_wait_reply(struct TcpConnection *con) {
+ for (;;) {
+ if (con->unhandled_packet) {
+ return;
+ }
+ // TODO: Make the scheduler halt the process
+ switch_task();
+ }
+}
+
u16 tcp_checksum(u16 *buffer, int size) {
unsigned long cksum = 0;
while (size > 1) {
@@ -64,11 +73,11 @@ u16 tcp_checksum(u16 *buffer, int size) {
return (u16)(~cksum);
}
-void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload,
+void tcp_calculate_checksum(u8 src_ip[4], u32 dst_ip, const u8 *payload,
u16 payload_length, struct TCP_HEADER *header) {
struct PSEUDO_TCP_HEADER ps = {0};
memcpy(&ps.src_addr, src_ip, sizeof(u32));
- memcpy(&ps.dst_addr, dst_ip, sizeof(u32));
+ memcpy(&ps.dst_addr, &dst_ip, sizeof(u32));
ps.protocol = 6;
ps.tcp_length = htons(20 + payload_length);
ps.src_port = header->src_port;
@@ -88,80 +97,145 @@ void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload,
header->checksum = tcp_checksum((u16 *)buffer, buffer_length);
}
-void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) {
+void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
struct TCP_HEADER header = {0};
- header.src_port = htons(inc->dst_port);
- header.dst_port = inc->n_port;
- header.seq_num = htonl(inc->seq_num);
- header.ack_num = htonl(inc->ack_num);
+ header.src_port = htons(con->incoming_port);
+ header.dst_port = htons(con->outgoing_port);
+ header.seq_num = htonl(con->seq);
+ header.ack_num = htonl(con->ack);
header.data_offset = 5;
header.reserved = 0;
- header.flags = FIN | ACK;
+ header.flags = flags;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
+
u8 payload[0];
u16 payload_length = 0;
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload,
+ tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
payload_length, &header);
int send_len = sizeof(header) + payload_length;
u8 send_buffer[send_len];
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet(dst_ip, 6, send_buffer, send_len);
+
+ send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len);
+}
+
+void tcp_send_ack(struct TcpConnection *con) {
+ tcp_send_empty_payload(con, ACK);
}
-void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload,
+void tcp_send_syn(struct TcpConnection *con) {
+ tcp_send_empty_payload(con, SYN);
+ con->seq++;
+}
+
+// void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, const u8 *payload,
+// u16 payload_length) {
+void send_tcp_packet(struct TcpConnection *con, const u8 *payload,
u16 payload_length) {
if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) {
- send_tcp_packet(inc, payload, 1500 - 20 - sizeof(struct TCP_HEADER));
+ send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER));
payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER);
payload += 1500 - 20 - sizeof(struct TCP_HEADER);
- return send_tcp_packet(inc, payload, payload_length);
+ return send_tcp_packet(con, payload, payload_length);
}
struct TCP_HEADER header = {0};
- header.src_port = htons(inc->dst_port);
- header.dst_port = inc->n_port;
- header.seq_num = htonl(inc->seq_num);
- header.ack_num = htonl(inc->ack_num);
+ header.src_port = htons(con->incoming_port);
+ header.dst_port = htons(con->outgoing_port);
+ header.seq_num = htonl(con->seq);
+ header.ack_num = htonl(con->ack);
header.data_offset = 5;
header.reserved = 0;
header.flags = PSH | ACK;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload,
+ tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
payload_length, &header);
int send_len = sizeof(header) + payload_length;
u8 send_buffer[send_len];
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet(dst_ip, 6, send_buffer, send_len);
+ send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len);
- inc->seq_num += payload_length;
+ con->seq += payload_length;
}
-void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, u8 flags,
- u32 inc_seq_num, u16 n_dst_port, u16 n_src_port) {
- struct TCP_HEADER header = {0};
- header.src_port = n_dst_port;
- header.dst_port = n_src_port;
- header.seq_num = 0;
- header.ack_num = htonl(inc_seq_num + 1);
- header.data_offset = 5;
- header.reserved = 0;
- header.flags = flags;
- header.window_size = htons(WINDOW_SIZE);
- header.urgent_pointer = 0;
- char payload[0];
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, 0, &header);
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
- send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header));
-}
+void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
+ const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload;
+ (void)header;
+ u16 n_src_port = *(u16 *)(payload);
+ u16 n_dst_port = *(u16 *)(payload + 2);
+ u32 n_seq_num = *(u32 *)(payload + 4);
+ u32 n_ack_num = *(u32 *)(payload + 8);
+
+ // u8 flags = *(payload + 13);
+ u8 flags = header->flags;
+ u16 src_port = htons(n_src_port);
+ (void)src_port;
+ u16 dst_port = htons(n_dst_port);
+ u32 seq_num = htonl(n_seq_num);
+ u32 ack_num = htonl(n_ack_num);
+ (void)ack_num;
+
+ if (SYN == flags) {
+ u32 t;
+ memcpy(&t, src_ip, sizeof(u8[4]));
+ struct TcpConnection *con = internal_tcp_incoming(t, src_port, 0, dst_port);
+ assert(con);
+ con->ack = seq_num + 1;
+ tcp_send_empty_payload(con, SYN | ACK);
+ return;
+ }
+
+ struct TcpConnection *incoming_connection = tcp_find_connection(dst_port);
+ if (incoming_connection) {
+ incoming_connection->unhandled_packet = 1;
+ if (0 != (flags & RST)) {
+ klog("Requested port is closed", LOG_NOTE);
+ incoming_connection->dead = 1;
+ return;
+ }
+ if (ACK == flags) {
+ if (0 == incoming_connection->handshake_state) {
+ // Then it is probably a response to the SYN|ACK we sent.
+ incoming_connection->handshake_state = 1;
+ return;
+ }
+ }
+ if ((SYN | ACK) == flags) {
+ assert(0 == incoming_connection->handshake_state);
+ incoming_connection->handshake_state = 1;
+
+ incoming_connection->ack = seq_num + 1;
+
+ tcp_send_ack(incoming_connection);
+ }
+ if (0 != (flags & PSH)) {
+ u16 tcp_payload_length =
+ payload_length - header->data_offset * sizeof(u32);
+ int len = fifo_object_write(
+ (u8 *)(payload + header->data_offset * sizeof(u32)), 0,
+ tcp_payload_length, incoming_connection->data_file);
+ assert(len >= 0);
+ incoming_connection->ack += len;
+ tcp_send_ack(incoming_connection);
+ }
+ if (0 != (flags & FIN)) {
+ incoming_connection->ack++;
+
+ tcp_send_empty_payload(incoming_connection, FIN | ACK);
+
+ incoming_connection->dead = 1; // FIXME: It should wait for a ACK
+ // of the FIN before the connection
+ // is closed.
+ }
+ } else {
+ assert(NULL);
+ }
+}
+/*
void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload;
u16 n_src_port = *(u16 *)(payload);
@@ -233,11 +307,11 @@ void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
header.flags = ACK;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- char payload[0];
- tcp_calculate_checksum(ip_address, src_ip, (const u8 *)payload, 0, &header);
u32 dst_ip;
memcpy(&dst_ip, src_ip, sizeof(dst_ip));
+ char payload[0];
+ tcp_calculate_checksum(ip_address, dst_ip, (const u8 *)payload, 0, &header);
send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header));
return;
}
-}
+}*/
diff --git a/kernel/network/tcp.h b/kernel/network/tcp.h
index 2a836a4..0f9e818 100644
--- a/kernel/network/tcp.h
+++ b/kernel/network/tcp.h
@@ -1,4 +1,7 @@
+#include <socket.h>
+void tcp_send_syn(struct TcpConnection *con);
+void tcp_wait_reply(struct TcpConnection *con);
void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length);
-void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload,
+void send_tcp_packet(struct TcpConnection *con, const u8 *payload,
u16 payload_length);
-void tcp_close_connection(struct INCOMING_TCP_CONNECTION *s);
+ void tcp_close_connection(struct INCOMING_TCP_CONNECTION * s);
diff --git a/kernel/network/udp.c b/kernel/network/udp.c
index 5aae050..4f3848a 100644
--- a/kernel/network/udp.c
+++ b/kernel/network/udp.c
@@ -20,44 +20,6 @@ void send_udp_packet(struct sockaddr_in *src, const struct sockaddr_in *dst,
}
void handle_udp(u8 src_ip[4], const u8 *payload, u32 packet_length) {
- assert(packet_length >= 8);
- // n_.* means network format(big endian)
- // h_.* means host format((probably) little endian)
- u16 n_source_port = *(u16 *)payload;
- u16 h_source_port = ntohs(n_source_port);
- (void)h_source_port;
- u16 h_dst_port = ntohs(*(u16 *)(payload + 2));
- u16 h_length = ntohs(*(u16 *)(payload + 4));
- assert(h_length == packet_length);
- u16 data_length = h_length - 8;
- const u8 *data = payload + 8;
-
- // Find the open port
- OPEN_INET_SOCKET *in_s = find_open_udp_port(htons(h_dst_port));
- assert(in_s);
- SOCKET *s = in_s->s;
- vfs_fd_t *fifo_file = s->ptr_socket_fd;
-
- // Write the sockaddr struct such that it can later be
- // given to userland if asked.
- struct sockaddr_in /*{
- sa_family_t sin_family;
- union {
- u32 s_addr;
- } sin_addr;
- u16 sin_port;
- }*/ in;
- in.sin_family = AF_INET;
- memcpy(&in.sin_addr.s_addr, src_ip, sizeof(u32));
- in.sin_port = n_source_port;
- socklen_t sock_length = sizeof(struct sockaddr_in);
-
- raw_vfs_pwrite(fifo_file, &sock_length, sizeof(sock_length), 0);
- raw_vfs_pwrite(fifo_file, &in, sizeof(in), 0);
-
- // Write the UDP payload length(not including header)
- raw_vfs_pwrite(fifo_file, &data_length, sizeof(u16), 0);
-
- // Write the UDP payload
- raw_vfs_pwrite(fifo_file, (char *)data, data_length, 0);
+ // TODO: Reimplement
+ assert(NULL);
}