diff --git a/crates/kernel/examples/net.rs b/crates/kernel/examples/net.rs index 29ecdb55..44a8c4bd 100644 --- a/crates/kernel/examples/net.rs +++ b/crates/kernel/examples/net.rs @@ -4,17 +4,27 @@ extern crate alloc; extern crate kernel; -use core::net::Ipv4Addr; - -use kernel::{device::usb::device::net::get_dhcpd_mut, event::{task, thread}, networking::{iface::icmp, repr::{IcmpPacket, Ipv4Address}, socket::RawSocket}, ringbuffer}; +#[allow(unused_imports)] +use kernel::{ + device::usb::device::net::get_dhcpd_mut, + event::{task, thread}, + networking::{ + iface::icmp, + repr::{HttpMethod, HttpPacket, IcmpPacket, Ipv4Address}, + socket::RawSocket, + Result, + }, + ringbuffer, +}; #[allow(unused_imports)] use kernel::networking::socket::{ - bind, accept, listen, connect, recv_from, send_to, SocketAddr, TcpSocket, UdpSocket, + accept, bind, close, connect, listen, recv_from, send_to, SocketAddr, TcpSocket, UdpSocket, }; -use kernel::networking::Result; use kernel::*; +use alloc::string::String; + #[no_mangle] extern "Rust" fn kernel_main(_device_tree: device_tree::DeviceTree) { println!("| starting kernel_main"); @@ -28,22 +38,25 @@ async fn main() { println!("starting dhcpd"); let dhcpd = get_dhcpd_mut(); - dhcpd.start().await; + let _ = dhcpd.start().await; println!("out of dhcpd"); - // [udp send test] + // // [udp send test] // println!("udp send test"); // let s = UdpSocket::new(); // let saddr = SocketAddr { - // addr: Ipv4Address::new([11, 187, 10, 102]), + // addr: Ipv4Address::new([10, 0, 2, 2]), // port: 1337, // }; // for _i in 0..5 { - // let _ = send_to(s, "hello everynyan".as_bytes().to_vec(), saddr).await; + // let _ = send_to(s, "hello everynyan\n".as_bytes().to_vec(), saddr).await; // } // println!("end udp send test"); + // for _i in 0..5 { + // sync::spin_sleep(500_000); + // } // [udp listening test] // println!("udp listening test"); @@ -61,11 +74,10 @@ async fn main() { // // println!("end udp listening test"); - // [tcp send test] // println!("tcp send test"); // let saddr = SocketAddr { - // addr: Ipv4Address::new([11, 187, 10, 102]), + // addr: Ipv4Address::new([10, 0, 2, 2]), // port: 1337, // }; // @@ -75,35 +87,60 @@ async fn main() { // Err(_) => println!("couldn't connect"), // }; // - // for _i in 0..5 { - // let _ = send_to(s, "hello everynyan".as_bytes().to_vec(), saddr); + // for _i in 0..100 { + // let _ = send_to(s, "hello everynyan\n".as_bytes().to_vec(), saddr).await; // } + // + // close(s).await; // println!("tcp send test end"); - // [tcp recv test] - let s = TcpSocket::new(); + // let s = TcpSocket::new(); + // + // bind(s, 22); - bind(s, 22); - listen(s, 1).await; // has a timeout, we will wait for 5 seconds + // listen(s, 1).await; + // + // let clientfd = accept(s).await; + // + // let mut tot = 0; + // while let recv = recv_from(*clientfd.as_ref().unwrap()).await { + // if let Ok((payload, senderaddr)) = recv { + // println!("got message: {:x?}", payload); + // tot += payload.len() + // } else { + // println!("\t[!] got a fin, ended"); + // break; + // } + // } + // + // println!("got {} bytes", tot); - let clientfd = accept(s).await; + // [http request test] + println!("http send test"); + // let host = "http.badssl.com"; + // let host = "http-textarea.badssl.com"; + // let host = "httpforever.com"; + let host = "neverssl.com"; + let saddr = SocketAddr::resolve(host, 80).await; + + let s = TcpSocket::new(); + match connect(s, saddr).await { + Ok(_) => (), + Err(_) => println!("couldn't connect"), + }; + let path = "/"; + let http_req = HttpPacket::new(HttpMethod::Get, host, path); + let _ = send_to(s, http_req.serialize(), saddr).await; - while let recv = recv_from(*clientfd.as_ref().unwrap()).await { - if let Ok((payload, senderaddr)) = recv { - println!("got message: {:x?}", payload); - } else { - break; - } - } + let (resp, _) = recv_from(s).await.unwrap(); - // there is a delay when calling NetSend on a packet, this loop is to allow all the packets to - // drain out - for i in 0..32 { - sync::spin_sleep(500_000); - } + let _ = close(s).await; + println!("response:\n{:?}", resp); + println!("response:\n{:?}", String::from_utf8(resp)); + println!("http send test end"); shutdown(); } diff --git a/crates/kernel/src/device/usb/device/net.rs b/crates/kernel/src/device/usb/device/net.rs index 7a6a9984..aae2bb8f 100644 --- a/crates/kernel/src/device/usb/device/net.rs +++ b/crates/kernel/src/device/usb/device/net.rs @@ -150,7 +150,7 @@ pub fn NetAttach(device: &mut UsbDevice, interface_number: u32) -> ResultCode { // begin receieve series, this queues a receive to be ran which will eventually propogate back // to us through the rgistered `recv` function which then queues another receive - let buf = vec![0u8; 1500]; + let buf = vec![0u8; 1600]; unsafe { rndis_receive_packet(device, buf.into_boxed_slice(), 1500); // TODO: ask aaron if I need to use another function? } @@ -204,6 +204,12 @@ pub unsafe fn NetReceive(buffer: *mut u8, buffer_length: u32) { println!("| Net: No callback for receive."); } } + + let buf = vec![0u8; 1]; + unsafe { + let device = &mut *NET_DEVICE.device.unwrap(); + rndis_receive_packet(device, buf.into_boxed_slice(), 1600); + } } pub fn RegisterNetReceiveCallback(callback: unsafe fn(*mut u8, u32)) { diff --git a/crates/kernel/src/device/usb/device/rndis.rs b/crates/kernel/src/device/usb/device/rndis.rs index 2a814185..24063e89 100644 --- a/crates/kernel/src/device/usb/device/rndis.rs +++ b/crates/kernel/src/device/usb/device/rndis.rs @@ -29,7 +29,8 @@ pub fn rndis_initialize_msg(device: &mut UsbDevice) -> ResultCode { request_id: 0, major_version: 1, minor_version: 0, - max_transfer_size: 0x4000, + // max_transfer_size: 0x4000, + max_transfer_size: 1540, }; let mut buffer_req = [0u8; 52]; diff --git a/crates/kernel/src/device/usb/usbd/endpoint.rs b/crates/kernel/src/device/usb/usbd/endpoint.rs index 6820e204..381e1ae9 100644 --- a/crates/kernel/src/device/usb/usbd/endpoint.rs +++ b/crates/kernel/src/device/usb/usbd/endpoint.rs @@ -8,9 +8,13 @@ */ use crate::device::usb; +use crate::device::system_timer::micro_delay; +use crate::device::usb::hcd::dwc::dwc_otg; use crate::device::usb::hcd::dwc::dwc_otg::DwcActivateChannel; use crate::device::usb::hcd::dwc::dwc_otg::UpdateDwcOddFrame; +use crate::device::usb::hcd::dwc::dwc_otgreg::DOTG_HCINT; use crate::device::usb::hcd::dwc::dwc_otgreg::HCINT_FRMOVRUN; +use crate::device::usb::DwcDisableChannel; use crate::device::usb::UsbSendInterruptMessage; use usb::dwc_hub; use usb::hcd::dwc::dwc_otg::HcdUpdateTransferSize; @@ -19,18 +23,16 @@ use usb::types::*; use usb::usbd::device::*; use usb::usbd::pipe::*; use usb::PacketId; -use crate::device::usb::hcd::dwc::dwc_otg; -use crate::device::usb::hcd::dwc::dwc_otgreg::DOTG_HCINT; -use crate::device::usb::DwcDisableChannel; -use crate::device::system_timer::micro_delay; - use crate::event::task::spawn_async_rt; -use crate::shutdown; use crate::sync::time::{interval, MissedTicks}; use alloc::boxed::Box; +// static mut NET_BUFFER_CUR_LEN: u32 = 0; +static mut NET_BUFFER_LEN: u32 = 0; +static mut NET_BUFFER_ACTIVE: bool = false; + pub fn finish_bulk_endpoint_callback_in( endpoint: endpoint_descriptor, hcint: u32, @@ -39,7 +41,7 @@ pub fn finish_bulk_endpoint_callback_in( let device = unsafe { &mut *endpoint.device }; let transfer_size = HcdUpdateTransferSize(device, channel); - let last_transfer = endpoint.buffer_length - transfer_size; + let mut last_transfer = endpoint.buffer_length - transfer_size; let endpoint_device = device.driver_data.downcast::().unwrap(); if hcint & HCINT_NAK != 0 { @@ -55,8 +57,7 @@ pub fn finish_bulk_endpoint_callback_in( channel, hcint, last_transfer ); - if last_transfer > 0 && (hcint & HCINT_CHHLTD == 0) && (hcint & HCINT_XFERCOMPL == 0) - { + if last_transfer > 0 && (hcint & HCINT_CHHLTD == 0) && (hcint & HCINT_XFERCOMPL == 0) { // DwcActivateChannel(channel); return false; @@ -68,8 +69,8 @@ pub fn finish_bulk_endpoint_callback_in( // return true; } } - // return; // WARN: aaron said to comment this out - + // return; // WARN: aaron said to comment this out + if hcint & HCINT_CHHLTD == 0 { panic!( "| Endpoint {} in: HCINT_CHHLTD not set, aborting. hcint: {:x} last transfer: {}", @@ -91,6 +92,45 @@ pub fn finish_bulk_endpoint_callback_in( // core::ptr::copy_nonoverlapping(dma_addr as *const u8, buffer, buffer_length as usize); // } + //assume rndis net bulk in + unsafe { + if !NET_BUFFER_ACTIVE { + use alloc::slice; + // let slice: &[u8] = unsafe { slice::from_raw_parts(dma_addr as *const u8, 16 as usize) }; + let slice32: &[u32] = slice::from_raw_parts(dma_addr as *const u32, 4 as usize); + //print slice + // println!("| Net buffer: {:?}", slice); + // println!("| Net buffer 32: {:?}", slice32); + let _buffer32 = dma_addr as *const u32; + + let rndis_len = slice32[3]; + // let part1 = unsafe { buffer32.offset(0) } as u32; + // println!("| rndis 1 {}", part1); + // println!( + // "| Net buffer length: {} rndis_len: {}", + // last_transfer, rndis_len + // ); + if rndis_len > last_transfer - 44 { + NET_BUFFER_ACTIVE = true; + NET_BUFFER_LEN = rndis_len; + //reenable channel + DwcActivateChannel(channel); + return false; + } + // println!("| NEt continue"); + } else { + if last_transfer >= NET_BUFFER_LEN { + // println!("| NEt buffer finished length: {} NETBUFFER {}", last_transfer, NET_BUFFER_LEN); + NET_BUFFER_ACTIVE = false; + last_transfer = NET_BUFFER_LEN; + } else { + // println!("| Net buffer not yet active length: {} NETBUFFER {}", last_transfer, NET_BUFFER_LEN); + DwcActivateChannel(channel); + return false; + } + } + } + //TODO: Perhaps update this to pass the direct dma buffer address instead of copying // as it is likely that the callback will need to copy the data anyway // Also, we suffer issue from buffer_length not being known before the copy so the callback likely will have better information about the buffer @@ -115,16 +155,24 @@ pub fn finish_bulk_endpoint_callback_out( let transfer_size = HcdUpdateTransferSize(device, channel); let last_transfer = endpoint.buffer_length - transfer_size; - println!("Bulk out transfer hcint {:x} , last transfer: {} ", hcint, last_transfer); + println!( + "Bulk out transfer hcint {:x} , last transfer: {} ", + hcint, last_transfer + ); if hcint & HCINT_CHHLTD == 0 { - panic!("| Endpoint {}: HCINT_CHHLTD not set, aborting. bulk out hcint {:x}", channel, hcint); + panic!( + "| Endpoint {}: HCINT_CHHLTD not set, aborting. bulk out hcint {:x}", + channel, hcint + ); } if hcint & HCINT_XFERCOMPL == 0 { - panic!("| Endpoint {}: HCINT_XFERCOMPL not set, aborting. bulk out hcint {:x}", channel, hcint); + panic!( + "| Endpoint {}: HCINT_XFERCOMPL not set, aborting. bulk out hcint {:x}", + channel, hcint + ); } - //Most Likely not going to be called but could be useful for cases where precise timing of when message gets off the system is needed let endpoint_device = device.driver_data.downcast::().unwrap(); if let Some(callback) = endpoint_device.endpoints[endpoint.device_endpoint_number as usize] { @@ -178,7 +226,6 @@ pub fn finish_interrupt_endpoint_callback( // return true; } - hcint |= hcint_nochhltd; } diff --git a/crates/kernel/src/networking/iface/dhcp.rs b/crates/kernel/src/networking/iface/dhcp.rs index 12e62051..91c34ff7 100644 --- a/crates/kernel/src/networking/iface/dhcp.rs +++ b/crates/kernel/src/networking/iface/dhcp.rs @@ -1,6 +1,4 @@ use crate::device::system_timer; -use crate::ringbuffer::channel; -use crate::event::task::spawn_async; use crate::device::usb::device::net::get_interface_mut; use crate::networking::iface::Interface; @@ -43,7 +41,7 @@ pub struct Dhcpd { rebind_time: Option, subnet_mask: u32, router: Option, - dns_servers: Vec, + pub dns_servers: Vec, udp_socket: u16, } @@ -89,12 +87,12 @@ impl Dhcpd { self.state = DhcpState::Discovering; self.last_action_time = time; - send_dhcp_discover(interface, self.udp_socket, self.xid).await; + let _ = send_dhcp_discover(interface, self.udp_socket, self.xid).await; while self.state != DhcpState::Bound { let r = recv_from(self.udp_socket).await; let (payload, _) = r.unwrap(); - self.process_dhcp_packet(interface, payload).await; + let _ = self.process_dhcp_packet(interface, payload).await; } Ok(()) @@ -110,7 +108,8 @@ impl Dhcpd { if let (Some(server_id), Some(offered_ip)) = (self.server_identifier, self.offered_ip) { let result = - send_dhcp_release(interface, self.xid, offered_ip, server_id, self.udp_socket).await; + send_dhcp_release(interface, self.xid, offered_ip, server_id, self.udp_socket) + .await; self.state = DhcpState::Released; result } else { @@ -145,7 +144,8 @@ impl Dhcpd { self.last_action_time = system_timer::get_time(); self.retries = DEFAULT_LEASE_RETRY; - send_dhcp_request(interface, self.xid, offered_ip, server_id, self.udp_socket).await?; + send_dhcp_request(interface, self.xid, offered_ip, server_id, self.udp_socket) + .await?; // send_dhcp_packet_workaround(interface, self.xid, offered_ip, server_id, packet)?; } else { return Err(Error::Malformed); @@ -195,10 +195,11 @@ impl Dhcpd { self.last_action_time = system_timer::get_time(); println!( - "\t[+] DHCP: Bound to IP {} with lease time {} seconds on gateway {}", + "\t[+] DHCP: Bound to IP {} with lease time {} seconds on gateway {} with dns servers {:?}", interface.ipv4_addr, self.lease_time.unwrap_or(0), interface.default_gateway, + self.dns_servers ); } (DhcpState::Requesting, DhcpMessageType::Nak) @@ -404,7 +405,11 @@ pub async fn send_dhcp_release( send_dhcp_packet_unicast(interface, socketfd, &packet, server_id).await } -async fn send_dhcp_packet(interface: &mut Interface, socketfd: u16, packet: &DhcpPacket) -> Result<()> { +async fn send_dhcp_packet( + interface: &mut Interface, + socketfd: u16, + packet: &DhcpPacket, +) -> Result<()> { let data = packet.serialize(); let saddr = SocketAddr { addr: interface.ipv4_addr.broadcast(), diff --git a/crates/kernel/src/networking/iface/ethernet.rs b/crates/kernel/src/networking/iface/ethernet.rs index 6a85c295..80d011e6 100644 --- a/crates/kernel/src/networking/iface/ethernet.rs +++ b/crates/kernel/src/networking/iface/ethernet.rs @@ -31,9 +31,12 @@ pub fn send_ethernet_frame( Ok(()) } +// pub static mut FRAME: Vec = Vec::new(); +// pub static mut LEFT: u32 = 0; + // recv ethernet frame from interface: parsed -> fwd to socket -> propogated up stack pub fn recv_ethernet_frame(interface: &mut Interface, eth_buffer: &[u8], _len: u32) -> Result<()> { - println!("[!] received ethernet frame"); + // println!("[!] received ethernet frame"); // println!("\t{:x?}", ð_buffer[44..]); // we will truncate the first 44 bytes from the RNDIS protocol @@ -54,13 +57,13 @@ pub fn recv_ethernet_frame(interface: &mut Interface, eth_buffer: &[u8], _len: u }; // queue another recv to be run in the future - thread::thread(move || { - let buf = vec![0u8; 1500]; - unsafe { - let device = &mut *NET_DEVICE.device.unwrap(); - rndis_receive_packet(device, buf.into_boxed_slice(), 1500); - } - }); + // thread::thread(move || { + // let buf = vec![0u8; 1500]; + // unsafe { + // let device = &mut *NET_DEVICE.device.unwrap(); + + // } + // }); return result; } diff --git a/crates/kernel/src/networking/iface/ipv4.rs b/crates/kernel/src/networking/iface/ipv4.rs index 88e277fa..f23d5ce5 100644 --- a/crates/kernel/src/networking/iface/ipv4.rs +++ b/crates/kernel/src/networking/iface/ipv4.rs @@ -41,7 +41,7 @@ pub fn send_ipv4_packet( } pub fn recv_ip_packet(interface: &mut Interface, eth_frame: EthernetFrame) -> Result<()> { - println!("[!] received IP packet"); + // println!("[!] received IP packet"); let ipv4_packet = Ipv4Packet::deserialize(eth_frame.payload.as_slice())?; if !ipv4_packet.is_valid_checksum() { return Err(Error::Checksum); @@ -57,7 +57,6 @@ pub fn recv_ip_packet(interface: &mut Interface, eth_frame: EthernetFrame) -> Re return Err(Error::Ignored); } - // update arp cache for immediate ICMP echo replies, errors, etc. if eth_frame.src.is_unicast() { let mut arp_cache = interface.arp_cache.lock(); diff --git a/crates/kernel/src/networking/iface/mod.rs b/crates/kernel/src/networking/iface/mod.rs index a9f5011a..c7ecfb89 100644 --- a/crates/kernel/src/networking/iface/mod.rs +++ b/crates/kernel/src/networking/iface/mod.rs @@ -9,12 +9,10 @@ use crate::device::system_timer; use crate::sync::SpinLock; -use crate::networking::repr::{Device, EthernetAddress, Ipv4Address, Ipv4Cidr}; +use crate::networking::repr::{EthernetAddress, Ipv4Address, Ipv4Cidr}; use crate::networking::socket::TaggedSocket; use crate::networking::utils::arp_cache::ArpCache; -use alloc::boxed::Box; -use alloc::sync::Arc; use alloc::collections::btree_map::BTreeMap; pub mod arp; @@ -23,7 +21,6 @@ pub mod dhcp; pub mod ethernet; pub mod icmp; pub mod ipv4; -pub mod socket; pub mod tcp; pub mod udp; diff --git a/crates/kernel/src/networking/iface/socket.rs b/crates/kernel/src/networking/iface/socket.rs deleted file mode 100644 index a6e6e7d9..00000000 --- a/crates/kernel/src/networking/iface/socket.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::device::usb::device::net::get_interface_mut; -use crate::event::thread; -use crate::sync; - -use crate::networking::socket::TaggedSocket; - -use alloc::vec::Vec; - -// pub fn socket_send_loop() { -// println!("socket_send_loop"); -// let interface = get_interface_mut(); -// println!("try send"); -// -// let to_send: Vec<_> = { -// let mut sockets = interface.sockets.lock(); -// sockets -// .iter_mut() -// .map(|(_, socket)| socket as *mut TaggedSocket) -// .collect() -// }; -// println!("try send2"); -// -// -// for &socket_ptr in &to_send { -// let socket: &mut TaggedSocket = unsafe { &mut *socket_ptr }; -// let _ = socket.send(interface); -// } -// -// thread::thread(move || { -// sync::spin_sleep(500_000); -// socket_send_loop(); -// }); -// } diff --git a/crates/kernel/src/networking/iface/tcp.rs b/crates/kernel/src/networking/iface/tcp.rs index 95104e81..685ddc9f 100644 --- a/crates/kernel/src/networking/iface/tcp.rs +++ b/crates/kernel/src/networking/iface/tcp.rs @@ -1,11 +1,11 @@ use crate::networking::iface::*; use crate::networking::repr::*; -use crate::networking::socket::SocketAddr; use crate::networking::socket::SockType; +use crate::networking::socket::SocketAddr; use crate::networking::{Error, Result}; -use alloc::vec::Vec; use crate::event::task; +use alloc::vec::Vec; pub fn send_tcp_packet( interface: &mut Interface, @@ -65,7 +65,7 @@ pub fn recv_tcp_packet(interface: &mut Interface, ipv4_packet: Ipv4Packet) -> Re if stype != SockType::TCP { return Err(Error::Unsupported); } - + let mut payload = tcp_packet.payload.clone(); payload.extend_from_slice(&tcp_packet.seq_number.to_le_bytes()); payload.extend_from_slice(&tcp_packet.ack_number.to_le_bytes()); @@ -73,7 +73,7 @@ pub fn recv_tcp_packet(interface: &mut Interface, ipv4_packet: Ipv4Packet) -> Re payload.extend_from_slice(&tcp_packet.window_size.to_le_bytes()); task::spawn_async(async move { - let _ = tx.send((payload, sender_socket_addr)).await; + let _ = tx.send((payload, sender_socket_addr)).await; }); } } diff --git a/crates/kernel/src/networking/iface/udp.rs b/crates/kernel/src/networking/iface/udp.rs index 5120c8ec..e1df1076 100644 --- a/crates/kernel/src/networking/iface/udp.rs +++ b/crates/kernel/src/networking/iface/udp.rs @@ -2,7 +2,7 @@ use crate::networking::iface::*; use crate::networking::repr::*; use crate::networking::socket::SockType; use crate::networking::socket::SocketAddr; -use crate::networking::{Result, Error}; +use crate::networking::{Error, Result}; use crate::event::task; @@ -31,7 +31,7 @@ pub fn send_udp_packet( } pub fn recv_udp_packet(interface: &mut Interface, ipv4_packet: Ipv4Packet) -> Result<()> { - println!("\t received udp packet"); + // println!("\t received udp packet"); let udp_packet = UdpPacket::deserialize(ipv4_packet.payload.as_slice())?; let local_socket_addr = SocketAddr { @@ -45,7 +45,7 @@ pub fn recv_udp_packet(interface: &mut Interface, ipv4_packet: Ipv4Packet) -> Re }; // let mut sockets = interface.sockets.lock(); - + for (_, socket) in interface.sockets.iter_mut() { if socket.binding_equals(local_socket_addr) { let (stype, mut tx) = socket.get_send_ref(); diff --git a/crates/kernel/src/networking/mod.rs b/crates/kernel/src/networking/mod.rs index 0c330aac..df4d04d5 100644 --- a/crates/kernel/src/networking/mod.rs +++ b/crates/kernel/src/networking/mod.rs @@ -46,6 +46,7 @@ pub enum Error { Checksum, Timeout, NotImplemented, + Closed, } pub type Result = CoreResult; diff --git a/crates/kernel/src/networking/repr/dns.rs b/crates/kernel/src/networking/repr/dns.rs new file mode 100644 index 00000000..ea7e30b0 --- /dev/null +++ b/crates/kernel/src/networking/repr/dns.rs @@ -0,0 +1,284 @@ +use crate::networking::repr::Ipv4Address; +use crate::networking::{Error, Result}; +use alloc::string::String; +use alloc::vec; +use alloc::vec::Vec; +use byteorder::{ByteOrder, NetworkEndian}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsHeader { + pub id: u16, + pub flags: u16, + pub qdcount: u16, + pub ancount: u16, + pub nscount: u16, + pub arcount: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsQuestion { + pub qname: String, + pub qtype: u16, + pub qclass: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsRecord { + pub name: String, + pub rtype: u16, + pub rclass: u16, + pub ttl: u32, + pub rdata: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Packet { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub additionals: Vec, +} + +impl Packet { + pub fn create_dns_query(domain: &str) -> Self { + let header = DnsHeader { + id: 0x1337, + flags: 0x0100, // Standard query (QR=0), recursion desired (RD=1) + qdcount: 1, // One question + ancount: 0, // No answers initially + nscount: 0, // No authorities initially + arcount: 0, // No additionals initially + }; + + let question = DnsQuestion { + qname: String::from(domain), + qtype: 1, // Type A (IPv4 address) + qclass: 1, // Class IN (Internet) + }; + + Packet { + header, + questions: vec![question], + answers: Vec::new(), + authorities: Vec::new(), + additionals: Vec::new(), + } + } + + pub fn extract_ip_address(&self) -> Option { + for record in &self.answers { + if record.rtype == 1 && record.rdata.len() == 4 { + let ip = Ipv4Address::new([ + record.rdata[0], + record.rdata[1], + record.rdata[2], + record.rdata[3], + ]); + return Some(ip); + } + } + + return None; + } + + pub fn deserialize(buffer: &[u8]) -> Result { + if buffer.len() < 12 { + return Err(Error::Malformed); + } + + let id = NetworkEndian::read_u16(&buffer[0..2]); + let flags = NetworkEndian::read_u16(&buffer[2..4]); + let qdcount = NetworkEndian::read_u16(&buffer[4..6]); + let ancount = NetworkEndian::read_u16(&buffer[6..8]); + let nscount = NetworkEndian::read_u16(&buffer[8..10]); + let arcount = NetworkEndian::read_u16(&buffer[10..12]); + + let mut offset = 12; + let mut questions = Vec::new(); + for _ in 0..qdcount { + let (qname, next_offset) = read_qname(buffer, offset)?; + offset = next_offset; + if offset + 4 > buffer.len() { + return Err(Error::Malformed); + } + let qtype = NetworkEndian::read_u16(&buffer[offset..offset + 2]); + let qclass = NetworkEndian::read_u16(&buffer[offset + 2..offset + 4]); + offset += 4; + questions.push(DnsQuestion { + qname, + qtype, + qclass, + }); + } + + let mut answers = Vec::new(); + for _ in 0..ancount { + println!("\t[!]ANSWER offset {}", offset); + let (record, next_offset) = read_record(buffer, offset)?; + answers.push(record); + offset = next_offset; + } + + let mut authorities = Vec::new(); + for _ in 0..nscount { + let (record, next_offset) = read_record(buffer, offset)?; + authorities.push(record); + offset = next_offset; + } + + let mut additionals = Vec::new(); + for _ in 0..arcount { + let (record, next_offset) = read_record(buffer, offset)?; + additionals.push(record); + offset = next_offset; + } + + Ok(Packet { + header: DnsHeader { + id, + flags, + qdcount, + ancount, + nscount, + arcount, + }, + questions, + answers, + authorities, + additionals, + }) + } + + pub fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + + let mut header = [0u8; 12]; + NetworkEndian::write_u16(&mut header[0..2], self.header.id); + NetworkEndian::write_u16(&mut header[2..4], self.header.flags); + NetworkEndian::write_u16(&mut header[4..6], self.questions.len() as u16); + NetworkEndian::write_u16(&mut header[6..8], self.answers.len() as u16); + NetworkEndian::write_u16(&mut header[8..10], self.authorities.len() as u16); + NetworkEndian::write_u16(&mut header[10..12], self.additionals.len() as u16); + buffer.extend_from_slice(&header); + + for question in &self.questions { + write_qname(&mut buffer, &question.qname); + let mut qinfo = [0u8; 4]; + NetworkEndian::write_u16(&mut qinfo[0..2], question.qtype); + NetworkEndian::write_u16(&mut qinfo[2..4], question.qclass); + buffer.extend_from_slice(&qinfo); + } + + for record in &self.answers { + write_record(&mut buffer, record); + } + for record in &self.authorities { + write_record(&mut buffer, record); + } + for record in &self.additionals { + write_record(&mut buffer, record); + } + + buffer + } +} + +fn read_qname(buffer: &[u8], mut offset: usize) -> Result<(String, usize)> { + let mut labels = Vec::new(); + let mut jumped = false; + let original_offset = offset; + + loop { + if offset >= buffer.len() { + return Err(Error::Malformed); + } + + let len = buffer[offset]; + if len & 0xC0 == 0xC0 { + // Pointer to another location + if offset + 1 >= buffer.len() { + return Err(Error::Malformed); + } + let pointer = (((len & 0x3F) as usize) << 8) | buffer[offset + 1] as usize; + if pointer >= buffer.len() { + return Err(Error::Malformed); + } + let (suffix, _) = read_qname(buffer, pointer)?; + labels.push(suffix); + offset += 2; + jumped = true; + break; + } else if len == 0 { + offset += 1; + break; + } else { + offset += 1; + if offset + len as usize > buffer.len() { + return Err(Error::Malformed); + } + let label = core::str::from_utf8(&buffer[offset..offset + len as usize]) + .map_err(|_| Error::Malformed)?; + labels.push(String::from(label)); + offset += len as usize; + } + } + + if !jumped { + Ok((labels.join("."), offset)) + } else { + Ok((labels.join("."), original_offset + 2)) // after pointer + } +} + +fn write_qname(buffer: &mut Vec, name: &str) { + for part in name.split('.') { + buffer.push(part.len() as u8); + buffer.extend_from_slice(part.as_bytes()); + } + buffer.push(0); +} + +fn read_record(buffer: &[u8], offset: usize) -> Result<(DnsRecord, usize)> { + let (name, mut pos) = read_qname(buffer, offset)?; + if pos + 10 > buffer.len() { + println!("malformed "); + return Err(Error::Malformed); + } + let rtype = NetworkEndian::read_u16(&buffer[pos..pos + 2]); + let rclass = NetworkEndian::read_u16(&buffer[pos + 2..pos + 4]); + let ttl = NetworkEndian::read_u32(&buffer[pos + 4..pos + 8]); + let rdlength = NetworkEndian::read_u16(&buffer[pos + 8..pos + 10]) as usize; + pos += 10; + + if pos + rdlength > buffer.len() { + println!("malformed "); + return Err(Error::Malformed); + } + let rdata = buffer[pos..pos + rdlength].to_vec(); + pos += rdlength; + + println!("made it"); + + Ok(( + DnsRecord { + name, + rtype, + rclass, + ttl, + rdata, + }, + pos, + )) +} + +fn write_record(buffer: &mut Vec, record: &DnsRecord) { + write_qname(buffer, &record.name); + let mut rinfo = [0u8; 10]; + NetworkEndian::write_u16(&mut rinfo[0..2], record.rtype); + NetworkEndian::write_u16(&mut rinfo[2..4], record.rclass); + NetworkEndian::write_u32(&mut rinfo[4..8], record.ttl); + NetworkEndian::write_u16(&mut rinfo[8..10], record.rdata.len() as u16); + buffer.extend_from_slice(&rinfo); + buffer.extend_from_slice(&record.rdata); +} diff --git a/crates/kernel/src/networking/repr/http.rs b/crates/kernel/src/networking/repr/http.rs new file mode 100644 index 00000000..1a012355 --- /dev/null +++ b/crates/kernel/src/networking/repr/http.rs @@ -0,0 +1,206 @@ +use crate::networking::{Error, Result}; +use alloc::string::{String, ToString}; +use alloc::vec::Vec; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Method { + Get, + Post, + Put, + Delete, + Head, + Options, + Patch, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Header { + pub name: String, + pub value: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Packet { + pub method: Method, + pub path: String, + pub version: String, + pub headers: Vec
, + pub body: Vec, +} + +impl Packet { + pub fn new(method: Method, host: &str, path: &str) -> Self { + let mut headers = Vec::new(); + headers.push(Header { + name: "Host".to_string(), + value: host.to_string(), + }); + headers.push(Header { + name: "User-Agent".to_string(), + value: "curl/8.13.0".to_string(), + }); + headers.push(Header { + name: "Accept".to_string(), + value: "*/*".to_string(), + }); + Packet { + method, + path: path.to_string(), + version: "HTTP/1.1".to_string(), + headers, + body: Vec::new(), + } + } + + pub fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + + // Start-line + let method_str = match self.method { + Method::Get => "GET", + Method::Post => "POST", + Method::Put => "PUT", + Method::Delete => "DELETE", + Method::Head => "HEAD", + Method::Options => "OPTIONS", + Method::Patch => "PATCH", + }; + + buffer.extend_from_slice(method_str.as_bytes()); + buffer.push(b' '); + buffer.extend_from_slice(self.path.as_bytes()); + buffer.push(b' '); + buffer.extend_from_slice(self.version.as_bytes()); + buffer.extend_from_slice(b"\r\n"); + + // Headers + for header in &self.headers { + buffer.extend_from_slice(header.name.as_bytes()); + buffer.extend_from_slice(b": "); + buffer.extend_from_slice(header.value.as_bytes()); + buffer.extend_from_slice(b"\r\n"); + } + + // End of headers + buffer.extend_from_slice(b"\r\n"); + + // Body + buffer.extend_from_slice(&self.body); + + buffer + } + + pub fn deserialize(buffer: &[u8]) -> Result { + let mut headers = Vec::new(); + let mut pos = 0; + + // Parse request line + let request_line_end = find_crlf(buffer, pos).ok_or(Error::Malformed)?; + let request_line = &buffer[pos..request_line_end]; + let parts = split_ascii_whitespace(request_line); + + if parts.len() != 3 { + return Err(Error::Malformed); + } + + let method = match parts[0] { + b"GET" => Method::Get, + b"POST" => Method::Post, + b"PUT" => Method::Put, + b"DELETE" => Method::Delete, + b"HEAD" => Method::Head, + b"OPTIONS" => Method::Options, + b"PATCH" => Method::Patch, + _ => return Err(Error::Malformed), + }; + + let path = String::from_utf8(parts[1].to_vec()).map_err(|_| Error::Malformed)?; + let version = String::from_utf8(parts[2].to_vec()).map_err(|_| Error::Malformed)?; + + pos = request_line_end + 2; + + // Parse headers + loop { + if pos >= buffer.len() { + return Err(Error::Malformed); + } + + if buffer[pos..].starts_with(b"\r\n") { + pos += 2; + break; + } + + let header_end = find_crlf(buffer, pos).ok_or(Error::Malformed)?; + let header_line = &buffer[pos..header_end]; + + if let Some(colon_pos) = header_line.iter().position(|&b| b == b':') { + let name = String::from_utf8(header_line[..colon_pos].to_vec()) + .map_err(|_| Error::Malformed)?; + let value = String::from_utf8(header_line[colon_pos + 1..].to_vec()) + .map_err(|_| Error::Malformed)? + .trim() + .to_string(); + headers.push(Header { name, value }); + } else { + return Err(Error::Malformed); + } + + pos = header_end + 2; + } + + let body = buffer[pos..].to_vec(); + + Ok(Packet { + method, + path, + version, + headers, + body, + }) + } + + pub fn get_header(&self, name: &str) -> Option<&str> { + for header in &self.headers { + if header.name.eq_ignore_ascii_case(name) { + return Some(&header.value); + } + } + None + } + + pub fn content_length(&self) -> Option { + if let Some(value) = self.get_header("Content-Length") { + value.parse().ok() + } else { + None + } + } +} + +// Helper: Find \r\n (CRLF) sequence starting from position +fn find_crlf(buffer: &[u8], start: usize) -> Option { + buffer[start..] + .windows(2) + .position(|window| window == b"\r\n") + .map(|p| start + p) +} + +// Helper: Split a line into whitespace-separated parts +fn split_ascii_whitespace(line: &[u8]) -> Vec<&[u8]> { + let mut parts = Vec::new(); + let mut start = None; + for (i, &b) in line.iter().enumerate() { + if b.is_ascii_whitespace() { + if let Some(s) = start { + parts.push(&line[s..i]); + start = None; + } + } else if start.is_none() { + start = Some(i); + } + } + if let Some(s) = start { + parts.push(&line[s..]); + } + parts +} diff --git a/crates/kernel/src/networking/repr/mod.rs b/crates/kernel/src/networking/repr/mod.rs index 2767b896..03af254a 100644 --- a/crates/kernel/src/networking/repr/mod.rs +++ b/crates/kernel/src/networking/repr/mod.rs @@ -20,7 +20,9 @@ mod arp; pub mod dev; mod dhcp; +mod dns; mod ethernet; +mod http; mod icmp; mod ipv4; mod tcp; @@ -43,6 +45,10 @@ pub use self::icmp::{ pub use self::udp::Packet as UdpPacket; +pub use self::dns::Packet as DnsPacket; + +pub use self::http::{Method as HttpMethod, Packet as HttpPacket}; + pub use self::dhcp::{DhcpOption, DhcpParam, MessageType as DhcpMessageType, Packet as DhcpPacket}; pub use self::tcp::{Flags as TcpFlags, Packet as TcpPacket}; diff --git a/crates/kernel/src/networking/socket/bindings.rs b/crates/kernel/src/networking/socket/bindings.rs index 2e5f1d58..078be213 100644 --- a/crates/kernel/src/networking/socket/bindings.rs +++ b/crates/kernel/src/networking/socket/bindings.rs @@ -4,11 +4,12 @@ use core::sync::atomic::{AtomicU16, Ordering}; use alloc::vec::Vec; -use crate::networking::repr::Ipv4Address; +use crate::networking::repr::{DnsPacket, Ipv4Address}; use crate::networking::{Error, Result}; -use crate::device::usb::device::net::get_interface_mut; -use crate::event::task::spawn_async; +use crate::device::usb::device::net::{get_dhcpd_mut, get_interface_mut}; + +use super::UdpSocket; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub struct SocketAddr { @@ -23,6 +24,29 @@ impl SocketAddr { port: 0, } } + + pub async fn resolve(host: &str, port: u16) -> Self { + let dhcp = get_dhcpd_mut(); + let dns_socket = UdpSocket::new(); + let _ = bind(dns_socket, 53); + + let dns_req = DnsPacket::create_dns_query(host); + + let saddr = SocketAddr { + addr: dhcp.dns_servers[0], + port: 53, + }; + + let _ = send_to(dns_socket, dns_req.serialize(), saddr).await; + + let (resp, _) = recv_from(dns_socket).await.unwrap(); + let dhcp_resp = DnsPacket::deserialize(&resp).unwrap(); + + SocketAddr { + addr: dhcp_resp.extract_ip_address().unwrap(), + port, + } + } } impl Display for SocketAddr { @@ -35,7 +59,7 @@ impl Display for SocketAddr { pub enum SockType { UDP, TCP, - Raw + Raw, } // TODO: these technically runs out eventually lol need wrap around @@ -47,7 +71,8 @@ pub async fn send_to(socketfd: u16, payload: Vec, saddr: SocketAddr) -> Resu // let mut sockets = interface.sockets.lock(); // 1. check if socket fd is valid if not return error - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; @@ -66,7 +91,8 @@ pub async fn recv_from(socketfd: u16) -> Result<(Vec, SocketAddr)> { let interface = get_interface_mut(); // 1. check if a socketfd is valid if not return error - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; @@ -76,7 +102,7 @@ pub async fn recv_from(socketfd: u16) -> Result<(Vec, SocketAddr)> { } // 3. blocking recv from socket recv queue - + tagged_socket.recv().await } @@ -85,7 +111,8 @@ pub async fn connect(socketfd: u16, saddr: SocketAddr) -> Result<()> { // let mut sockets = interface.sockets.lock(); // 1. check if a socketfd is valid if not return error - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; @@ -97,7 +124,8 @@ pub async fn listen(socketfd: u16, num_requests: usize) -> Result<()> { // 1.check if binded, if not error // let mut sockets = interface.sockets.lock(); - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; @@ -114,15 +142,30 @@ pub async fn accept(socketfd: u16) -> Result { // 1. if listener not started, error // let mut sockets = interface.sockets.lock(); - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; // 2. accept 1 connection, error if no pending connections - tagged_socket.accept().await; + let _ = tagged_socket.accept().await; Ok(socketfd) } +pub async fn close(socketfd: u16) -> Result<()> { + let interface = get_interface_mut(); + // 1. if listener not started, error + // let mut sockets = interface.sockets.lock(); + + let tagged_socket = interface + .sockets + .get_mut(&socketfd) + .ok_or(Error::InvalidSocket(socketfd))?; + + // 2. accept 1 connection, error if no pending connections + tagged_socket.close().await +} + pub fn bind(socketfd: u16, port: u16) -> Result<()> { let interface = get_interface_mut(); // 1. check if binding is already in use by another socket @@ -138,7 +181,8 @@ pub fn bind(socketfd: u16, port: u16) -> Result<()> { } // 2. check if this is a valid socketfd - let tagged_socket = interface.sockets + let tagged_socket = interface + .sockets .get_mut(&socketfd) .ok_or(Error::InvalidSocket(socketfd))?; diff --git a/crates/kernel/src/networking/socket/mod.rs b/crates/kernel/src/networking/socket/mod.rs index e990e159..808487aa 100644 --- a/crates/kernel/src/networking/socket/mod.rs +++ b/crates/kernel/src/networking/socket/mod.rs @@ -4,7 +4,9 @@ pub mod tagged; pub mod tcp; pub mod udp; -pub use self::bindings::{bind, accept, listen, connect, recv_from, send_to, SocketAddr, SockType}; +pub use self::bindings::{ + accept, bind, close, connect, listen, recv_from, send_to, SockType, SocketAddr, +}; pub use self::tagged::TaggedSocket; diff --git a/crates/kernel/src/networking/socket/raw.rs b/crates/kernel/src/networking/socket/raw.rs index 4cda7182..abab0293 100644 --- a/crates/kernel/src/networking/socket/raw.rs +++ b/crates/kernel/src/networking/socket/raw.rs @@ -1,12 +1,11 @@ -use crate::networking::repr::Ipv4Protocol; -use crate::networking::utils::{ring::Ring, slice::Slice}; +use crate::device::usb::device::net::get_interface_mut; use crate::networking::iface::ipv4; use crate::networking::iface::Interface; -use crate::device::usb::device::net::get_interface_mut; -use crate::networking::{Result, Error}; -use crate::ringbuffer::{channel, Sender, Receiver}; -use crate::networking::socket::tagged::{TaggedSocket, BUFFER_LEN}; -use crate::networking::socket::{SocketAddr, SockType}; +use crate::networking::repr::Ipv4Protocol; +use crate::networking::socket::tagged::BUFFER_LEN; +use crate::networking::socket::{SockType, SocketAddr}; +use crate::networking::{Error, Result}; +use crate::ringbuffer::{channel, Receiver, Sender}; use alloc::vec::Vec; @@ -26,9 +25,7 @@ pub struct RawSocket { } impl RawSocket { - pub fn new( - raw_type: RawType, - ) -> RawSocket { + pub fn new(raw_type: RawType) -> RawSocket { let (recv_tx, recv_rx) = channel::, SocketAddr)>(); let interface = get_interface_mut(); RawSocket { @@ -60,7 +57,12 @@ impl RawSocket { self.binding = bind_addr; } - pub async fn send_enqueue(&mut self, payload: Vec, proto: Ipv4Protocol, dest: SocketAddr) -> Result<()> { + pub async fn send_enqueue( + &mut self, + payload: Vec, + proto: Ipv4Protocol, + dest: SocketAddr, + ) -> Result<()> { println!("enqueud send"); let interface = get_interface_mut(); diff --git a/crates/kernel/src/networking/socket/tagged.rs b/crates/kernel/src/networking/socket/tagged.rs index d5b91254..f3f477c3 100644 --- a/crates/kernel/src/networking/socket/tagged.rs +++ b/crates/kernel/src/networking/socket/tagged.rs @@ -1,5 +1,4 @@ -use crate::networking::iface::Interface; -use crate::networking::socket::{SocketAddr, TcpSocket, UdpSocket, SockType}; +use crate::networking::socket::{SockType, SocketAddr, TcpSocket, UdpSocket}; use crate::networking::{Error, Result}; use crate::device::usb::device::net::get_interface_mut; @@ -37,7 +36,7 @@ impl TaggedSocket { match self { // TaggedSocket::Raw(socket) => socket.queue_send(payload, saddr), TaggedSocket::Udp(socket) => socket.send_enqueue(payload, saddr).await, - TaggedSocket::Tcp(socket) => socket.send_enqueue(payload, saddr), + TaggedSocket::Tcp(socket) => socket.send_enqueue(payload, saddr).await, } } @@ -54,7 +53,9 @@ impl TaggedSocket { // TaggedSocket::Raw(socket) => socket.queue_recv(payload, saddr), TaggedSocket::Udp(socket) => socket.recv_enqueue(payload, saddr).await, TaggedSocket::Tcp(socket) => { - socket.recv_enqueue(seq_num, ack_num, flags, window_size, payload, saddr).await + socket + .recv_enqueue(seq_num, ack_num, flags, window_size, payload, saddr) + .await } } } @@ -109,6 +110,14 @@ impl TaggedSocket { } } + pub async fn close(&mut self) -> Result<()> { + match self { + // TaggedSocket::Raw(socket) => socket.recv(), + TaggedSocket::Udp(_socket) => Err(Error::Ignored), + TaggedSocket::Tcp(socket) => socket.close().await, + } + } + pub async fn accept(&mut self) -> Result { match self { // TaggedSocket::Raw(socket) => socket.recv(), diff --git a/crates/kernel/src/networking/socket/tcp.rs b/crates/kernel/src/networking/socket/tcp.rs index 3f6e83c7..2f4de1d1 100644 --- a/crates/kernel/src/networking/socket/tcp.rs +++ b/crates/kernel/src/networking/socket/tcp.rs @@ -1,21 +1,12 @@ use crate::device::usb::device::net::get_interface_mut; use crate::networking::iface::{tcp, Interface}; -use crate::networking::repr::TcpPacket; use crate::networking::socket::bindings::{NEXT_EPHEMERAL, NEXT_SOCKETFD}; use crate::networking::socket::tagged::{TaggedSocket, BUFFER_LEN}; -use crate::networking::socket::{SocketAddr, SockType}; -use crate::networking::utils::ring::Ring; +use crate::networking::socket::{SockType, SocketAddr}; use crate::networking::{Error, Result}; -use alloc::vec; +use crate::ringbuffer::{channel, Receiver, Sender}; use alloc::vec::Vec; use core::sync::atomic::Ordering; -use crate::ringbuffer::{channel, Sender, Receiver}; - -fn new_ring_packet_buffer(capacity: usize) -> Ring<(Vec, SocketAddr)> { - let default_entry = (Vec::new(), SocketAddr::default()); - let buffer = vec![default_entry; capacity]; - Ring::from(buffer) -} // flags pub const TCP_FLAG_FIN: u8 = 0x01; @@ -54,7 +45,6 @@ pub struct TcpSocket { recv_rx: Receiver, SocketAddr)>, // recvp_tx: Sender, // recvp_rx: Receiver, - state: TcpState, remote_addr: Option, seq_number: u32, @@ -82,8 +72,6 @@ impl TcpSocket { recv_rx, // recvp_tx, // recvp_rx, - - state: TcpState::Closed, remote_addr: None, seq_number: INITIAL_SEQ_NUMBER, @@ -94,12 +82,14 @@ impl TcpSocket { let socketfd = NEXT_SOCKETFD.fetch_add(1, Ordering::SeqCst); // let mut sockets = interface.sockets.lock(); - interface.sockets.insert(socketfd, TaggedSocket::Tcp(socket)); + interface + .sockets + .insert(socketfd, TaggedSocket::Tcp(socket)); socketfd } pub fn binding_equals(&self, saddr: SocketAddr) -> bool { - println!("binding port {} provided port {}", self.binding.port, saddr.port); + // println!("binding port {} provided port {}", self.binding.port, saddr.port); self.binding == saddr } @@ -124,7 +114,11 @@ impl TcpSocket { self.binding = bind_addr; } - pub async fn listen(&mut self, interface: &mut Interface, num_max_requests: usize) -> Result<()> { + pub async fn listen( + &mut self, + interface: &mut Interface, + num_max_requests: usize, + ) -> Result<()> { println!("in listen"); if !self.is_bound { // bind to ephemeral if not bound @@ -176,19 +170,19 @@ impl TcpSocket { self.ack_number += 1; println!("ack number {}", self.ack_number); - tcp::send_tcp_packet( + let _ = tcp::send_tcp_packet( interface, self.binding.port, addr.port, self.seq_number, - self.ack_number, + self.ack_number, TCP_FLAG_ACK | TCP_FLAG_SYN, self.window_size, addr.addr, - Vec::new(), + Vec::new(), ); - self.recv().await; + let _ = self.recv().await; self.state = TcpState::Established; @@ -212,7 +206,7 @@ impl TcpSocket { self.remote_addr = Some(saddr); let flags = TCP_FLAG_SYN; - tcp::send_tcp_packet( + let _ = tcp::send_tcp_packet( interface, self.binding.port, saddr.port, @@ -227,224 +221,81 @@ impl TcpSocket { self.state = TcpState::SynSent; println!("[!] sent syn"); - - let _ = self.recv().await; - - self.seq_number += 1; - - let flags = TCP_FLAG_SYN & TCP_FLAG_ACK; - tcp::send_tcp_packet( - interface, - self.binding.port, - saddr.port, - self.seq_number, - self.ack_number, - flags, - self.window_size, - saddr.addr, - Vec::new(), - ); - self.state = TcpState::Established; + let _ = self.recv().await?; Ok(()) } - pub fn send_enqueue(&mut self, payload: Vec, dest: SocketAddr) -> Result<()> { + pub async fn send_enqueue(&mut self, payload: Vec, dest: SocketAddr) -> Result<()> { if self.state != TcpState::Established { return Err(Error::NotConnected); } - // verify the destination matches the connected remote address - if let Some(remote) = self.remote_addr { - if remote != dest { - return Err(Error::NotConnected); - } - } else { - return Err(Error::NotConnected); - } + // if let Some(remote) = self.remote_addr { + // if remote != dest { + // return Err(Error::NotConnected); + // } + // } else { + // return Err(Error::NotConnected); + // } let interface = get_interface_mut(); - tcp::send_tcp_packet(interface, self.binding.port, dest.port, self.seq_number, self.ack_number, self.flags, self.window_size, dest.addr, payload) + let _ = tcp::send_tcp_packet( + interface, + self.binding.port, + dest.port, + self.seq_number, + self.ack_number, + // TCP_FLAG_PSH | TCP_FLAG_ACK, + TCP_FLAG_ACK, + self.window_size, + dest.addr, + payload, + ); + + let _ = self.recv().await; + + Ok(()) } pub async fn recv(&mut self) -> Result<(Vec, SocketAddr)> { + let interface = get_interface_mut(); + let (mut payload, addr) = self.recv_rx.recv().await; - payload.truncate(payload.len().saturating_sub(11)); // get rid of context bytes let last = &payload[payload.len() - 2..payload.len()]; self.window_size = u16::from_le_bytes([last[0], last[1]]); - self.flags = payload[payload.len() - 3]; - // let last = &payload[payload.len() - 7..payload.len() - 3]; - // self.ack_number = u32::from_le_bytes([last[0], last[1], last[2], last[3]]); + let flags = payload[payload.len() - 3]; + let last = &payload[payload.len() - 7..payload.len() - 3]; + let ack_number = u32::from_le_bytes([last[0], last[1], last[2], last[3]]); let last = &payload[payload.len() - 11..payload.len() - 7]; - self.ack_number = u32::from_le_bytes([last[0], last[1], last[2], last[3]]); - - if self.state == TcpState::Established { - let interface = get_interface_mut(); - self.ack_number += payload.len() as u32; - self.seq_number += 1; - println!("sending ack to a received packet"); - tcp::send_tcp_packet( - interface, - self.binding.port, - addr.port, - self.seq_number, - self.ack_number, - TCP_FLAG_ACK, - self.window_size, - addr.addr, - Vec::new() - ); - } + let seq_number = u32::from_le_bytes([last[0], last[1], last[2], last[3]]); - Ok((payload, addr)) - } - - pub async fn recv_with_context(&mut self) -> Result<(Vec, SocketAddr)> { - let (payload, addr) = self.recv_rx.recv().await; - Ok((payload, addr)) - } - - // Enqueues a packet for receiving and handles TCP state machine - pub async fn recv_enqueue( - &mut self, - seq_number: u32, - ack_number: u32, - flags: u8, - window_size: u16, - payload: Vec, - sender: SocketAddr, - ) -> Result<()> { - println!("got a recv_enqueue"); - let mut payload_with_context = payload.clone(); - payload_with_context.extend_from_slice(&seq_number.to_le_bytes()); - payload_with_context.extend_from_slice(&ack_number.to_le_bytes()); - payload_with_context.push(flags); - payload_with_context.extend_from_slice(&window_size.to_le_bytes()); - - // let packet = TcpPacket::new(sender.port, self.binding.port, seq_number, ack_number, flags, window_size, payload, sender.addr, self.binding.addr); - self.recv_tx.send((payload_with_context, sender)).await; - Ok(()) - - // if self.state == TcpState::SynSent { - // // Check if this is a valid remote endpoint response - // if let Some(remote) = self.remote_addr { - // if remote == sender { - // // This is a response to our SYN - // if (flags & (TCP_FLAG_SYN | TCP_FLAG_ACK)) == (TCP_FLAG_SYN | TCP_FLAG_ACK) { - // // Received SYN-ACK, update ACK number - // self.ack_number = seq_number + 1; - // - // // Send ACK to complete three-way handshake - // tcp::send_tcp_packet( - // interface, - // self.binding.port, - // sender.port, - // self.seq_number + 1, // SYN consumes one sequence number - // self.ack_number, - // TCP_FLAG_ACK, - // self.window_size, - // sender.addr, - // Vec::new(), - // )?; - // - // // Update state and sequence number - // self.state = TcpState::Established; - // self.connected = true; - // self.seq_number += 1; // SYN consumes one sequence number - // - // return Ok(()); - // } - // } - // } - // } else if self.state == TcpState::Closed { - // // Handle incoming SYN for passive open (if we're listening) - // if flags & TCP_FLAG_SYN != 0 && flags & TCP_FLAG_ACK == 0 { - // // This would be for a server socket - not handling passive open in this example - // // But this is where you would handle it - // } - // } - // - // // Process state transitions based on TCP flags - // self.process_tcp_state_transitions(interface, flags, seq_number, ack_number, sender)?; - // - // // Now that we've handled any state transitions, enqueue actual data for user - // // Only enqueue if we're in established state and there's actual data - // if self.state == TcpState::Established && (payload.len() > 0) { - // // Only enqueue if there's actual data - // self.recv_buffer.enqueue_maybe(|(buffer, addr)| { - // *buffer = payload.clone(); - // *addr = sender; - // Ok(()) - // })?; - // - // // Update ACK number and send ACK for the data - // self.ack_number = seq_number + payload.len() as u32; - // - // // Send ACK for received data - // if let Some(remote) = self.remote_addr { - // tcp::send_tcp_packet( - // interface, - // self.binding.port, - // remote.port, - // self.seq_number, - // self.ack_number, - // TCP_FLAG_ACK, - // self.window_size, - // remote.addr, - // Vec::new(), - // )?; - // } - // } - - // Ok(()) - } + payload.truncate(payload.len().saturating_sub(11)); // get rid of context bytes - // Helper function to process TCP state transitions based on packet flags - fn process_tcp_state_transitions( - &mut self, - interface: &mut Interface, - flags: u8, - seq_number: u32, - _ack_number: u32, - _sender: SocketAddr, - ) -> Result<()> { match self.state { TcpState::SynSent => { - if flags & TCP_FLAG_ACK != 0 { - } - }, - TcpState::Established => { - // Handle FIN from remote - if flags & TCP_FLAG_FIN != 0 { - self.ack_number = seq_number + 1; // FIN consumes a sequence number - - // Send ACK for FIN - if let Some(remote) = self.remote_addr { - tcp::send_tcp_packet( - interface, - self.binding.port, - remote.port, - self.seq_number, - self.ack_number, - TCP_FLAG_ACK, - self.window_size, - remote.addr, - Vec::new(), - )?; - } + if flags & TCP_FLAG_RST != 0 { + return Err(Error::Closed); + } else if flags & (TCP_FLAG_ACK | TCP_FLAG_SYN) != 0 { + self.seq_number += 1; + self.ack_number = seq_number + 1; - self.state = TcpState::CloseWait; - } + tcp::send_tcp_packet( + interface, + self.binding.port, + addr.port, + self.seq_number, + self.ack_number, + TCP_FLAG_ACK, + self.window_size, + addr.addr, + Vec::new(), + )?; - // Handle RST from remote - if flags & TCP_FLAG_RST != 0 { - self.state = TcpState::Closed; - self.connected = false; - self.remote_addr = None; + self.state = TcpState::Established; } } - TcpState::FinWait1 => { if flags & TCP_FLAG_ACK != 0 { // Our FIN was acknowledged @@ -474,12 +325,10 @@ impl TcpSocket { } } } - TcpState::FinWait2 => { if flags & TCP_FLAG_FIN != 0 { self.ack_number = seq_number + 1; - // Send ACK for their FIN if let Some(remote) = self.remote_addr { tcp::send_tcp_packet( interface, @@ -495,23 +344,86 @@ impl TcpSocket { } self.state = TcpState::TimeWait; - // In a real implementation, start the TIME_WAIT timer here } } - TcpState::LastAck => { if flags & TCP_FLAG_ACK != 0 { - // Final ACK received, connection fully closed self.state = TcpState::Closed; self.connected = false; self.remote_addr = None; } } + TcpState::Established => { + let interface = get_interface_mut(); + + if flags & TCP_FLAG_FIN != 0 { + self.state = TcpState::Closed; + + self.ack_number += 1; + let _ = tcp::send_tcp_packet( + interface, + self.binding.port, + addr.port, + self.seq_number, + self.ack_number, + TCP_FLAG_ACK, + self.window_size, + addr.addr, + Vec::new(), + ); + + return Err(Error::Closed); + } - // Handle other states as needed + if payload.len() > 0 { + self.ack_number += payload.len() as u32; + self.seq_number += 1; + let _ = tcp::send_tcp_packet( + interface, + self.binding.port, + addr.port, + self.seq_number, + self.ack_number, + TCP_FLAG_ACK, + self.window_size, + addr.addr, + Vec::new(), + ); + } else { + let old_seq = self.seq_number; + self.seq_number = ack_number; + self.ack_number = old_seq + payload.len() as u32; + } + } _ => {} } + Ok((payload, addr)) + } + + pub async fn recv_with_context(&mut self) -> Result<(Vec, SocketAddr)> { + let (payload, addr) = self.recv_rx.recv().await; + Ok((payload, addr)) + } + + // Enqueues a packet for receiving and handles TCP state machine + pub async fn recv_enqueue( + &mut self, + seq_number: u32, + ack_number: u32, + flags: u8, + window_size: u16, + payload: Vec, + sender: SocketAddr, + ) -> Result<()> { + println!("got a recv_enqueue"); + let mut payload_with_context = payload.clone(); + payload_with_context.extend_from_slice(&seq_number.to_le_bytes()); + payload_with_context.extend_from_slice(&ack_number.to_le_bytes()); + payload_with_context.push(flags); + payload_with_context.extend_from_slice(&window_size.to_le_bytes()); + + self.recv_tx.send((payload_with_context, sender)).await; Ok(()) } @@ -526,10 +438,12 @@ impl TcpSocket { } // Close the connection gracefully - pub fn close(&mut self, interface: &mut Interface) -> Result<()> { + pub async fn close(&mut self) -> Result<()> { + let interface = get_interface_mut(); match self.state { TcpState::Established => { // Send FIN packet + // println!("sending a close"); if let Some(remote) = self.remote_addr { tcp::send_tcp_packet( interface, @@ -546,6 +460,13 @@ impl TcpSocket { self.seq_number += 1; // FIN consumes a sequence number self.state = TcpState::FinWait1; } + + let _ = self.recv().await; + + if self.state == TcpState::FinWait2 { + let _ = self.recv().await; + } + Ok(()) } TcpState::CloseWait => { diff --git a/crates/kernel/src/networking/socket/udp.rs b/crates/kernel/src/networking/socket/udp.rs index 8e6688b4..39078728 100644 --- a/crates/kernel/src/networking/socket/udp.rs +++ b/crates/kernel/src/networking/socket/udp.rs @@ -3,9 +3,9 @@ use crate::networking::iface::udp; use crate::networking::iface::Interface; use crate::networking::socket::bindings::NEXT_SOCKETFD; use crate::networking::socket::tagged::{TaggedSocket, BUFFER_LEN}; -use crate::networking::socket::{SocketAddr, SockType}; -use crate::ringbuffer::{channel, Sender, Receiver}; +use crate::networking::socket::{SockType, SocketAddr}; use crate::networking::{Error, Result}; +use crate::ringbuffer::{channel, Receiver, Sender}; use alloc::vec::Vec; use core::sync::atomic::Ordering; @@ -23,7 +23,7 @@ pub struct UdpSocket { impl UdpSocket { pub fn new() -> u16 { let interface = get_interface_mut(); - + // let (send_tx, send_rx) = channel::, SocketAddr)>(); let (recv_tx, recv_rx) = channel::, SocketAddr)>(); @@ -41,13 +41,15 @@ impl UdpSocket { let socketfd = NEXT_SOCKETFD.fetch_add(1, Ordering::SeqCst); // let mut sockets = interface.sockets.lock(); - interface.sockets.insert(socketfd, TaggedSocket::Udp(socket)); + interface + .sockets + .insert(socketfd, TaggedSocket::Udp(socket)); socketfd } pub fn binding_equals(&self, saddr: SocketAddr) -> bool { - println!("binding port {} provided port {}", self.binding.port, saddr.port); + // println!("binding port {} provided port {}", self.binding.port, saddr.port); self.binding.port == saddr.port }