diff options
Diffstat (limited to 'kernel/network/tcp.c')
-rw-r--r-- | kernel/network/tcp.c | 166 |
1 files changed, 120 insertions, 46 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index f45d1af..717c7db 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -3,7 +3,6 @@ #include <network/bytes.h> #include <network/ipv4.h> #include <network/udp.h> -#include <socket.h> extern u8 ip_address[4]; #define CWR (1 << 7) @@ -49,6 +48,16 @@ struct __attribute__((__packed__)) PSEUDO_TCP_HEADER { u16 urgent_pointer; }; +void tcp_wait_reply(struct TcpConnection *con) { + for (;;) { + if (con->unhandled_packet) { + return; + } + // TODO: Make the scheduler halt the process + switch_task(); + } +} + u16 tcp_checksum(u16 *buffer, int size) { unsigned long cksum = 0; while (size > 1) { @@ -64,11 +73,11 @@ u16 tcp_checksum(u16 *buffer, int size) { return (u16)(~cksum); } -void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload, +void tcp_calculate_checksum(u8 src_ip[4], u32 dst_ip, const u8 *payload, u16 payload_length, struct TCP_HEADER *header) { struct PSEUDO_TCP_HEADER ps = {0}; memcpy(&ps.src_addr, src_ip, sizeof(u32)); - memcpy(&ps.dst_addr, dst_ip, sizeof(u32)); + memcpy(&ps.dst_addr, &dst_ip, sizeof(u32)); ps.protocol = 6; ps.tcp_length = htons(20 + payload_length); ps.src_port = header->src_port; @@ -88,80 +97,145 @@ void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload, header->checksum = tcp_checksum((u16 *)buffer, buffer_length); } -void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) { +void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { struct TCP_HEADER header = {0}; - header.src_port = htons(inc->dst_port); - header.dst_port = inc->n_port; - header.seq_num = htonl(inc->seq_num); - header.ack_num = htonl(inc->ack_num); + header.src_port = htons(con->incoming_port); + header.dst_port = htons(con->outgoing_port); + header.seq_num = htonl(con->seq); + header.ack_num = htonl(con->ack); header.data_offset = 5; header.reserved = 0; - header.flags = FIN | ACK; + header.flags = flags; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); + u8 payload[0]; u16 payload_length = 0; - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, + tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, &header); int send_len = sizeof(header) + payload_length; u8 send_buffer[send_len]; memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + + send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); +} + +void tcp_send_ack(struct TcpConnection *con) { + tcp_send_empty_payload(con, ACK); } -void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload, +void tcp_send_syn(struct TcpConnection *con) { + tcp_send_empty_payload(con, SYN); + con->seq++; +} + +// void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, const u8 *payload, +// u16 payload_length) { +void send_tcp_packet(struct TcpConnection *con, const u8 *payload, u16 payload_length) { if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) { - send_tcp_packet(inc, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); + send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER); payload += 1500 - 20 - sizeof(struct TCP_HEADER); - return send_tcp_packet(inc, payload, payload_length); + return send_tcp_packet(con, payload, payload_length); } struct TCP_HEADER header = {0}; - header.src_port = htons(inc->dst_port); - header.dst_port = inc->n_port; - header.seq_num = htonl(inc->seq_num); - header.ack_num = htonl(inc->ack_num); + header.src_port = htons(con->incoming_port); + header.dst_port = htons(con->outgoing_port); + header.seq_num = htonl(con->seq); + header.ack_num = htonl(con->ack); header.data_offset = 5; header.reserved = 0; header.flags = PSH | ACK; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, + tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload, payload_length, &header); int send_len = sizeof(header) + payload_length; u8 send_buffer[send_len]; memcpy(send_buffer, &header, sizeof(header)); memcpy(send_buffer + sizeof(header), payload, payload_length); - send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); - inc->seq_num += payload_length; + con->seq += payload_length; } -void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, u8 flags, - u32 inc_seq_num, u16 n_dst_port, u16 n_src_port) { - struct TCP_HEADER header = {0}; - header.src_port = n_dst_port; - header.dst_port = n_src_port; - header.seq_num = 0; - header.ack_num = htonl(inc_seq_num + 1); - header.data_offset = 5; - header.reserved = 0; - header.flags = flags; - header.window_size = htons(WINDOW_SIZE); - header.urgent_pointer = 0; - char payload[0]; - tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, 0, &header); - u32 dst_ip; - memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); - send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header)); -} +void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { + const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload; + (void)header; + u16 n_src_port = *(u16 *)(payload); + u16 n_dst_port = *(u16 *)(payload + 2); + u32 n_seq_num = *(u32 *)(payload + 4); + u32 n_ack_num = *(u32 *)(payload + 8); + + // u8 flags = *(payload + 13); + u8 flags = header->flags; + u16 src_port = htons(n_src_port); + (void)src_port; + u16 dst_port = htons(n_dst_port); + u32 seq_num = htonl(n_seq_num); + u32 ack_num = htonl(n_ack_num); + (void)ack_num; + + if (SYN == flags) { + u32 t; + memcpy(&t, src_ip, sizeof(u8[4])); + struct TcpConnection *con = internal_tcp_incoming(t, src_port, 0, dst_port); + assert(con); + con->ack = seq_num + 1; + tcp_send_empty_payload(con, SYN | ACK); + return; + } + + struct TcpConnection *incoming_connection = tcp_find_connection(dst_port); + if (incoming_connection) { + incoming_connection->unhandled_packet = 1; + if (0 != (flags & RST)) { + klog("Requested port is closed", LOG_NOTE); + incoming_connection->dead = 1; + return; + } + if (ACK == flags) { + if (0 == incoming_connection->handshake_state) { + // Then it is probably a response to the SYN|ACK we sent. + incoming_connection->handshake_state = 1; + return; + } + } + if ((SYN | ACK) == flags) { + assert(0 == incoming_connection->handshake_state); + incoming_connection->handshake_state = 1; + + incoming_connection->ack = seq_num + 1; + + tcp_send_ack(incoming_connection); + } + if (0 != (flags & PSH)) { + u16 tcp_payload_length = + payload_length - header->data_offset * sizeof(u32); + int len = fifo_object_write( + (u8 *)(payload + header->data_offset * sizeof(u32)), 0, + tcp_payload_length, incoming_connection->data_file); + assert(len >= 0); + incoming_connection->ack += len; + tcp_send_ack(incoming_connection); + } + if (0 != (flags & FIN)) { + incoming_connection->ack++; + + tcp_send_empty_payload(incoming_connection, FIN | ACK); + + incoming_connection->dead = 1; // FIXME: It should wait for a ACK + // of the FIN before the connection + // is closed. + } + } else { + assert(NULL); + } +} +/* void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload; u16 n_src_port = *(u16 *)(payload); @@ -233,11 +307,11 @@ void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { header.flags = ACK; header.window_size = htons(WINDOW_SIZE); header.urgent_pointer = 0; - char payload[0]; - tcp_calculate_checksum(ip_address, src_ip, (const u8 *)payload, 0, &header); u32 dst_ip; memcpy(&dst_ip, src_ip, sizeof(dst_ip)); + char payload[0]; + tcp_calculate_checksum(ip_address, dst_ip, (const u8 *)payload, 0, &header); send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header)); return; } -} +}*/ |