summaryrefslogtreecommitdiff
path: root/kernel/network/tcp.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/network/tcp.c')
-rw-r--r--kernel/network/tcp.c166
1 files changed, 120 insertions, 46 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c
index f45d1af..717c7db 100644
--- a/kernel/network/tcp.c
+++ b/kernel/network/tcp.c
@@ -3,7 +3,6 @@
#include <network/bytes.h>
#include <network/ipv4.h>
#include <network/udp.h>
-#include <socket.h>
extern u8 ip_address[4];
#define CWR (1 << 7)
@@ -49,6 +48,16 @@ struct __attribute__((__packed__)) PSEUDO_TCP_HEADER {
u16 urgent_pointer;
};
+void tcp_wait_reply(struct TcpConnection *con) {
+ for (;;) {
+ if (con->unhandled_packet) {
+ return;
+ }
+ // TODO: Make the scheduler halt the process
+ switch_task();
+ }
+}
+
u16 tcp_checksum(u16 *buffer, int size) {
unsigned long cksum = 0;
while (size > 1) {
@@ -64,11 +73,11 @@ u16 tcp_checksum(u16 *buffer, int size) {
return (u16)(~cksum);
}
-void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload,
+void tcp_calculate_checksum(u8 src_ip[4], u32 dst_ip, const u8 *payload,
u16 payload_length, struct TCP_HEADER *header) {
struct PSEUDO_TCP_HEADER ps = {0};
memcpy(&ps.src_addr, src_ip, sizeof(u32));
- memcpy(&ps.dst_addr, dst_ip, sizeof(u32));
+ memcpy(&ps.dst_addr, &dst_ip, sizeof(u32));
ps.protocol = 6;
ps.tcp_length = htons(20 + payload_length);
ps.src_port = header->src_port;
@@ -88,80 +97,145 @@ void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload,
header->checksum = tcp_checksum((u16 *)buffer, buffer_length);
}
-void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) {
+void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {
struct TCP_HEADER header = {0};
- header.src_port = htons(inc->dst_port);
- header.dst_port = inc->n_port;
- header.seq_num = htonl(inc->seq_num);
- header.ack_num = htonl(inc->ack_num);
+ header.src_port = htons(con->incoming_port);
+ header.dst_port = htons(con->outgoing_port);
+ header.seq_num = htonl(con->seq);
+ header.ack_num = htonl(con->ack);
header.data_offset = 5;
header.reserved = 0;
- header.flags = FIN | ACK;
+ header.flags = flags;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
+
u8 payload[0];
u16 payload_length = 0;
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload,
+ tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
payload_length, &header);
int send_len = sizeof(header) + payload_length;
u8 send_buffer[send_len];
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet(dst_ip, 6, send_buffer, send_len);
+
+ send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len);
+}
+
+void tcp_send_ack(struct TcpConnection *con) {
+ tcp_send_empty_payload(con, ACK);
}
-void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload,
+void tcp_send_syn(struct TcpConnection *con) {
+ tcp_send_empty_payload(con, SYN);
+ con->seq++;
+}
+
+// void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, const u8 *payload,
+// u16 payload_length) {
+void send_tcp_packet(struct TcpConnection *con, const u8 *payload,
u16 payload_length) {
if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) {
- send_tcp_packet(inc, payload, 1500 - 20 - sizeof(struct TCP_HEADER));
+ send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER));
payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER);
payload += 1500 - 20 - sizeof(struct TCP_HEADER);
- return send_tcp_packet(inc, payload, payload_length);
+ return send_tcp_packet(con, payload, payload_length);
}
struct TCP_HEADER header = {0};
- header.src_port = htons(inc->dst_port);
- header.dst_port = inc->n_port;
- header.seq_num = htonl(inc->seq_num);
- header.ack_num = htonl(inc->ack_num);
+ header.src_port = htons(con->incoming_port);
+ header.dst_port = htons(con->outgoing_port);
+ header.seq_num = htonl(con->seq);
+ header.ack_num = htonl(con->ack);
header.data_offset = 5;
header.reserved = 0;
header.flags = PSH | ACK;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload,
+ tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,
payload_length, &header);
int send_len = sizeof(header) + payload_length;
u8 send_buffer[send_len];
memcpy(send_buffer, &header, sizeof(header));
memcpy(send_buffer + sizeof(header), payload, payload_length);
- send_ipv4_packet(dst_ip, 6, send_buffer, send_len);
+ send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len);
- inc->seq_num += payload_length;
+ con->seq += payload_length;
}
-void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, u8 flags,
- u32 inc_seq_num, u16 n_dst_port, u16 n_src_port) {
- struct TCP_HEADER header = {0};
- header.src_port = n_dst_port;
- header.dst_port = n_src_port;
- header.seq_num = 0;
- header.ack_num = htonl(inc_seq_num + 1);
- header.data_offset = 5;
- header.reserved = 0;
- header.flags = flags;
- header.window_size = htons(WINDOW_SIZE);
- header.urgent_pointer = 0;
- char payload[0];
- tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, 0, &header);
- u32 dst_ip;
- memcpy(&dst_ip, inc->ip, sizeof(dst_ip));
- send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header));
-}
+void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
+ const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload;
+ (void)header;
+ u16 n_src_port = *(u16 *)(payload);
+ u16 n_dst_port = *(u16 *)(payload + 2);
+ u32 n_seq_num = *(u32 *)(payload + 4);
+ u32 n_ack_num = *(u32 *)(payload + 8);
+
+ // u8 flags = *(payload + 13);
+ u8 flags = header->flags;
+ u16 src_port = htons(n_src_port);
+ (void)src_port;
+ u16 dst_port = htons(n_dst_port);
+ u32 seq_num = htonl(n_seq_num);
+ u32 ack_num = htonl(n_ack_num);
+ (void)ack_num;
+
+ if (SYN == flags) {
+ u32 t;
+ memcpy(&t, src_ip, sizeof(u8[4]));
+ struct TcpConnection *con = internal_tcp_incoming(t, src_port, 0, dst_port);
+ assert(con);
+ con->ack = seq_num + 1;
+ tcp_send_empty_payload(con, SYN | ACK);
+ return;
+ }
+
+ struct TcpConnection *incoming_connection = tcp_find_connection(dst_port);
+ if (incoming_connection) {
+ incoming_connection->unhandled_packet = 1;
+ if (0 != (flags & RST)) {
+ klog("Requested port is closed", LOG_NOTE);
+ incoming_connection->dead = 1;
+ return;
+ }
+ if (ACK == flags) {
+ if (0 == incoming_connection->handshake_state) {
+ // Then it is probably a response to the SYN|ACK we sent.
+ incoming_connection->handshake_state = 1;
+ return;
+ }
+ }
+ if ((SYN | ACK) == flags) {
+ assert(0 == incoming_connection->handshake_state);
+ incoming_connection->handshake_state = 1;
+
+ incoming_connection->ack = seq_num + 1;
+
+ tcp_send_ack(incoming_connection);
+ }
+ if (0 != (flags & PSH)) {
+ u16 tcp_payload_length =
+ payload_length - header->data_offset * sizeof(u32);
+ int len = fifo_object_write(
+ (u8 *)(payload + header->data_offset * sizeof(u32)), 0,
+ tcp_payload_length, incoming_connection->data_file);
+ assert(len >= 0);
+ incoming_connection->ack += len;
+ tcp_send_ack(incoming_connection);
+ }
+ if (0 != (flags & FIN)) {
+ incoming_connection->ack++;
+
+ tcp_send_empty_payload(incoming_connection, FIN | ACK);
+
+ incoming_connection->dead = 1; // FIXME: It should wait for a ACK
+ // of the FIN before the connection
+ // is closed.
+ }
+ } else {
+ assert(NULL);
+ }
+}
+/*
void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload;
u16 n_src_port = *(u16 *)(payload);
@@ -233,11 +307,11 @@ void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {
header.flags = ACK;
header.window_size = htons(WINDOW_SIZE);
header.urgent_pointer = 0;
- char payload[0];
- tcp_calculate_checksum(ip_address, src_ip, (const u8 *)payload, 0, &header);
u32 dst_ip;
memcpy(&dst_ip, src_ip, sizeof(dst_ip));
+ char payload[0];
+ tcp_calculate_checksum(ip_address, dst_ip, (const u8 *)payload, 0, &header);
send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header));
return;
}
-}
+}*/