use async_std::io::ReadExt; use async_std::io::WriteExt; use async_std::net::{TcpListener, TcpStream}; use futures::stream::StreamExt; use std::error::Error; use std::fmt; /// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. pub struct SimpleServer { listener: TcpListener, port: u16, host: String, raw_http_responses: Vec, calls_counter: usize, } /// Request-Line = Method SP Request-URI SP HTTP-Version CRLF struct Request<'a> { method: &'a str, uri: &'a str, http_version: &'a str, } impl<'a> fmt::Display for Request<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{} {} {}\r\n", self.method, self.uri, self.http_version) } } impl SimpleServer { /// Creates an instance of a [`SimpleServer`] /// If [`port`] is None os Some(0), it gets randomly chosen between the available ones. pub async fn new( host: &str, port: Option, raw_http_responses: Vec, ) -> Result { let port = port.unwrap_or(0); let listener = TcpListener::bind(format!("{}:{}", host, port)).await?; let port = listener.local_addr()?.port(); Ok(Self { listener, port, host: host.to_string(), raw_http_responses, calls_counter: 0, }) } /// Returns the uri in which the server is listening to. pub fn uri(&self) -> String { format!("http://{}:{}", self.host, self.port) } /// Starts the TcpListener and handles the requests. pub async fn start(mut self) -> () { while let Some(stream) = self.listener.incoming().next().await { match stream { Ok(stream) => { match self.handle_connection(stream).await { Ok(_) => (), Err(e) => { println!("Error handling connection: {}", e); () } } self.calls_counter = self.calls_counter + 1; } Err(e) => { println!("Connection failed: {}", e); () } } } } /// Asyncrounously reads from the buffer and handle the request. /// It first checks that the format is correct, then returns the response. /// /// Returns a 400 if the request if formatted badly. async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box> { let mut buffer = vec![0; 1024]; stream.read(&mut buffer).await.unwrap(); let request = String::from_utf8_lossy(&buffer[..]); let request_line = request.lines().next().unwrap(); let response = match Self::parse_request_line(&request_line) { Ok(request) => { println!("== Request == \n{}\n=============", request); self.get_response().clone() } Err(e) => { println!("++ Bad request: {} ++++++", e); self.get_bad_request_response() } }; println!("-- Response --\n{}\n--------------", response.clone()); stream.write(response.as_bytes()).await.unwrap(); stream.flush().await.unwrap(); Ok(()) } /// Parses the request line and checks that it contains the method, uri and http_version parts. /// It does not check if the content of the checked parts is correct. It just checks the format (it contains enough parts) of the request. fn parse_request_line(request: &str) -> Result> { let mut parts = request.split_whitespace(); let method = parts.next().ok_or("Method not specified")?; let uri = parts.next().ok_or("URI not specified")?; let http_version = parts.next().ok_or("HTTP version not specified")?; Ok(Request { method, uri, http_version, }) } /// Returns the response to use based on the calls counter. /// It uses a round-robin mechanism. fn get_response(&self) -> String { let index = if self.calls_counter >= self.raw_http_responses.len() { self.raw_http_responses.len() % self.calls_counter } else { self.calls_counter }; self.raw_http_responses[index].clone() } /// Returns the raw HTTP response in case of a 400 Bad Request. fn get_bad_request_response(&self) -> String { "HTTP/1.1 400 Bad Request\r\n\r\n".to_string() } }