diff options
-rw-r--r-- | include/sys/socket.h | 6 | ||||
-rw-r--r-- | kernel/cpu/syscall.c | 2 | ||||
-rw-r--r-- | kernel/drivers/pit.c | 4 | ||||
-rw-r--r-- | kernel/drivers/rtl8139.c | 85 | ||||
-rw-r--r-- | kernel/network/arp.c | 2 | ||||
-rw-r--r-- | kernel/network/ethernet.c | 2 | ||||
-rw-r--r-- | kernel/network/ipv4.c | 1 | ||||
-rw-r--r-- | kernel/network/tcp.c | 57 | ||||
-rw-r--r-- | kernel/socket.c | 96 | ||||
-rw-r--r-- | kernel/socket.h | 4 | ||||
-rw-r--r-- | userland/libc/include/sys/mman.h | 10 | ||||
-rw-r--r-- | userland/libc/netdb/getaddrinfo.c | 83 |
12 files changed, 220 insertions, 132 deletions
diff --git a/include/sys/socket.h b/include/sys/socket.h index 5af84fc..128abaa 100644 --- a/include/sys/socket.h +++ b/include/sys/socket.h @@ -7,8 +7,10 @@ #define AF_INET 1 #define AF_LOCAL AF_UNIX -#define SOCK_DGRAM 0 -#define SOCK_STREAM 1 +#define SOCK_DGRAM (1 << 0) +#define SOCK_STREAM (1 << 1) +#define SOCK_NONBLOCK (1 << 2) + #define MSG_WAITALL 1 #define INADDR_ANY 0 diff --git a/kernel/cpu/syscall.c b/kernel/cpu/syscall.c index b981d5b..3fe0766 100644 --- a/kernel/cpu/syscall.c +++ b/kernel/cpu/syscall.c @@ -9,11 +9,11 @@ #include <interrupts.h> #include <kmalloc.h> #include <network/ethernet.h> +#include <network/tcp.h> #include <socket.h> #include <string.h> #include <syscalls.h> #include <typedefs.h> -#include <network/tcp.h> #pragma GCC diagnostic ignored "-Wpedantic" diff --git a/kernel/drivers/pit.c b/kernel/drivers/pit.c index b938de4..f8d2c42 100644 --- a/kernel/drivers/pit.c +++ b/kernel/drivers/pit.c @@ -43,12 +43,16 @@ void set_pit_count(u16 _hertz) { outb(PIT_IO_CHANNEL_0, (divisor & 0xFF00) >> 8); } +extern int is_switching_tasks; void int_clock(reg_t *regs) { clock_num_ms_ticks += 5; switch_counter++; if (switch_counter >= hertz) { EOI(0x20); switch_counter = 0; + if(is_switching_tasks) { + return; + } switch_task(); } else { EOI(0x20); diff --git a/kernel/drivers/rtl8139.c b/kernel/drivers/rtl8139.c index 74d7c05..610ca47 100644 --- a/kernel/drivers/rtl8139.c +++ b/kernel/drivers/rtl8139.c @@ -12,7 +12,7 @@ #define CMD 0x37 #define IMR 0x3C -#define RTL8139_RXBUFFER_SIZE (8192 + 16) +#define RTL8139_RXBUFFER_SIZE (8192 << 3) #define TSD0 0x10 // transmit status #define TSAD0 0x20 // transmit start address @@ -34,73 +34,72 @@ struct _INT_PACKET_HEADER { u8 BAR : 1; u8 PAM : 1; u8 MAR : 1; -}; +} __attribute__((packed)); struct PACKET_HEADER { union { u16 raw; struct _INT_PACKET_HEADER data; }; -}; +} __attribute__((packed)); -u32 current_packet_read = 0; +unsigned short current_packet_read = 0; void handle_packet(void) { assert(sizeof(struct _INT_PACKET_HEADER) == sizeof(u16)); + int had_error = 0; + for (; 0 == (inb(rtl8139.gen.base_mem_io + 0x37) & 1);) { - u16 *buf = (u16 *)(device_buffer + current_packet_read); + u32 ring_offset = current_packet_read % RTL8139_RXBUFFER_SIZE; + + u16 rx_size = *(u16 *)(device_buffer + ring_offset + sizeof(u16)); + struct PACKET_HEADER packet_header; - packet_header.raw = *buf; - if (packet_header.data.FAE) { - break; - } - if (packet_header.data.CRC) { - break; - } - if (!packet_header.data.ROK) { - break; - } - u16 packet_length = *(buf + 1); - assert(packet_length <= 2048); - - u8 packet_buffer[RTL8139_RXBUFFER_SIZE]; - if (current_packet_read + packet_length >= RTL8139_RXBUFFER_SIZE) { - u32 end = RTL8139_RXBUFFER_SIZE - current_packet_read; - memcpy(packet_buffer, buf, end); - u32 rest = packet_length - end; - memcpy(packet_buffer + end, device_buffer, rest); + packet_header.raw = device_buffer[ring_offset + 0]; + + int error = (packet_header.data.FAE) || (packet_header.data.CRC) || + (!packet_header.data.ROK); + + if (error) { + current_packet_read = 0; + outb(rtl8139.gen.base_mem_io + 0x37, 0x4); + outb(rtl8139.gen.base_mem_io + 0x37, 0x4 | 0x8); + had_error = 1; } else { - memcpy(packet_buffer, buf, packet_length); + int packet_length = rx_size - 4; + assert(packet_length <= 2048); + + u8 packet_buffer[packet_length]; + if (ring_offset + rx_size > RTL8139_RXBUFFER_SIZE) { + int end = RTL8139_RXBUFFER_SIZE - ring_offset - 4; + memcpy(packet_buffer, &device_buffer[ring_offset + 4], end); + memcpy(packet_buffer + end, device_buffer, packet_length - end); + } else { + memcpy(packet_buffer, &device_buffer[ring_offset + 4], packet_length); + } + + handle_ethernet(packet_buffer, packet_length); } - // I have no documentation backing this implementation of updating - // the CBR. It is just a (somewhat)uneducated guess. But it does - // seem to work. - u32 old = current_packet_read; - current_packet_read = (current_packet_read + packet_length + 4 + 3) & (~3); - current_packet_read %= RTL8139_RXBUFFER_SIZE; + current_packet_read = (current_packet_read + rx_size + 4 + 3) & (~3); outw(rtl8139.gen.base_mem_io + 0x38, current_packet_read - 0x10); - current_packet_read = inw(rtl8139.gen.base_mem_io + 0x3A); - outw(rtl8139.gen.base_mem_io + 0x38, current_packet_read - 0x10); - - assert(current_packet_read != old); - - handle_ethernet((u8 *)packet_buffer + 4, packet_length); + } + if (had_error) { + current_packet_read = 0; } } void rtl8139_handler(void *regs) { - disable_interrupts(); (void)regs; u16 status = inw(rtl8139.gen.base_mem_io + 0x3e); + outw(rtl8139.gen.base_mem_io + 0x3E, 0x5); if (status & (1 << 2)) { } if (status & (1 << 0)) { handle_packet(); } - outw(rtl8139.gen.base_mem_io + 0x3E, 0x5); EOI(0xB); } @@ -170,8 +169,8 @@ void rtl8139_init(void) { outb(base_address + CMD, 0x10); for (; 0 != (inb(base_address + CMD) & 0x10);) ; - device_buffer = ksbrk(RTL8139_RXBUFFER_SIZE); - memset(device_buffer, 0, RTL8139_RXBUFFER_SIZE); + device_buffer = ksbrk(RTL8139_RXBUFFER_SIZE + 16); + memset(device_buffer, 0, RTL8139_RXBUFFER_SIZE + 16); // Setupt the recieve buffer u32 rx_buffer = (u32)virtual_to_physical(device_buffer, NULL); outl(base_address + RBSTART, rx_buffer); @@ -183,9 +182,11 @@ void rtl8139_init(void) { // Set transmit and reciever enable outb(base_address + 0x37, (1 << 2) | (1 << 3)); + int buffer_length = 3; // 0b11 for 64K + 16 byte + // Configure the recieve buffer outl(base_address + 0x44, - 0xf); // 0xf is AB+AM+APM+AAP + 0xf | (buffer_length << 11)); // 0xf is AB+AM+APM+AAP install_handler((interrupt_handler)rtl8139_handler, INT_32_INTERRUPT_GATE(0x3), 0x20 + interrupt_line); diff --git a/kernel/network/arp.c b/kernel/network/arp.c index 33aaaec..44dfd44 100644 --- a/kernel/network/arp.c +++ b/kernel/network/arp.c @@ -113,8 +113,8 @@ int get_mac_from_ip(const ipv4_t ip, u8 mac[6]) { return 1; } klog(LOG_NOTE, "ARP cache miss"); - enable_interrupts(); send_arp_request(ip); + enable_interrupts(); // TODO: Maybe wait a bit? for (int i = 0; i < 10; i++) { if (0 != memcmp(arp_table[i].ip, &ip, sizeof(u8[4]))) { diff --git a/kernel/network/ethernet.c b/kernel/network/ethernet.c index a4f2f85..deb942b 100644 --- a/kernel/network/ethernet.c +++ b/kernel/network/ethernet.c @@ -60,7 +60,7 @@ void handle_ethernet(const u8 *packet, u64 packet_length) { handle_arp(payload); break; case 0x0800: - handle_ipv4(payload, packet_length - sizeof(struct ETHERNET_HEADER) - 4); + handle_ipv4(payload, packet_length - sizeof(struct ETHERNET_HEADER)); break; default: kprintf("Can't handle ethernet type 0x%x\n", type); diff --git a/kernel/network/ipv4.c b/kernel/network/ipv4.c index cceee8e..8fb983a 100644 --- a/kernel/network/ipv4.c +++ b/kernel/network/ipv4.c @@ -71,6 +71,7 @@ void handle_ipv4(const u8 *payload, u32 packet_length) { u16 calc_checksum = ip_checksum((u8 *)payload, 20); *(u16 *)(payload + 10) = saved_checksum; if (calc_checksum != saved_checksum) { + klog(LOG_WARN, "Invalid ipv4 checksum"); return; } diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index e42482f..d8538cc 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -2,6 +2,7 @@ #include <cpu/arch_inst.h> #include <drivers/pit.h> #include <fs/vfs.h> +#include <interrupts.h> #include <math.h> #include <network/arp.h> #include <network/bytes.h> @@ -83,7 +84,7 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { header.dst_port = htons(con->outgoing_port); header.seq_num = htonl(con->snd_nxt); if (flags & ACK) { - header.ack_num = htonl(con->sent_ack); + header.ack_num = htonl(con->rcv_nxt); } else { header.ack_num = 0; } @@ -110,6 +111,13 @@ void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) { con->snd_max = max(con->snd_nxt, con->snd_max); } +// When both the client and the server have closed the connection it can +// be "destroyed". +void tcp_destroy_connection(struct TcpConnection *con) { + con->state = TCP_STATE_CLOSED; + tcp_remove_connection(con); +} + void tcp_close_connection(struct TcpConnection *con) { if (TCP_STATE_CLOSE_WAIT == con->state) { tcp_send_empty_payload(con, FIN); @@ -127,8 +135,7 @@ void tcp_close_connection(struct TcpConnection *con) { return; } if (TCP_STATE_SYN_SENT == con->state) { - con->state = TCP_STATE_CLOSED; - // TODO: Cleanup + tcp_destroy_connection(con); return; } } @@ -159,7 +166,7 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload, header.src_port = htons(con->incoming_port); header.dst_port = htons(con->outgoing_port); header.seq_num = htonl(con->snd_nxt); - header.ack_num = htonl(con->sent_ack); + header.ack_num = htonl(con->rcv_nxt); header.data_offset = 5; header.reserved = 0; header.flags = PSH | ACK; @@ -184,12 +191,13 @@ int send_tcp_packet(struct TcpConnection *con, const u8 *payload, void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, u32 payload_length) { const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload; - u16 tcp_payload_length = payload_length - header->data_offset * sizeof(u32); + u32 tcp_payload_length = payload_length - header->data_offset * sizeof(u32); const u8 *tcp_payload = payload + header->data_offset * sizeof(u32); u16 checksum = tcp_calculate_checksum(src_ip, dst_ip.d, tcp_payload, tcp_payload_length, header, payload_length); if (header->checksum != checksum) { + klog(LOG_WARN, "Bad TCP checksum"); return; } u16 n_src_port = header->src_port; @@ -207,15 +215,22 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, u16 window_size = htons(n_window_size); struct TcpConnection *con = tcp_find_connection(src_ip, src_port, dst_ip, dst_port); + if (!con) { + return; + } if (con->state == TCP_STATE_LISTEN || con->state == TCP_STATE_SYN_SENT) { con->rcv_nxt = seq_num; } + /* u32 seq_num_end = seq_num + tcp_payload_length - 1; - if (!((con->rcv_nxt <= seq_num) && seq_num < con->rcv_nxt + con->rcv_wnd) && - !((con->rcv_nxt <= seq_num_end) && - seq_num_end < con->rcv_nxt + con->rcv_wnd)) { + int case1 = + (con->rcv_nxt <= seq_num) && (seq_num < (con->rcv_nxt + con->rcv_wnd)); + int case2 = (con->rcv_nxt <= seq_num_end) && + (seq_num_end < (con->rcv_nxt + con->rcv_wnd)); + + if (!case1 && !case2) { kprintf("seq_num: %d\n", seq_num); kprintf("seq_num_end: %d\n", seq_num_end); kprintf("con->rcv_nxt: %d\n", con->rcv_nxt); @@ -224,6 +239,7 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, kprintf("invalid segment\n"); return; } + */ if (ack_num > con->snd_max) { // TODO: Odd ACK number, what should be done? @@ -237,19 +253,28 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, return; } + if (con->rcv_nxt != seq_num) { + return; + } + + // kprintf("seq_num: %d rcv_nxt %d\n", seq_num, con->rcv_nxt); + // kprintf("tcp_payload_length: %d\n", tcp_payload_length); + con->snd_wnd = window_size; con->snd_una = ack_num; - con->sent_ack = max(con->sent_ack, seq_num + tcp_payload_length); - con->sent_ack += (FIN & flags)?1:0; - con->sent_ack += (SYN & flags)?1:0; + // con->sent_ack = + // max(con->sent_ack, seq_num + + // min(ringbuffer_unused(&con->incoming_buffer), + // tcp_payload_length)); + con->rcv_nxt += (FIN & flags) ? 1 : 0; + con->rcv_nxt += (SYN & flags) ? 1 : 0; switch (con->state) { case TCP_STATE_LISTEN: { if (SYN & flags) { tcp_send_empty_payload(con, SYN | ACK); con->state = TCP_STATE_SYN_RECIEVED; - con->rcv_nxt++; break; } break; @@ -265,7 +290,6 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, if ((ACK & flags) && (SYN & flags)) { tcp_send_empty_payload(con, ACK); con->state = TCP_STATE_ESTABLISHED; - con->rcv_nxt++; break; } break; @@ -277,10 +301,12 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, break; } if (tcp_payload_length > 0) { + if (tcp_payload_length > ringbuffer_unused(&con->incoming_buffer)) { + return; + } int rc = ringbuffer_write(&con->incoming_buffer, tcp_payload, tcp_payload_length); con->rcv_nxt += rc; - con->rcv_wnd = ringbuffer_unused(&con->incoming_buffer); tcp_send_empty_payload(con, ACK); } break; @@ -304,8 +330,7 @@ void handle_tcp(ipv4_t src_ip, ipv4_t dst_ip, const u8 *payload, } case TCP_STATE_LAST_ACK: { if (ACK & flags) { - // TODO: cleanup - con->state = TCP_STATE_CLOSED; + tcp_destroy_connection(con); break; } break; diff --git a/kernel/socket.c b/kernel/socket.c index 5aeeb93..ad774e2 100644 --- a/kernel/socket.c +++ b/kernel/socket.c @@ -15,11 +15,11 @@ #include <sys/socket.h> struct list open_udp_connections; -struct list open_tcp_connections; +struct relist open_tcp_connections; void global_socket_init(void) { list_init(&open_udp_connections); - list_init(&open_tcp_connections); + relist_init(&open_tcp_connections); } OPEN_UNIX_SOCKET *un_sockets[100] = {0}; @@ -31,12 +31,35 @@ void gen_ipv4(ipv4_t *ip, u8 i1, u8 i2, u8 i3, u8 i4) { ip->a[3] = i4; } +void tcp_remove_connection(struct TcpConnection *con) { + // TODO: It should also be freed but I am unsure if a inode might + // still have a pointer(that should not be the case) + for (int i = 0;; i++) { + struct TcpConnection *c; + int end; + if (!relist_get(&open_tcp_connections, i, (void **)&c, &end)) { + if (end) { + break; + } + continue; + } + if (c == con) { + relist_remove(&open_tcp_connections, i); + break; + } + } +} + struct TcpConnection *tcp_find_connection(ipv4_t src_ip, u16 src_port, ipv4_t dst_ip, u16 dst_port) { for (int i = 0;; i++) { struct TcpConnection *c; - if (!list_get(&open_tcp_connections, i, (void **)&c)) { - break; + int end; + if (!relist_get(&open_tcp_connections, i, (void **)&c, &end)) { + if (end) { + break; + } + continue; } if (TCP_STATE_CLOSED == c->state) { continue; @@ -48,6 +71,25 @@ struct TcpConnection *tcp_find_connection(ipv4_t src_ip, u16 src_port, return NULL; } +u16 tcp_find_free_port(u16 suggestion) { + for (int i = 0;; i++) { + struct TcpConnection *c; + int end; + if (!relist_get(&open_tcp_connections, i, (void **)&c, &end)) { + if (end) { + break; + } + continue; + } + if (c->incoming_port == suggestion) { + suggestion++; + // FIXME: Recursion bad + return tcp_find_free_port(suggestion); + } + } + return suggestion; +} + struct UdpConnection *udp_find_connection(ipv4_t src_ip, u16 src_port, u16 dst_port) { (void)src_ip; @@ -106,7 +148,7 @@ int tcp_read(u8 *buffer, u64 offset, u64 len, vfs_fd_t *fd) { u32 rc = ringbuffer_read(&con->incoming_buffer, buffer, len); if (0 == rc && len > 0) { if (TCP_STATE_ESTABLISHED != con->state) { - return 0; + return -ENOTCONN; } return -EWOULDBLOCK; } @@ -239,6 +281,7 @@ void udp_close(vfs_fd_t *fd) { } int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + kprintf("CONNECT\n"); vfs_fd_t *fd = get_vfs_fd(sockfd, NULL); if (!fd) { return -EBADF; @@ -278,13 +321,13 @@ int tcp_connect(vfs_fd_t *fd, const struct sockaddr *addr, socklen_t addrlen) { const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr; con->state = TCP_STATE_LISTEN; - con->incoming_port = 1337; // TODO + con->incoming_port = tcp_find_free_port(1337); con->outgoing_ip = in_addr->sin_addr.s_addr; con->outgoing_port = ntohs(in_addr->sin_port); con->max_seg = 1; - con->rcv_wnd = 4096; + con->rcv_wnd = 65535; con->rcv_nxt = 0; con->rcv_adv = 0; @@ -294,24 +337,26 @@ int tcp_connect(vfs_fd_t *fd, const struct sockaddr *addr, socklen_t addrlen) { con->snd_wnd = 0; con->sent_ack = 0; - ringbuffer_init(&con->incoming_buffer, 8192); + ringbuffer_init(&con->incoming_buffer, con->rcv_wnd * 4); ringbuffer_init(&con->outgoing_buffer, 8192); - con->snd_wnd = ringbuffer_unused(&con->incoming_buffer); - list_add(&open_tcp_connections, con, NULL); + con->snd_wnd = ringbuffer_unused(&con->outgoing_buffer); + relist_add(&open_tcp_connections, con, NULL); con->state = TCP_STATE_SYN_SENT; tcp_send_empty_payload(con, SYN); - for (;;) { - switch_task(); - if (TCP_STATE_CLOSED == con->state) { - return -ECONNREFUSED; - } - if (TCP_STATE_ESTABLISHED == con->state) { - break; + if (!(O_NONBLOCK & fd->flags)) { + for (;;) { + switch_task(); + if (TCP_STATE_CLOSED == con->state) { + return -ECONNREFUSED; + } + if (TCP_STATE_ESTABLISHED == con->state) { + break; + } + assert(TCP_STATE_SYN_SENT == con->state || + TCP_STATE_ESTABLISHED == con->state); } - assert(TCP_STATE_SYN_SENT == con->state || - TCP_STATE_ESTABLISHED == con->state); } fd->inode->_has_data = tcp_has_data; @@ -528,7 +573,7 @@ void socket_close(vfs_fd_t *fd) { fd->inode->is_open = 0; } -int tcp_create_fd() { +int tcp_create_fd(int is_nonblock) { struct TcpConnection *con = kmalloc(sizeof(struct TcpConnection)); if (!con) { return -ENOMEM; @@ -544,7 +589,8 @@ int tcp_create_fd() { kfree(con); return -ENOMEM; } - int fd = vfs_create_fd(O_RDWR, 0, 0 /*is_tty*/, inode, NULL); + int fd = vfs_create_fd(O_RDWR | (is_nonblock ? O_NONBLOCK : 0), 0, + 0 /*is_tty*/, inode, NULL); if (fd < 0) { kfree(con); kfree(inode); @@ -598,9 +644,10 @@ int socket(int domain, int type, int protocol) { return n; } if (AF_INET == domain) { - int is_udp = (SOCK_DGRAM == type); + int is_udp = (SOCK_DGRAM & type); + int is_nonblock = (SOCK_NONBLOCK & type); if (!is_udp) { - return tcp_create_fd(); + return tcp_create_fd(is_nonblock); } vfs_inode_t *inode = vfs_create_inode( @@ -612,7 +659,8 @@ int socket(int domain, int type, int protocol) { rc = -ENOMEM; goto socket_error; } - int n = vfs_create_fd(O_RDWR, 0, 0 /*is_tty*/, inode, NULL); + int n = vfs_create_fd(O_RDWR | (is_nonblock ? O_NONBLOCK : 0), 0, + 0 /*is_tty*/, inode, NULL); if (n < 0) { rc = n; goto socket_error; diff --git a/kernel/socket.h b/kernel/socket.h index 0c419fe..454992c 100644 --- a/kernel/socket.h +++ b/kernel/socket.h @@ -15,9 +15,6 @@ typedef int socklen_t; #define AF_INET 1 #define AF_LOCAL AF_UNIX -#define SOCK_DGRAM 0 -#define SOCK_STREAM 1 - #define INADDR_ANY 0 #define MSG_WAITALL 1 @@ -144,4 +141,5 @@ void global_socket_init(void); u16 tcp_get_free_port(void); int setsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len); +void tcp_remove_connection(struct TcpConnection *con); #endif diff --git a/userland/libc/include/sys/mman.h b/userland/libc/include/sys/mman.h index 7cba68f..d60997d 100644 --- a/userland/libc/include/sys/mman.h +++ b/userland/libc/include/sys/mman.h @@ -1,14 +1,16 @@ #ifndef MMAP_H #define MMAP_H -#include <stdint.h> #include <stddef.h> +#include <stdint.h> + +#define MAP_FAILED ((void *)-1) #define PROT_READ (1 << 0) #define PROT_WRITE (1 << 1) -#define MAP_PRIVATE (1 << 0) -#define MAP_ANONYMOUS (1<< 1) -#define MAP_SHARED (1<< 2) +#define MAP_PRIVATE (1 << 0) +#define MAP_ANONYMOUS (1 << 1) +#define MAP_SHARED (1 << 2) void *mmap(void *addr, size_t length, int prot, int flags, int fd, size_t offset); diff --git a/userland/libc/netdb/getaddrinfo.c b/userland/libc/netdb/getaddrinfo.c index d550b39..f7e89c0 100644 --- a/userland/libc/netdb/getaddrinfo.c +++ b/userland/libc/netdb/getaddrinfo.c @@ -1,14 +1,12 @@ -#include <assert.h> -#include <stdint.h> -#include <stdlib.h> -#include <tb/sv.h> #include <arpa/inet.h> #include <assert.h> #include <ctype.h> +#include <stdint.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/socket.h> +#include <tb/sv.h> #include <unistd.h> #define PORT 53 @@ -137,6 +135,7 @@ int getaddrinfo(const char *restrict node, const char *restrict service, close(sockfd); return EAI_FAIL; } + close(sockfd); if (0 != memcmp(buffer, &id, sizeof(u16))) { close(sockfd); @@ -163,7 +162,7 @@ int getaddrinfo(const char *restrict node, const char *restrict service, } u16 ancount = ntohs(*(u16 *)(buffer + sizeof(u16) * 3)); - if (1 != ancount) { + if (0 == ancount) { close(sockfd); return EAI_NONAME; // TODO: Check if this is correct } @@ -192,44 +191,52 @@ int getaddrinfo(const char *restrict node, const char *restrict service, } */ - // type - u16 type = ntohs(*(u16 *)(answer + sizeof(u16))); - u16 class = ntohs(*(u16 *)(answer + sizeof(u16) * 2)); - if (1 != type) { - close(sockfd); - return EAI_FAIL; - } - if (1 != class) { - close(sockfd); - return EAI_FAIL; - } - - u16 rdlength = ntohs(*(u16 *)(answer + sizeof(u16) * 3 + sizeof(u32))); + for (u16 i = 0; i < answer_length;) { + // type + u16 type = ntohs(*(u16 *)(answer + sizeof(u16))); + u16 class = ntohs(*(u16 *)(answer + sizeof(u16) * 2)); + int ignore = 0; + u16 inc = 0; + if (1 != type) { + ignore = 1; + } + if (1 != class) { + ignore = 1; + } + inc += sizeof(u16) * 3; - if (4 != rdlength) { - close(sockfd); - return EAI_FAIL; - } - u32 ip = *(u32 *)(answer + sizeof(u16) * 4 + sizeof(u32)); + u16 rdlength = ntohs(*(u16 *)(answer + sizeof(u16) * 3 + sizeof(u32))); + inc += sizeof(u16); + inc += rdlength; + inc += sizeof(u32); - if (res) { - *res = calloc(1, sizeof(struct addrinfo)); - if (!(*res)) { - close(sockfd); - return EAI_MEMORY; + if (4 != rdlength) { + ignore = 1; } - struct sockaddr_in *sa = calloc(1, sizeof(struct sockaddr_in)); - if (!sa) { - free(*res); - close(sockfd); - return EAI_MEMORY; + if (!ignore) { + u32 ip = *(u32 *)(answer + sizeof(u16) * 4 + sizeof(u32)); + + if (res) { + *res = calloc(1, sizeof(struct addrinfo)); + if (!(*res)) { + close(sockfd); + return EAI_MEMORY; + } + struct sockaddr_in *sa = calloc(1, sizeof(struct sockaddr_in)); + if (!sa) { + free(*res); + close(sockfd); + return EAI_MEMORY; + } + (*res)->ai_addr = (struct sockaddr *)sa; + sa->sin_addr.s_addr = ip; + sa->sin_port = service_to_port(service); + sa->sin_family = AF_INET; + } } - (*res)->ai_addr = (struct sockaddr *)sa; - sa->sin_addr.s_addr = ip; - sa->sin_port = service_to_port(service); - sa->sin_family = AF_INET; + i += inc; + answer += inc; } - close(sockfd); return 0; } |