halfway done s2n-quic rewrite

This commit is contained in:
Skye 2024-09-26 15:51:21 +09:00
parent 0d4aae1d63
commit 6141fe71a5
4 changed files with 728 additions and 293 deletions

864
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -14,10 +14,9 @@ env_logger = "0.10.0"
idna = "0.4.0" idna = "0.4.0"
log = "0.4.19" log = "0.4.19"
parking_lot = "0.12.1" parking_lot = "0.12.1"
quinn = "0.10.1"
rand = "0.8.5" rand = "0.8.5"
rustls = "0.21.9"
rustls-pemfile = "1.0.2" rustls-pemfile = "1.0.2"
s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] }
serde = { version = "1.0.164", features = ["derive"] } serde = { version = "1.0.164", features = ["derive"] }
serde_json = "1.0.97" serde_json = "1.0.97"
thiserror = "1.0.40" thiserror = "1.0.40"
@ -25,4 +24,4 @@ tokio = { version = "1.28.2", features = ["rt-multi-thread", "fs", "macros", "io
[profile.release] [profile.release]
lto = "fat" lto = "fat"
debug = "full" debug = "full"

View file

@ -2,7 +2,7 @@
#![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)] #![allow(clippy::cast_possible_wrap)]
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use axum::{ use axum::{
@ -11,17 +11,17 @@ use axum::{
}; };
use log::{error, info}; use log::{error, info};
use netty::{Handshake, ReadError}; use netty::{Handshake, ReadError};
use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig};
use routing::RoutingTable; use routing::RoutingTable;
use rustls::{Certificate, PrivateKey}; use s2n_quic::{connection::Error as ConnectionError, Connection, Server};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream, net::TcpStream,
task::{JoinError, JoinSet},
}; };
use crate::{ use crate::{
netty::{ReadExt, WriteExt}, netty::{ReadExt, WriteExt},
proto::{ClientboundControlMessage, ServerboundControlMessage}, proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest,
}; };
mod netty; mod netty;
@ -107,80 +107,116 @@ async fn main() -> anyhow::Result<()> {
} }
async fn try_handle_quic( async fn try_handle_quic(
connection: Connecting, connection: Connection,
routing_table: &RoutingTable, routing_table: &RoutingTable,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let connection = connection.await?;
info!( info!(
"QUIClime connection established to: {}", "QUIClime connection established to: {}",
connection.remote_address() connection.remote_addr()?
); );
let (mut send_control, mut recv_control) = connection.accept_bi().await?; let mut control = connection
info!("Control channel open: {}", connection.remote_address()); .accept_bidirectional_stream()
.await?
.ok_or(anyhow!(
"Connection closed while waiting for control channel"
))?;
info!("Control channel open: {}", connection.remote_addr()?);
let mut handle = loop { let mut handle = loop {
let mut buf = vec![0u8; recv_control.read_u8().await? as _]; let mut buf = vec![0u8; control.read_u8().await? as _];
recv_control.read_exact(&mut buf).await?; control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) { if let Ok(parsed) = serde_json::from_slice(&buf) {
match parsed { match parsed {
ServerboundControlMessage::RequestDomainAssignment => { ServerboundControlMessage::RequestDomainAssignment => {
let handle = routing_table.register(); let handle = routing_table.register();
info!( info!(
"Domain assigned to {}: {}", "Domain assigned to {}: {}",
connection.remote_address(), connection.remote_addr()?,
handle.domain() handle.domain()
); );
let response = let response =
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete { serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
domain: handle.domain().to_string(), domain: handle.domain().to_string(),
})?; })?;
send_control.write_all(&[response.len() as u8]).await?; control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?; control.write_all(&response).await?;
break handle; break handle;
} }
} }
} }
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
send_control.write_all(&[response.len() as u8]).await?; control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?; control.write_all(&response).await?;
}; };
let mut set = JoinSet::new();
tokio::select! { let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel();
e = connection.closed() => { let (mut control_recv, mut control_send) = control.split();
match e { let send_task = tokio::spawn(async move {
ConnectionError::ConnectionClosed(_) while let Some(event) = control_message_queue_recv.recv().await {
| ConnectionError::ApplicationClosed(_) let response = serde_json::to_vec(&event)?;
| ConnectionError::LocallyClosed => Ok(()), control_send.write_all(&[response.len() as u8]).await?;
e => Err(e.into()), control_send.write_all(&response).await?;
}
Ok::<_, tokio::io::Error>(())
});
let control_message_queue_ref = &control_message_queue;
set.spawn(async move {
loop {
let mut buf = vec![0u8; control_recv.read_u8().await? as _];
control_recv.read_exact(&mut buf).await?;
control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage);
}
Ok::<_, tokio::io::Error>(())
});
enum Event {
RouterEvent(RouterRequest),
TaskSet(Result<Result<(), tokio::io::Error>, JoinError>)
}
while let Some(remote) = tokio::select! {
v = handle.next() => v.map(Event::RouterEvent),
v = set.join_next() => v.map(Event::TaskSet),
} {
match remote {
Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => {
let stream = connection.open_bidirectional_stream().await;
set.spawn(async move {
if let Err(
ConnectionError::Transport { .. }
| ConnectionError::Application { .. }
| ConnectionError::EndpointClosing { .. }
| ConnectionError::ImmediateClose { .. },
) = stream
{
Ok(())
} else {
let mut stream = stream?;
handshake.send(&mut stream).await?;
tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?;
Ok::<_, tokio::io::Error>(())
}
});
} }
}, Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => {
r = async { control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast {
while let Some(remote) = handle.next().await { message,
match remote { });
routing::RouterRequest::RouteRequest(remote) => { }
let pair = connection.open_bi().await; Event::TaskSet(Ok(Ok(()))) => {}
if let Err(ConnectionError::ApplicationClosed(_)) = pair { Event::TaskSet(Ok(Err(e))) => {
break; if e.kind() != ErrorKind::UnexpectedEof {
} else if let Err(ConnectionError::ConnectionClosed(_)) = pair { error!("Error in task: {e:?}")
break;
}
remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?;
}
routing::RouterRequest::BroadcastRequest(message) => {
let response =
serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast {
message,
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
}
} }
} }
Ok(()) Event::TaskSet(Err(e)) => {
} => r error!("Error in task: {e:?}")
}
}
} }
send_task.abort();
Ok(())
} }
async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) { async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
if let Err(e) = try_handle_quic(connection, routing_table).await { if let Err(e) = try_handle_quic(connection, routing_table).await {
error!("Error handling QUIClime connection: {}", e); error!("Error handling QUIClime connection: {}", e);
}; };
@ -188,7 +224,7 @@ async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
} }
async fn listen_quic( async fn listen_quic(
endpoint: &'static Endpoint, mut endpoint: Server,
routing_table: &'static RoutingTable, routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> { ) -> anyhow::Result<Infallible> {
while let Some(connection) = endpoint.accept().await { while let Some(connection) = endpoint.accept().await {
@ -198,7 +234,6 @@ async fn listen_quic(
} }
async fn listen_control( async fn listen_control(
endpoint: &'static Endpoint,
routing_table: &'static RoutingTable, routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> { ) -> anyhow::Result<Infallible> {
let app = axum::Router::new() let app = axum::Router::new()

View file

@ -1,20 +1,21 @@
use log::info; use log::info;
use log::warn; use log::warn;
use parking_lot::RwLock; use parking_lot::RwLock;
use quinn::RecvStream;
use quinn::SendStream;
use rand::prelude::*; use rand::prelude::*;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::netty::Handshake;
#[derive(Debug)] #[derive(Debug)]
pub enum RouterRequest { pub enum RouterRequest {
RouteRequest(RouterCallback), RouteRequest(RouterConnection),
BroadcastRequest(String), BroadcastRequest(String),
} }
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>; type RouterConnection = (Handshake, TcpStream);
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>; type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
@ -44,14 +45,12 @@ impl RoutingTable {
} }
} }
pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> { pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> {
let (send, recv) = oneshot::channel();
self.table self.table
.read() .read()
.get(domain)? .get(domain)?
.send(RouterRequest::RouteRequest(send)) .send(RouterRequest::RouteRequest(conn))
.ok()?; .ok()
recv.await.ok()
} }
pub fn register(&self) -> RoutingHandle { pub fn register(&self) -> RoutingHandle {