halfway done s2n-quic rewrite
This commit is contained in:
		
							parent
							
								
									0d4aae1d63
								
							
						
					
					
						commit
						6141fe71a5
					
				
					 4 changed files with 728 additions and 293 deletions
				
			
		
							
								
								
									
										864
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										864
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							|  | @ -14,10 +14,9 @@ env_logger = "0.10.0" | ||||||
| idna = "0.4.0" | idna = "0.4.0" | ||||||
| log = "0.4.19" | log = "0.4.19" | ||||||
| parking_lot = "0.12.1" | parking_lot = "0.12.1" | ||||||
| quinn = "0.10.1" |  | ||||||
| rand = "0.8.5" | rand = "0.8.5" | ||||||
| rustls = "0.21.9" |  | ||||||
| rustls-pemfile = "1.0.2" | rustls-pemfile = "1.0.2" | ||||||
|  | s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] } | ||||||
| serde = { version = "1.0.164", features = ["derive"] } | serde = { version = "1.0.164", features = ["derive"] } | ||||||
| serde_json = "1.0.97" | serde_json = "1.0.97" | ||||||
| thiserror = "1.0.40" | thiserror = "1.0.40" | ||||||
|  |  | ||||||
							
								
								
									
										135
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										135
									
								
								src/main.rs
									
									
									
									
									
								
							|  | @ -2,7 +2,7 @@ | ||||||
| #![allow(clippy::cast_possible_truncation)] | #![allow(clippy::cast_possible_truncation)] | ||||||
| #![allow(clippy::cast_possible_wrap)] | #![allow(clippy::cast_possible_wrap)] | ||||||
| 
 | 
 | ||||||
| use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; | use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; | ||||||
| 
 | 
 | ||||||
| use anyhow::{anyhow, Context}; | use anyhow::{anyhow, Context}; | ||||||
| use axum::{ | use axum::{ | ||||||
|  | @ -11,17 +11,17 @@ use axum::{ | ||||||
| }; | }; | ||||||
| use log::{error, info}; | use log::{error, info}; | ||||||
| use netty::{Handshake, ReadError}; | use netty::{Handshake, ReadError}; | ||||||
| use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig}; |  | ||||||
| use routing::RoutingTable; | use routing::RoutingTable; | ||||||
| use rustls::{Certificate, PrivateKey}; | use s2n_quic::{connection::Error as ConnectionError, Connection, Server}; | ||||||
| use tokio::{ | use tokio::{ | ||||||
|     io::{AsyncReadExt, AsyncWriteExt}, |     io::{AsyncReadExt, AsyncWriteExt}, | ||||||
|     net::TcpStream, |     net::TcpStream, | ||||||
|  |     task::{JoinError, JoinSet}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| use crate::{ | use crate::{ | ||||||
|     netty::{ReadExt, WriteExt}, |     netty::{ReadExt, WriteExt}, | ||||||
|     proto::{ClientboundControlMessage, ServerboundControlMessage}, |     proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| mod netty; | mod netty; | ||||||
|  | @ -107,80 +107,116 @@ async fn main() -> anyhow::Result<()> { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| async fn try_handle_quic( | async fn try_handle_quic( | ||||||
|     connection: Connecting, |     connection: Connection, | ||||||
|     routing_table: &RoutingTable, |     routing_table: &RoutingTable, | ||||||
| ) -> anyhow::Result<()> { | ) -> anyhow::Result<()> { | ||||||
|     let connection = connection.await?; |  | ||||||
|     info!( |     info!( | ||||||
|         "QUIClime connection established to: {}", |         "QUIClime connection established to: {}", | ||||||
|         connection.remote_address() |         connection.remote_addr()? | ||||||
|     ); |     ); | ||||||
|     let (mut send_control, mut recv_control) = connection.accept_bi().await?; |     let mut control = connection | ||||||
|     info!("Control channel open: {}", connection.remote_address()); |         .accept_bidirectional_stream() | ||||||
|  |         .await? | ||||||
|  |         .ok_or(anyhow!( | ||||||
|  |             "Connection closed while waiting for control channel" | ||||||
|  |         ))?; | ||||||
|  |     info!("Control channel open: {}", connection.remote_addr()?); | ||||||
|     let mut handle = loop { |     let mut handle = loop { | ||||||
|         let mut buf = vec![0u8; recv_control.read_u8().await? as _]; |         let mut buf = vec![0u8; control.read_u8().await? as _]; | ||||||
|         recv_control.read_exact(&mut buf).await?; |         control.read_exact(&mut buf).await?; | ||||||
|         if let Ok(parsed) = serde_json::from_slice(&buf) { |         if let Ok(parsed) = serde_json::from_slice(&buf) { | ||||||
|             match parsed { |             match parsed { | ||||||
|                 ServerboundControlMessage::RequestDomainAssignment => { |                 ServerboundControlMessage::RequestDomainAssignment => { | ||||||
|                     let handle = routing_table.register(); |                     let handle = routing_table.register(); | ||||||
|                     info!( |                     info!( | ||||||
|                         "Domain assigned to {}: {}", |                         "Domain assigned to {}: {}", | ||||||
|                         connection.remote_address(), |                         connection.remote_addr()?, | ||||||
|                         handle.domain() |                         handle.domain() | ||||||
|                     ); |                     ); | ||||||
|                     let response = |                     let response = | ||||||
|                         serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete { |                         serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete { | ||||||
|                             domain: handle.domain().to_string(), |                             domain: handle.domain().to_string(), | ||||||
|                         })?; |                         })?; | ||||||
|                     send_control.write_all(&[response.len() as u8]).await?; |                     control.write_all(&[response.len() as u8]).await?; | ||||||
|                     send_control.write_all(&response).await?; |                     control.write_all(&response).await?; | ||||||
|                     break handle; |                     break handle; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; |         let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; | ||||||
|         send_control.write_all(&[response.len() as u8]).await?; |         control.write_all(&[response.len() as u8]).await?; | ||||||
|         send_control.write_all(&response).await?; |         control.write_all(&response).await?; | ||||||
|     }; |     }; | ||||||
| 
 |     let mut set = JoinSet::new(); | ||||||
|     tokio::select! { |     let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel(); | ||||||
|         e = connection.closed() => { |     let (mut control_recv, mut control_send) = control.split(); | ||||||
|             match e { |     let send_task = tokio::spawn(async move { | ||||||
|                 ConnectionError::ConnectionClosed(_) |         while let Some(event) = control_message_queue_recv.recv().await { | ||||||
|                 | ConnectionError::ApplicationClosed(_) |             let response = serde_json::to_vec(&event)?; | ||||||
|                 | ConnectionError::LocallyClosed => Ok(()), |             control_send.write_all(&[response.len() as u8]).await?; | ||||||
|                 e => Err(e.into()), |             control_send.write_all(&response).await?; | ||||||
|         } |         } | ||||||
|         }, |         Ok::<_, tokio::io::Error>(()) | ||||||
|         r = async { |     }); | ||||||
|             while let Some(remote) = handle.next().await { |     let control_message_queue_ref = &control_message_queue; | ||||||
|  |     set.spawn(async move { | ||||||
|  |         loop { | ||||||
|  |             let mut buf = vec![0u8; control_recv.read_u8().await? as _]; | ||||||
|  |             control_recv.read_exact(&mut buf).await?; | ||||||
|  |             control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage); | ||||||
|  |         } | ||||||
|  |         Ok::<_, tokio::io::Error>(()) | ||||||
|  |     }); | ||||||
|  |     enum Event { | ||||||
|  |         RouterEvent(RouterRequest), | ||||||
|  |         TaskSet(Result<Result<(), tokio::io::Error>, JoinError>) | ||||||
|  |     } | ||||||
|  |     while let Some(remote) = tokio::select! { | ||||||
|  |         v = handle.next() => v.map(Event::RouterEvent), | ||||||
|  |         v = set.join_next() => v.map(Event::TaskSet), | ||||||
|  |     } { | ||||||
|         match remote { |         match remote { | ||||||
|                     routing::RouterRequest::RouteRequest(remote) => { |             Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => { | ||||||
|                         let pair = connection.open_bi().await; |                 let stream = connection.open_bidirectional_stream().await; | ||||||
|                         if let Err(ConnectionError::ApplicationClosed(_)) = pair { |                 set.spawn(async move { | ||||||
|                             break; |                     if let Err( | ||||||
|                         } else if let Err(ConnectionError::ConnectionClosed(_)) = pair { |                         ConnectionError::Transport { .. } | ||||||
|                             break; |                         | ConnectionError::Application { .. } | ||||||
|                         } |                         | ConnectionError::EndpointClosing { .. } | ||||||
|                         remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?; |                         | ConnectionError::ImmediateClose { .. }, | ||||||
|                     } |                     ) = stream | ||||||
|                     routing::RouterRequest::BroadcastRequest(message) => { |                     { | ||||||
|                         let response = |  | ||||||
|                             serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast { |  | ||||||
|                                 message, |  | ||||||
|                             })?; |  | ||||||
|                         send_control.write_all(&[response.len() as u8]).await?; |  | ||||||
|                         send_control.write_all(&response).await?; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|                         Ok(()) |                         Ok(()) | ||||||
|         } => r |                     } else { | ||||||
|  |                         let mut stream = stream?; | ||||||
|  |                         handshake.send(&mut stream).await?; | ||||||
|  |                         tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?; | ||||||
|  |                         Ok::<_, tokio::io::Error>(()) | ||||||
|                     } |                     } | ||||||
|  |                 }); | ||||||
|  |             } | ||||||
|  |             Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => { | ||||||
|  |                 control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast { | ||||||
|  |                     message, | ||||||
|  |                 }); | ||||||
|  |             } | ||||||
|  |             Event::TaskSet(Ok(Ok(()))) => {} | ||||||
|  |             Event::TaskSet(Ok(Err(e))) => { | ||||||
|  |                 if e.kind() != ErrorKind::UnexpectedEof { | ||||||
|  |                     error!("Error in task: {e:?}") | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             Event::TaskSet(Err(e)) => { | ||||||
|  |                 error!("Error in task: {e:?}") | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
| async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) { |         } | ||||||
|  |     } | ||||||
|  |     send_task.abort(); | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | async fn handle_quic(connection: Connection, routing_table: &RoutingTable) { | ||||||
|     if let Err(e) = try_handle_quic(connection, routing_table).await { |     if let Err(e) = try_handle_quic(connection, routing_table).await { | ||||||
|         error!("Error handling QUIClime connection: {}", e); |         error!("Error handling QUIClime connection: {}", e); | ||||||
|     }; |     }; | ||||||
|  | @ -188,7 +224,7 @@ async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| async fn listen_quic( | async fn listen_quic( | ||||||
|     endpoint: &'static Endpoint, |     mut endpoint: Server, | ||||||
|     routing_table: &'static RoutingTable, |     routing_table: &'static RoutingTable, | ||||||
| ) -> anyhow::Result<Infallible> { | ) -> anyhow::Result<Infallible> { | ||||||
|     while let Some(connection) = endpoint.accept().await { |     while let Some(connection) = endpoint.accept().await { | ||||||
|  | @ -198,7 +234,6 @@ async fn listen_quic( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| async fn listen_control( | async fn listen_control( | ||||||
|     endpoint: &'static Endpoint, |  | ||||||
|     routing_table: &'static RoutingTable, |     routing_table: &'static RoutingTable, | ||||||
| ) -> anyhow::Result<Infallible> { | ) -> anyhow::Result<Infallible> { | ||||||
|     let app = axum::Router::new() |     let app = axum::Router::new() | ||||||
|  |  | ||||||
|  | @ -1,20 +1,21 @@ | ||||||
| use log::info; | use log::info; | ||||||
| use log::warn; | use log::warn; | ||||||
| use parking_lot::RwLock; | use parking_lot::RwLock; | ||||||
| use quinn::RecvStream; |  | ||||||
| use quinn::SendStream; |  | ||||||
| use rand::prelude::*; | use rand::prelude::*; | ||||||
| use std::collections::HashMap; | use std::collections::HashMap; | ||||||
|  | use tokio::net::TcpStream; | ||||||
| use tokio::sync::mpsc; | use tokio::sync::mpsc; | ||||||
| use tokio::sync::oneshot; | use tokio::sync::oneshot; | ||||||
| 
 | 
 | ||||||
|  | use crate::netty::Handshake; | ||||||
|  | 
 | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| pub enum RouterRequest { | pub enum RouterRequest { | ||||||
|     RouteRequest(RouterCallback), |     RouteRequest(RouterConnection), | ||||||
|     BroadcastRequest(String), |     BroadcastRequest(String), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>; | type RouterConnection = (Handshake, TcpStream); | ||||||
| type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>; | type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>; | ||||||
| 
 | 
 | ||||||
| #[allow(clippy::module_name_repetitions)] | #[allow(clippy::module_name_repetitions)] | ||||||
|  | @ -44,14 +45,12 @@ impl RoutingTable { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> { |     pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> { | ||||||
|         let (send, recv) = oneshot::channel(); |  | ||||||
|         self.table |         self.table | ||||||
|             .read() |             .read() | ||||||
|             .get(domain)? |             .get(domain)? | ||||||
|             .send(RouterRequest::RouteRequest(send)) |             .send(RouterRequest::RouteRequest(conn)) | ||||||
|             .ok()?; |             .ok() | ||||||
|         recv.await.ok() |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn register(&self) -> RoutingHandle { |     pub fn register(&self) -> RoutingHandle { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue