summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/typedefs.h1
-rw-r--r--kernel/network/ipv4.c44
-rw-r--r--kernel/socket.c5
3 files changed, 20 insertions, 30 deletions
diff --git a/include/typedefs.h b/include/typedefs.h
index c10482c..e221978 100644
--- a/include/typedefs.h
+++ b/include/typedefs.h
@@ -13,7 +13,6 @@ typedef int32_t i32;
typedef int64_t i64;
typedef union {
- u8 a[4];
u32 d;
} ipv4_t;
diff --git a/kernel/network/ipv4.c b/kernel/network/ipv4.c
index 8fb983a..69ceef7 100644
--- a/kernel/network/ipv4.c
+++ b/kernel/network/ipv4.c
@@ -9,27 +9,13 @@
#include <network/udp.h>
#include <string.h>
-u16 ip_checksum(void *vdata, size_t length) {
- // Cast the data pointer to one that can be indexed.
- char *data = (char *)vdata;
-
+static u16 ip_checksum(const u16 *data, u16 length) {
// Initialise the accumulator.
u32 acc = 0xffff;
// Handle complete 16-bit blocks.
- for (size_t i = 0; i + 1 < length; i += 2) {
- u16 word;
- memcpy(&word, data + i, 2);
- acc += ntohs(word);
- if (acc > 0xffff) {
- acc -= 0xffff;
- }
- }
-
- // Handle any partial block at the end of the data.
- if (length & 1) {
- u16 word = 0;
- memcpy(&word, data + length - 1, 1);
+ for (size_t i = 0; i < length / 2; i++) {
+ u16 word = *(data + i);
acc += ntohs(word);
if (acc > 0xffff) {
acc -= 0xffff;
@@ -41,16 +27,24 @@ u16 ip_checksum(void *vdata, size_t length) {
}
void send_ipv4_packet(ipv4_t ip, u8 protocol, const u8 *payload, u16 length) {
- u8 header[20] = {0};
+ u16 header[10];
header[0] = (4 /*version*/ << 4) | (5 /*IHL*/);
- *((u16 *)(header + 2)) = htons(length + 20);
- header[8 /*TTL*/] = 0xF8;
- header[9] = protocol;
- memcpy(header + 12 /*src_ip*/, &ip_address, sizeof(ipv4_t));
- memcpy(header + 16, &ip, sizeof(ipv4_t));
+ header[1] = htons(length + 20);
+
+ header[2] = 0;
+
+ header[3] = 0;
+ header[4] = (protocol << 8) | 0xF8;
+
+ header[5] = 0;
+ header[6] = (ip_address.d >> 0) & 0xFFFF;
+ header[7] = (ip_address.d >> 16) & 0xFFFF;
+
+ header[8] = (ip.d >> 0) & 0xFFFF;
+ header[9] = (ip.d >> 16) & 0xFFFF;
- *((u16 *)(header + 10 /*checksum*/)) = ip_checksum(header, 20);
+ header[5] = ip_checksum(header, 20);
u16 packet_length = length + 20;
u8 *packet = kmalloc(packet_length);
memcpy(packet, header, 20);
@@ -68,7 +62,7 @@ void handle_ipv4(const u8 *payload, u32 packet_length) {
u16 saved_checksum = *(u16 *)(payload + 10);
*(u16 *)(payload + 10) = 0;
- u16 calc_checksum = ip_checksum((u8 *)payload, 20);
+ u16 calc_checksum = ip_checksum((const u16 *)payload, 20);
*(u16 *)(payload + 10) = saved_checksum;
if (calc_checksum != saved_checksum) {
klog(LOG_WARN, "Invalid ipv4 checksum");
diff --git a/kernel/socket.c b/kernel/socket.c
index d4e86a9..689e2ac 100644
--- a/kernel/socket.c
+++ b/kernel/socket.c
@@ -26,10 +26,7 @@ void global_socket_init(void) {
OPEN_UNIX_SOCKET *un_sockets[100] = {0};
void gen_ipv4(ipv4_t *ip, u8 i1, u8 i2, u8 i3, u8 i4) {
- ip->a[0] = i1;
- ip->a[1] = i2;
- ip->a[2] = i3;
- ip->a[3] = i4;
+ ip->d = (i1 << (8 * 0)) | (i2 << (8 * 1)) | (i3 << (8 * 2)) | (i4 << (8 * 3));
}
void tcp_remove_connection(struct TcpConnection *con) {