halfway done s2n-quic rewrite

This commit is contained in:
Skye 2024-09-26 15:51:21 +09:00
parent 0d4aae1d63
commit 6141fe71a5
4 changed files with 728 additions and 293 deletions

864
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -14,10 +14,9 @@ env_logger = "0.10.0"
idna = "0.4.0"
log = "0.4.19"
parking_lot = "0.12.1"
quinn = "0.10.1"
rand = "0.8.5"
rustls = "0.21.9"
rustls-pemfile = "1.0.2"
s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] }
serde = { version = "1.0.164", features = ["derive"] }
serde_json = "1.0.97"
thiserror = "1.0.40"

View file

@ -2,7 +2,7 @@
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use anyhow::{anyhow, Context};
use axum::{
@ -11,17 +11,17 @@ use axum::{
};
use log::{error, info};
use netty::{Handshake, ReadError};
use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig};
use routing::RoutingTable;
use rustls::{Certificate, PrivateKey};
use s2n_quic::{connection::Error as ConnectionError, Connection, Server};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
task::{JoinError, JoinSet},
};
use crate::{
netty::{ReadExt, WriteExt},
proto::{ClientboundControlMessage, ServerboundControlMessage},
proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest,
};
mod netty;
@ -107,80 +107,116 @@ async fn main() -> anyhow::Result<()> {
}
async fn try_handle_quic(
connection: Connecting,
connection: Connection,
routing_table: &RoutingTable,
) -> anyhow::Result<()> {
let connection = connection.await?;
info!(
"QUIClime connection established to: {}",
connection.remote_address()
connection.remote_addr()?
);
let (mut send_control, mut recv_control) = connection.accept_bi().await?;
info!("Control channel open: {}", connection.remote_address());
let mut control = connection
.accept_bidirectional_stream()
.await?
.ok_or(anyhow!(
"Connection closed while waiting for control channel"
))?;
info!("Control channel open: {}", connection.remote_addr()?);
let mut handle = loop {
let mut buf = vec![0u8; recv_control.read_u8().await? as _];
recv_control.read_exact(&mut buf).await?;
let mut buf = vec![0u8; control.read_u8().await? as _];
control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) {
match parsed {
ServerboundControlMessage::RequestDomainAssignment => {
let handle = routing_table.register();
info!(
"Domain assigned to {}: {}",
connection.remote_address(),
connection.remote_addr()?,
handle.domain()
);
let response =
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
domain: handle.domain().to_string(),
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
control.write_all(&[response.len() as u8]).await?;
control.write_all(&response).await?;
break handle;
}
}
}
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
control.write_all(&[response.len() as u8]).await?;
control.write_all(&response).await?;
};
tokio::select! {
e = connection.closed() => {
match e {
ConnectionError::ConnectionClosed(_)
| ConnectionError::ApplicationClosed(_)
| ConnectionError::LocallyClosed => Ok(()),
e => Err(e.into()),
let mut set = JoinSet::new();
let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel();
let (mut control_recv, mut control_send) = control.split();
let send_task = tokio::spawn(async move {
while let Some(event) = control_message_queue_recv.recv().await {
let response = serde_json::to_vec(&event)?;
control_send.write_all(&[response.len() as u8]).await?;
control_send.write_all(&response).await?;
}
Ok::<_, tokio::io::Error>(())
});
let control_message_queue_ref = &control_message_queue;
set.spawn(async move {
loop {
let mut buf = vec![0u8; control_recv.read_u8().await? as _];
control_recv.read_exact(&mut buf).await?;
control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage);
}
Ok::<_, tokio::io::Error>(())
});
enum Event {
RouterEvent(RouterRequest),
TaskSet(Result<Result<(), tokio::io::Error>, JoinError>)
}
while let Some(remote) = tokio::select! {
v = handle.next() => v.map(Event::RouterEvent),
v = set.join_next() => v.map(Event::TaskSet),
} {
match remote {
Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => {
let stream = connection.open_bidirectional_stream().await;
set.spawn(async move {
if let Err(
ConnectionError::Transport { .. }
| ConnectionError::Application { .. }
| ConnectionError::EndpointClosing { .. }
| ConnectionError::ImmediateClose { .. },
) = stream
{
Ok(())
} else {
let mut stream = stream?;
handshake.send(&mut stream).await?;
tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?;
Ok::<_, tokio::io::Error>(())
}
});
}
},
r = async {
while let Some(remote) = handle.next().await {
match remote {
routing::RouterRequest::RouteRequest(remote) => {
let pair = connection.open_bi().await;
if let Err(ConnectionError::ApplicationClosed(_)) = pair {
break;
} else if let Err(ConnectionError::ConnectionClosed(_)) = pair {
break;
}
remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?;
}
routing::RouterRequest::BroadcastRequest(message) => {
let response =
serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast {
message,
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
}
Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => {
control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast {
message,
});
}
Event::TaskSet(Ok(Ok(()))) => {}
Event::TaskSet(Ok(Err(e))) => {
if e.kind() != ErrorKind::UnexpectedEof {
error!("Error in task: {e:?}")
}
}
Ok(())
} => r
Event::TaskSet(Err(e)) => {
error!("Error in task: {e:?}")
}
}
}
send_task.abort();
Ok(())
}
async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
if let Err(e) = try_handle_quic(connection, routing_table).await {
error!("Error handling QUIClime connection: {}", e);
};
@ -188,7 +224,7 @@ async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
}
async fn listen_quic(
endpoint: &'static Endpoint,
mut endpoint: Server,
routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> {
while let Some(connection) = endpoint.accept().await {
@ -198,7 +234,6 @@ async fn listen_quic(
}
async fn listen_control(
endpoint: &'static Endpoint,
routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> {
let app = axum::Router::new()

View file

@ -1,20 +1,21 @@
use log::info;
use log::warn;
use parking_lot::RwLock;
use quinn::RecvStream;
use quinn::SendStream;
use rand::prelude::*;
use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use crate::netty::Handshake;
#[derive(Debug)]
pub enum RouterRequest {
RouteRequest(RouterCallback),
RouteRequest(RouterConnection),
BroadcastRequest(String),
}
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>;
type RouterConnection = (Handshake, TcpStream);
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
#[allow(clippy::module_name_repetitions)]
@ -44,14 +45,12 @@ impl RoutingTable {
}
}
pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> {
let (send, recv) = oneshot::channel();
pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> {
self.table
.read()
.get(domain)?
.send(RouterRequest::RouteRequest(send))
.ok()?;
recv.await.ok()
.send(RouterRequest::RouteRequest(conn))
.ok()
}
pub fn register(&self) -> RoutingHandle {