diff options
Diffstat (limited to 'kernel/network')
| -rw-r--r-- | kernel/network/tcp.c | 166 | ||||
| -rw-r--r-- | kernel/network/tcp.h | 7 | ||||
| -rw-r--r-- | kernel/network/udp.c | 42 | 
3 files changed, 127 insertions, 88 deletions
| diff --git a/kernel/network/tcp.c b/kernel/network/tcp.c index f45d1af..717c7db 100644 --- a/kernel/network/tcp.c +++ b/kernel/network/tcp.c @@ -3,7 +3,6 @@  #include <network/bytes.h>  #include <network/ipv4.h>  #include <network/udp.h> -#include <socket.h>  extern u8 ip_address[4];  #define CWR (1 << 7) @@ -49,6 +48,16 @@ struct __attribute__((__packed__)) PSEUDO_TCP_HEADER {    u16 urgent_pointer;  }; +void tcp_wait_reply(struct TcpConnection *con) { +  for (;;) { +    if (con->unhandled_packet) { +      return; +    } +    // TODO: Make the scheduler halt the process +    switch_task(); +  } +} +  u16 tcp_checksum(u16 *buffer, int size) {    unsigned long cksum = 0;    while (size > 1) { @@ -64,11 +73,11 @@ u16 tcp_checksum(u16 *buffer, int size) {    return (u16)(~cksum);  } -void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload, +void tcp_calculate_checksum(u8 src_ip[4], u32 dst_ip, const u8 *payload,                              u16 payload_length, struct TCP_HEADER *header) {    struct PSEUDO_TCP_HEADER ps = {0};    memcpy(&ps.src_addr, src_ip, sizeof(u32)); -  memcpy(&ps.dst_addr, dst_ip, sizeof(u32)); +  memcpy(&ps.dst_addr, &dst_ip, sizeof(u32));    ps.protocol = 6;    ps.tcp_length = htons(20 + payload_length);    ps.src_port = header->src_port; @@ -88,80 +97,145 @@ void tcp_calculate_checksum(u8 src_ip[4], u8 dst_ip[4], const u8 *payload,    header->checksum = tcp_checksum((u16 *)buffer, buffer_length);  } -void tcp_close_connection(struct INCOMING_TCP_CONNECTION *inc) { +void tcp_send_empty_payload(struct TcpConnection *con, u8 flags) {    struct TCP_HEADER header = {0}; -  header.src_port = htons(inc->dst_port); -  header.dst_port = inc->n_port; -  header.seq_num = htonl(inc->seq_num); -  header.ack_num = htonl(inc->ack_num); +  header.src_port = htons(con->incoming_port); +  header.dst_port = htons(con->outgoing_port); +  header.seq_num = htonl(con->seq); +  header.ack_num = htonl(con->ack);    header.data_offset = 5;    header.reserved = 0; -  header.flags = FIN | ACK; +  header.flags = flags;    header.window_size = htons(WINDOW_SIZE);    header.urgent_pointer = 0; -  u32 dst_ip; -  memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); +    u8 payload[0];    u16 payload_length = 0; -  tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, +  tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,                           payload_length, &header);    int send_len = sizeof(header) + payload_length;    u8 send_buffer[send_len];    memcpy(send_buffer, &header, sizeof(header));    memcpy(send_buffer + sizeof(header), payload, payload_length); -  send_ipv4_packet(dst_ip, 6, send_buffer, send_len); + +  send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); +} + +void tcp_send_ack(struct TcpConnection *con) { +  tcp_send_empty_payload(con, ACK);  } -void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload, +void tcp_send_syn(struct TcpConnection *con) { +  tcp_send_empty_payload(con, SYN); +  con->seq++; +} + +// void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, const u8 *payload, +//                      u16 payload_length) { +void send_tcp_packet(struct TcpConnection *con, const u8 *payload,                       u16 payload_length) {    if (payload_length > 1500 - 20 - sizeof(struct TCP_HEADER)) { -    send_tcp_packet(inc, payload, 1500 - 20 - sizeof(struct TCP_HEADER)); +    send_tcp_packet(con, payload, 1500 - 20 - sizeof(struct TCP_HEADER));      payload_length -= 1500 - 20 - sizeof(struct TCP_HEADER);      payload += 1500 - 20 - sizeof(struct TCP_HEADER); -    return send_tcp_packet(inc, payload, payload_length); +    return send_tcp_packet(con, payload, payload_length);    }    struct TCP_HEADER header = {0}; -  header.src_port = htons(inc->dst_port); -  header.dst_port = inc->n_port; -  header.seq_num = htonl(inc->seq_num); -  header.ack_num = htonl(inc->ack_num); +  header.src_port = htons(con->incoming_port); +  header.dst_port = htons(con->outgoing_port); +  header.seq_num = htonl(con->seq); +  header.ack_num = htonl(con->ack);    header.data_offset = 5;    header.reserved = 0;    header.flags = PSH | ACK;    header.window_size = htons(WINDOW_SIZE);    header.urgent_pointer = 0; -  u32 dst_ip; -  memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); -  tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, +  tcp_calculate_checksum(ip_address, con->outgoing_ip, (const u8 *)payload,                           payload_length, &header);    int send_len = sizeof(header) + payload_length;    u8 send_buffer[send_len];    memcpy(send_buffer, &header, sizeof(header));    memcpy(send_buffer + sizeof(header), payload, payload_length); -  send_ipv4_packet(dst_ip, 6, send_buffer, send_len); +  send_ipv4_packet(con->outgoing_ip, 6, send_buffer, send_len); -  inc->seq_num += payload_length; +  con->seq += payload_length;  } -void send_empty_tcp_message(struct INCOMING_TCP_CONNECTION *inc, u8 flags, -                            u32 inc_seq_num, u16 n_dst_port, u16 n_src_port) { -  struct TCP_HEADER header = {0}; -  header.src_port = n_dst_port; -  header.dst_port = n_src_port; -  header.seq_num = 0; -  header.ack_num = htonl(inc_seq_num + 1); -  header.data_offset = 5; -  header.reserved = 0; -  header.flags = flags; -  header.window_size = htons(WINDOW_SIZE); -  header.urgent_pointer = 0; -  char payload[0]; -  tcp_calculate_checksum(ip_address, inc->ip, (const u8 *)payload, 0, &header); -  u32 dst_ip; -  memcpy(&dst_ip, inc->ip, sizeof(dst_ip)); -  send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header)); -} +void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) { +  const struct TCP_HEADER *header = (const struct TCP_HEADER *)payload; +  (void)header; +  u16 n_src_port = *(u16 *)(payload); +  u16 n_dst_port = *(u16 *)(payload + 2); +  u32 n_seq_num = *(u32 *)(payload + 4); +  u32 n_ack_num = *(u32 *)(payload + 8); + +  //  u8 flags = *(payload + 13); +  u8 flags = header->flags; +  u16 src_port = htons(n_src_port); +  (void)src_port; +  u16 dst_port = htons(n_dst_port); +  u32 seq_num = htonl(n_seq_num); +  u32 ack_num = htonl(n_ack_num); +  (void)ack_num; + +  if (SYN == flags) { +    u32 t; +    memcpy(&t, src_ip, sizeof(u8[4])); +    struct TcpConnection *con = internal_tcp_incoming(t, src_port, 0, dst_port); +    assert(con); +    con->ack = seq_num + 1; +    tcp_send_empty_payload(con, SYN | ACK); +    return; +  } + +  struct TcpConnection *incoming_connection = tcp_find_connection(dst_port); +  if (incoming_connection) { +    incoming_connection->unhandled_packet = 1; +    if (0 != (flags & RST)) { +      klog("Requested port is closed", LOG_NOTE); +      incoming_connection->dead = 1; +      return; +    } +    if (ACK == flags) { +      if (0 == incoming_connection->handshake_state) { +        // Then it is probably a response to the SYN|ACK we sent. +        incoming_connection->handshake_state = 1; +        return; +      } +    } +    if ((SYN | ACK) == flags) { +      assert(0 == incoming_connection->handshake_state); +      incoming_connection->handshake_state = 1; + +      incoming_connection->ack = seq_num + 1; + +      tcp_send_ack(incoming_connection); +    } +    if (0 != (flags & PSH)) { +      u16 tcp_payload_length = +          payload_length - header->data_offset * sizeof(u32); +      int len = fifo_object_write( +          (u8 *)(payload + header->data_offset * sizeof(u32)), 0, +          tcp_payload_length, incoming_connection->data_file); +      assert(len >= 0); +      incoming_connection->ack += len; +      tcp_send_ack(incoming_connection); +    } +    if (0 != (flags & FIN)) { +      incoming_connection->ack++; + +      tcp_send_empty_payload(incoming_connection, FIN | ACK); + +      incoming_connection->dead = 1; // FIXME: It should wait for a ACK +                                     // of the FIN before the connection +                                     // is closed. +    } +  } else { +    assert(NULL); +  } +} +/*  void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {    const struct TCP_HEADER *inc_header = (const struct TCP_HEADER *)payload;    u16 n_src_port = *(u16 *)(payload); @@ -233,11 +307,11 @@ void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length) {      header.flags = ACK;      header.window_size = htons(WINDOW_SIZE);      header.urgent_pointer = 0; -    char payload[0]; -    tcp_calculate_checksum(ip_address, src_ip, (const u8 *)payload, 0, &header);      u32 dst_ip;      memcpy(&dst_ip, src_ip, sizeof(dst_ip)); +    char payload[0]; +    tcp_calculate_checksum(ip_address, dst_ip, (const u8 *)payload, 0, &header);      send_ipv4_packet(dst_ip, 6, (const u8 *)&header, sizeof(header));      return;    } -} +}*/ diff --git a/kernel/network/tcp.h b/kernel/network/tcp.h index 2a836a4..0f9e818 100644 --- a/kernel/network/tcp.h +++ b/kernel/network/tcp.h @@ -1,4 +1,7 @@ +#include <socket.h> +void tcp_send_syn(struct TcpConnection *con); +void tcp_wait_reply(struct TcpConnection *con);  void handle_tcp(u8 src_ip[4], const u8 *payload, u32 payload_length); -void send_tcp_packet(struct INCOMING_TCP_CONNECTION *inc, u8 *payload, +void send_tcp_packet(struct TcpConnection *con, const u8 *payload,                       u16 payload_length); -void tcp_close_connection(struct INCOMING_TCP_CONNECTION *s); +  void tcp_close_connection(struct INCOMING_TCP_CONNECTION * s); diff --git a/kernel/network/udp.c b/kernel/network/udp.c index 5aae050..4f3848a 100644 --- a/kernel/network/udp.c +++ b/kernel/network/udp.c @@ -20,44 +20,6 @@ void send_udp_packet(struct sockaddr_in *src, const struct sockaddr_in *dst,  }  void handle_udp(u8 src_ip[4], const u8 *payload, u32 packet_length) { -  assert(packet_length >= 8); -  // n_.* means network format(big endian) -  // h_.* means host format((probably) little endian) -  u16 n_source_port = *(u16 *)payload; -  u16 h_source_port = ntohs(n_source_port); -  (void)h_source_port; -  u16 h_dst_port = ntohs(*(u16 *)(payload + 2)); -  u16 h_length = ntohs(*(u16 *)(payload + 4)); -  assert(h_length == packet_length); -  u16 data_length = h_length - 8; -  const u8 *data = payload + 8; - -  // Find the open port -  OPEN_INET_SOCKET *in_s = find_open_udp_port(htons(h_dst_port)); -  assert(in_s); -  SOCKET *s = in_s->s; -  vfs_fd_t *fifo_file = s->ptr_socket_fd; - -  // Write the sockaddr struct such that it can later be -  // given to userland if asked. -  struct sockaddr_in /*{ -    sa_family_t sin_family; -    union { -      u32 s_addr; -    } sin_addr; -    u16 sin_port; -  }*/ in; -  in.sin_family = AF_INET; -  memcpy(&in.sin_addr.s_addr, src_ip, sizeof(u32)); -  in.sin_port = n_source_port; -  socklen_t sock_length = sizeof(struct sockaddr_in); - -  raw_vfs_pwrite(fifo_file, &sock_length, sizeof(sock_length), 0); -  raw_vfs_pwrite(fifo_file, &in, sizeof(in), 0); - -  // Write the UDP payload length(not including header) -  raw_vfs_pwrite(fifo_file, &data_length, sizeof(u16), 0); - -  // Write the UDP payload -  raw_vfs_pwrite(fifo_file, (char *)data, data_length, 0); +  // TODO: Reimplement +  assert(NULL);  } |