summaryrefslogtreecommitdiff
path: root/kernel/network
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/network')
-rw-r--r--kernel/network/tcp.c77
1 files changed, 43 insertions, 34 deletions
diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c
index e6febba..5d11a93 100644
--- a/kernel/network/tcp.c
+++ b/kernel/network/tcp.c
@@ -13,7 +13,7 @@
#define MSS 536
-struct __attribute__((__packed__)) TCP_HEADER {
+struct __attribute__((__packed__)) __attribute__((aligned(2))) TCP_HEADER {
u16 src_port;
u16 dst_port;
u32 seq_num;
@@ -26,19 +26,30 @@ struct __attribute__((__packed__)) TCP_HEADER {
u16 urgent_pointer;
};
-u16 tcp_checksum(u16 *buffer, u32 size) {
- u32 cksum = 0;
- for (; size > 1;) {
- cksum += *buffer++;
- size -= sizeof(u16);
+static inline u32 tcp_checksum_even_update(u32 cksum, const u16 *buffer,
+ u32 size) {
+ assert(0 == (size % sizeof(u16)));
+ for (; size > 0; size -= 2) {
+ cksum += ntohs(*buffer);
+ cksum = (cksum >> 16) + (cksum & 0xFFFF);
+ buffer++;
}
- if (size) {
- cksum += *(u8 *)buffer;
+ return cksum;
+}
+
+static inline u16 tcp_checksum_final(u32 cksum, const u8 *buffer, u32 size) {
+ for (; size > 1; size -= 2) {
+ cksum += ntohs(*(u16 *)buffer);
+ cksum = (cksum >> 16) + (cksum & 0xFFFF);
+ buffer += 2;
+ }
+
+ if (size > 0) {
+ cksum += *buffer << 8;
+ cksum = (cksum >> 16) + (cksum & 0xFFFF);
}
- cksum = (cksum >> 16) + (cksum & 0xffff);
- cksum += (cksum >> 16);
- return (u16)(~cksum);
+ return ~(cksum & 0xFFFF) & 0xFFFF;
}
u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
@@ -48,29 +59,27 @@ u16 tcp_calculate_checksum(ipv4_t src_ip, u32 dst_ip, const u8 *payload,
return 0;
}
- int pseudo = sizeof(u32) * 2 + sizeof(u16) + sizeof(u8) * 2;
-
- int buffer_length =
- pseudo + header->data_offset * sizeof(u32) + payload_length;
- u8 buffer[buffer_length];
-
- u8 *ptr = buffer;
- memcpy(ptr, &src_ip.d, sizeof(u32));
- ptr += sizeof(u32);
- memcpy(ptr, &dst_ip, sizeof(u32));
- ptr += sizeof(u32);
- *ptr = 0;
- ptr += sizeof(u8);
- *ptr = 6;
- ptr += sizeof(u8);
- *(u16 *)ptr = htons(header->data_offset * sizeof(u32) + payload_length);
- ptr += sizeof(u16);
- memcpy(ptr, header, header->data_offset * sizeof(u32));
- memset(ptr + 16, 0, sizeof(u16)); // set checksum to zero
- ptr += header->data_offset * sizeof(u32);
- memcpy(ptr, payload, payload_length);
-
- return tcp_checksum((u16 *)buffer, buffer_length);
+ u32 tmp = 0;
+ tmp = tcp_checksum_even_update(tmp, (u16 *)&src_ip.d, sizeof(u32));
+ tmp = tcp_checksum_even_update(tmp, (u16 *)&dst_ip, sizeof(u32));
+ u8 a[2] = {0, 6};
+ tmp = tcp_checksum_even_update(tmp, (u16 *)a, sizeof(a));
+
+ u16 p = htons(header->data_offset * sizeof(u32) + payload_length);
+ tmp = tcp_checksum_even_update(tmp, &p, sizeof(u16));
+
+ // This just skips including the preset header checksum in the
+ // calculation.
+ uintptr_t checksum_location =
+ (uintptr_t)&header->checksum - (uintptr_t)header;
+ uintptr_t rest =
+ header->data_offset * sizeof(u32) - (checksum_location + sizeof(u16));
+
+ tmp = tcp_checksum_even_update(tmp, (u16 *)header, checksum_location);
+ tmp = tcp_checksum_even_update(
+ tmp, (u16 *)((u8 *)header + checksum_location + sizeof(u16)), rest);
+
+ return htons(tcp_checksum_final(tmp, payload, payload_length));
}
static void tcp_send(struct TcpConnection *con, u8 *buffer, u16 length,