diff options
-rw-r--r-- | kernel/network/tcp.c | 77 |
1 files changed, 43 insertions, 34 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index e6febba..5d11a93 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -13,7 +13,7 @@ #define MSS 536 -struct __attribute__((__packed__)) TCP_HEADER { +struct __attribute__((__packed__)) __attribute__((aligned(2))) TCP_HEADER { u16 src_port; u16 dst_port; u32 seq_num; @@ -26,19 +26,30 @@ struct __attribute__((__packed__)) TCP_HEADER { u16 urgent_pointer; }; -u16 tcp_checksum(u16 *buffer, u32 size) { - u32 cksum = 0; - for (; size > 1;) { - cksum += *buffer++; - size -= sizeof(u16); +static inline u32 tcp_checksum_even_update(u32 cksum, const u16 *buffer, + u32 size) { + assert(0 == (size % sizeof(u16))); + for (; size > 0; size -= 2) { + cksum += ntohs(*buffer); + cksum = (cksum >> 16) + (cksum & 0xFFFF); + buffer++; } - if (size) { - cksum += *(u8 *)buffer; + return cksum; +} + +static inline u16 tcp_checksum_final(u32 cksum, const u8 *buffer, u32 size) { + for (; size > 1; size -= 2) { + cksum += ntohs(*(u16 *)buffer); + cksum = (cksum >> 16) + (cksum & 0xFFFF); + buffer += 2; + } + + if (size > 0) { + cksum += *buffer << 8; + cksum = (cksum >> 16) + (cksum & 0xFFFF); } - cksum = (cksum >> 16) + (cksum & 0xffff); - cksum += (cksum >> 16); - return (u16)(~cksum); + return ~(cksum & 0xFFFF) & 0xFFFF; } u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, @@ -48,29 +59,27 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload, return 0; } - int pseudo = sizeof(u32) * 2 + sizeof(u16) + sizeof(u8) * 2; - - int buffer_length = - pseudo + header->data_offset * sizeof(u32) + payload_length; - u8 buffer[buffer_length]; - - u8 *ptr = buffer; - memcpy(ptr, &src_ip.d, sizeof(u32)); - ptr += sizeof(u32); - memcpy(ptr, &dst_ip, sizeof(u32)); - ptr += sizeof(u32); - *ptr = 0; - ptr += sizeof(u8); - *ptr = 6; - ptr += sizeof(u8); - *(u16 *)ptr = htons(header->data_offset * sizeof(u32) + payload_length); - ptr += sizeof(u16); - memcpy(ptr, header, header->data_offset * sizeof(u32)); - memset(ptr + 16, 0, sizeof(u16)); // set checksum to zero - ptr += header->data_offset * sizeof(u32); - memcpy(ptr, payload, payload_length); - - return tcp_checksum((u16 *)buffer, buffer_length); + u32 tmp = 0; + tmp = tcp_checksum_even_update(tmp, (u16 *)&src_ip.d, sizeof(u32)); + tmp = tcp_checksum_even_update(tmp, (u16 *)&dst_ip, sizeof(u32)); + u8 a[2] = {0, 6}; + tmp = tcp_checksum_even_update(tmp, (u16 *)a, sizeof(a)); + + u16 p = htons(header->data_offset * sizeof(u32) + payload_length); + tmp = tcp_checksum_even_update(tmp, &p, sizeof(u16)); + + // This just skips including the preset header checksum in the + // calculation. + uintptr_t checksum_location = + (uintptr_t)&header->checksum - (uintptr_t)header; + uintptr_t rest = + header->data_offset * sizeof(u32) - (checksum_location + sizeof(u16)); + + tmp = tcp_checksum_even_update(tmp, (u16 *)header, checksum_location); + tmp = tcp_checksum_even_update( + tmp, (u16 *)((u8 *)header + checksum_location + sizeof(u16)), rest); + + return htons(tcp_checksum_final(tmp, payload, payload_length)); } static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length, |