diff options
Diffstat (limited to 'kernel/network')
-rw-r--r-- | kernel/network/arp.c | 2 | ||||
-rw-r--r-- | kernel/network/arp.h | 1 | ||||
-rw-r--r-- | kernel/network/ipv4.c | 10 | ||||
-rw-r--r-- | kernel/network/tcp.c | 236 | ||||
-rw-r--r-- | kernel/network/tcp.h | 5 |
5 files changed, 251 insertions, 3 deletions
diff --git a/kernel/network/arp.c b/kernel/network/arp.c index de6d898..179f6c1 100644 --- a/kernel/network/arp.c +++ b/kernel/network/arp.c @@ -91,7 +91,6 @@ int get_mac_from_ip(const uint8_t ip[4], uint8_t mac[6]) { return 1; } klog("ARP cache miss", LOG_NOTE); - asm("sti"); send_arp_request(ip); // TODO: Maybe wait a bit? for (int i = 0; i < 10; i++) { @@ -100,7 +99,6 @@ int get_mac_from_ip(const uint8_t ip[4], uint8_t mac[6]) { memcpy(mac, arp_table[i].mac, sizeof(uint8_t[6])); return 1; } - assert(0); return 0; } diff --git a/kernel/network/arp.h b/kernel/network/arp.h index c2beb94..1e42dff 100644 --- a/kernel/network/arp.h +++ b/kernel/network/arp.h @@ -1,4 +1,5 @@ #include <stdint.h> int get_mac_from_ip(const uint8_t ip[4], uint8_t mac[6]); +void send_arp_request(const uint8_t ip[4]); void handle_arp(const uint8_t *payload); diff --git a/kernel/network/ipv4.c b/kernel/network/ipv4.c index 099aa0d..7f480a3 100644 --- a/kernel/network/ipv4.c +++ b/kernel/network/ipv4.c @@ -1,9 +1,11 @@ #include <assert.h> +#include <drivers/pit.h> #include <kmalloc.h> #include <network/arp.h> #include <network/bytes.h> #include <network/ethernet.h> #include <network/ipv4.h> +#include <network/tcp.h> #include <network/udp.h> #include <string.h> @@ -59,8 +61,11 @@ void send_ipv4_packet(uint32_t ip, uint8_t protocol, const uint8_t *payload, uint8_t mac[6]; uint8_t ip_copy[4]; // TODO: Do I need to do this? memcpy(ip_copy, &ip, sizeof(uint8_t[4])); - get_mac_from_ip(ip_copy, mac); + for (; !get_mac_from_ip(ip_copy, mac);) + ; + kprintf("pre send_ethernet: %x\n", pit_num_ms()); send_ethernet_packet(mac, 0x0800, packet, packet_length); + kprintf("after send_ethernet: %x\n", pit_num_ms()); kfree(packet); } @@ -88,6 +93,9 @@ void handle_ipv4(const uint8_t *payload, uint32_t packet_length) { uint8_t protocol = *(payload + 9); switch (protocol) { + case 0x6: + handle_tcp(src_ip, payload + 20, ipv4_total_length - 20); + break; case 0x11: handle_udp(src_ip, payload + 20, ipv4_total_length - 20); break; diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c new file mode 100644 index 0000000..c20f046 --- /dev/null +++ b/kernel/network/tcp.c @@ -0,0 +1,236 @@ +#include <assert.h> +#include <drivers/pit.h> +#include <network/bytes.h> +#include <network/ipv4.h> +#include <network/udp.h> +#include <socket.h> +extern uint8_t ip_address[4]; + +#define CWR (1 << 7) +#define ECE (1 << 6) +#define URG (1 << 5) +#define ACK (1 << 4) +#define PSH (1 << 3) +#define RST (1 << 2) +#define SYN (1 << 1) +#define FIN (1 << 0) + +struct __attribute__((__packed__)) TCP_HEADER { + uint16_t src_port; + uint16_t dst_port; + uint32_t seq_num; + uint32_t ack_num; + uint8_t reserved : 4; + uint8_t data_offset : 4; + uint8_t flags; + uint16_t window_size; + uint16_t checksum; + uint16_t urgent_pointer; +}; + +struct __attribute__((__packed__)) PSEUDO_TCP_HEADER { + uint32_t src_addr; + uint32_t dst_addr; + uint8_t zero; + uint8_t protocol; + uint16_t tcp_length; + uint16_t src_port; + uint16_t dst_port; + uint32_t seq_num; + uint32_t ack_num; + uint8_t reserved : 4; + uint8_t data_offset : 4; + uint8_t flags; + uint16_t window_size; + uint16_t checksum_zero; // ????? + uint16_t urgent_pointer; +}; + +uint16_t tcp_checksum(uint16_t *buffer, int size) { + unsigned long cksum = 0; + while (size > 1) { + cksum += *buffer++; + size -= sizeof(uint16_t); + } + if (size) + cksum += *(uint8_t *)buffer; + + cksum = (cksum >> 16) + (cksum & 0xffff); + cksum += (cksum >> 16); + return (uint16_t)(~cksum); +} + +void tcp_calculate_checksum(uint8_t src_ip[4], uint8_t dst_ip[4], + const uint8_t *payload, uint16_t payload_length, + struct TCP_HEADER *header) { + struct PSEUDO_TCP_HEADER ps = {0}; + memcpy(&ps.src_addr, src_ip, sizeof(uint32_t)); + memcpy(&ps.dst_addr, dst_ip, sizeof(uint32_t)); + ps.protocol = 6; + ps.tcp_length = htons(20 + payload_length); + ps.src_port = header->src_port; + ps.dst_port = header->dst_port; + ps.seq_num = header->seq_num; + ps.ack_num = header->ack_num; + ps.data_offset = header->data_offset; + ps.reserved = header->reserved; + ps.flags = header->flags; + ps.window_size = header->window_size; + ps.urgent_pointer = header->urgent_pointer; + // ps.options = 0; + int buffer_length = sizeof(ps) + payload_length; + uint8_t buffer[buffer_length]; + memcpy(buffer, &ps, sizeof(ps)); + memcpy(buffer + sizeof(ps), payload, payload_length); + header->checksum = tcp_checksum((uint16_t *)buffer, buffer_length); +} + +void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) { + struct TCP_HEADER header = {0}; + header.src_port = htons(inc->dst_port); + header.dst_port = inc->n_port; + header.seq_num = htonl(inc->seq_num); + header.ack_num = htonl(inc->ack_num); + header.data_offset = 5; + header.reserved = 0; + header.flags = FIN | ACK; + header.window_size = htons(512); + header.urgent_pointer = 0; + uint32_t dst_ip; + memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); + uint8_t payload[0]; + uint16_t payload_length = 0; + tcp_calculate_checksum(ip_address, inc->ip, (const uint8_t *)payload, + payload_length, &header); + int send_len = sizeof(header) + payload_length; + uint8_t send_buffer[send_len]; + memcpy(send_buffer, &header, sizeof(header)); + memcpy(send_buffer + sizeof(header), payload, payload_length); + send_ipv4_packet(dst_ip, 6, send_buffer, send_len); +} + +void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, uint8_t *payload, + uint16_t payload_length) { + struct TCP_HEADER header = {0}; + header.src_port = htons(inc->dst_port); + header.dst_port = inc->n_port; + header.seq_num = htonl(inc->seq_num); + header.ack_num = htonl(inc->ack_num); + header.data_offset = 5; + header.reserved = 0; + header.flags = PSH | ACK; + header.window_size = htons(512); + header.urgent_pointer = 0; + uint32_t dst_ip; + memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); + tcp_calculate_checksum(ip_address, inc->ip, (const uint8_t *)payload, + payload_length, &header); + int send_len = sizeof(header) + payload_length; + uint8_t send_buffer[send_len]; + memcpy(send_buffer, &header, sizeof(header)); + memcpy(send_buffer + sizeof(header), payload, payload_length); + send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + + inc->seq_num += payload_length; +} + +void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, uint8_t flags, + uint32_t inc_seq_num, uint16_t n_dst_port, + uint16_t n_src_port) { + struct TCP_HEADER header = {0}; + header.src_port = n_dst_port; + header.dst_port = n_src_port; + header.seq_num = 0; + header.ack_num = htonl(inc_seq_num + 1); + header.data_offset = 5; + header.reserved = 0; + header.flags = flags; + header.window_size = htons(512); // TODO: What does this actually mean? + header.urgent_pointer = 0; + char payload[0]; + tcp_calculate_checksum(ip_address, inc->ip, (const uint8_t *)payload, 0, + &header); + uint32_t dst_ip; + memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); + send_ipv4_packet(dst_ip, 6, (const uint8_t *)&header, sizeof(header)); +} + +void handle_tcp(uint8_t src_ip[4], const uint8_t *payload, + uint32_t payload_length) { + const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload; + uint16_t n_src_port = *(uint16_t *)(payload); + uint16_t n_dst_port = *(uint16_t *)(payload + 2); + uint32_t n_seq_num = *(uint32_t *)(payload + 4); + uint32_t n_ack_num = *(uint32_t *)(payload + 8); + + uint8_t flags = *(payload + 13); + + uint16_t src_port = htons(n_src_port); + (void)src_port; + uint16_t dst_port = htons(n_dst_port); + uint32_t seq_num = htonl(n_seq_num); + uint32_t ack_num = htonl(n_ack_num); + (void)ack_num; + + if (SYN == flags) { + kprintf("GOT SYN UPTIME: %d\n", pit_num_ms()); + struct INCOMING_TCP_CONNECTION *inc; + if (!(inc = handle_incoming_tcp_connection(src_ip, n_src_port, dst_port))) + return; + memcpy(inc->ip, src_ip, sizeof(uint8_t[4])); + inc->seq_num = 0; + inc->ack_num = seq_num + 1; + inc->connection_closed = 0; + inc->requesting_connection_close = 0; + send_empty_tcp_message(inc, SYN | ACK, seq_num, n_dst_port, n_src_port); + inc->seq_num++; + return; + } + struct INCOMING_TCP_CONNECTION *inc = + get_incoming_tcp_connection(src_ip, n_src_port); + if (!inc) + return; + if (flags == (FIN | ACK)) { + if (inc->requesting_connection_close) { + send_empty_tcp_message(inc, ACK, seq_num, n_dst_port, n_src_port); + inc->connection_closed = 1; + } else { + send_empty_tcp_message(inc, FIN | ACK, seq_num, n_dst_port, n_src_port); + } + inc->seq_num++; + inc->connection_closed = 1; + return; + } + if (flags & ACK) { + // inc->seq_num = ack_num; + } + if (flags & PSH) { + kprintf("send ipv4 packet: %x\n", pit_num_ms()); + uint16_t tcp_payload_length = + payload_length - inc_header->data_offset * sizeof(uint32_t); + fifo_object_write( + (uint8_t *)(payload + inc_header->data_offset * sizeof(uint32_t)), 0, + tcp_payload_length, inc->data_file); + + // Send back a ACK + struct TCP_HEADER header = {0}; + header.src_port = n_dst_port; + header.dst_port = n_src_port; + header.seq_num = htonl(inc->seq_num); + inc->ack_num = seq_num + tcp_payload_length; + header.ack_num = htonl(seq_num + tcp_payload_length); + header.data_offset = 5; + header.reserved = 0; + header.flags = ACK; + header.window_size = htons(512); // TODO: What does this actually mean? + header.urgent_pointer = 0; + char payload[0]; + tcp_calculate_checksum(ip_address, src_ip, (const uint8_t *)payload, 0, + &header); + uint32_t dst_ip; + memcpy(&dst_ip, src_ip, sizeof(dst_ip)); + send_ipv4_packet(dst_ip, 6, (const uint8_t *)&header, sizeof(header)); + kprintf("after ipv4 packet: %x\n", pit_num_ms()); + return; + } +} diff --git a/kernel/network/tcp.h b/kernel/network/tcp.h new file mode 100644 index 0000000..fbd6065 --- /dev/null +++ b/kernel/network/tcp.h @@ -0,0 +1,5 @@ +void handle_tcp(uint8_t src_ip[4], const uint8_t *payload, + uint32_t payload_length); +void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, uint8_t *payload, + uint16_t payload_length); +void tcp_close_connection(struct INCOMING_TCP_CONNECTION *s); |