Compare commits

...

1 Commits
kity ... s2n

Author SHA1 Message Date
Skye 6141fe71a5 halfway done s2n-quic rewrite
3 weeks ago

818
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -14,10 +14,9 @@ env_logger = "0.10.0"
idna = "0.4.0" idna = "0.4.0"
log = "0.4.19" log = "0.4.19"
parking_lot = "0.12.1" parking_lot = "0.12.1"
quinn = "0.10.1"
rand = "0.8.5" rand = "0.8.5"
rustls = "0.21.9"
rustls-pemfile = "1.0.2" rustls-pemfile = "1.0.2"
s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] }
serde = { version = "1.0.164", features = ["derive"] } serde = { version = "1.0.164", features = ["derive"] }
serde_json = "1.0.97" serde_json = "1.0.97"
thiserror = "1.0.40" thiserror = "1.0.40"

@ -2,7 +2,7 @@
#![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)] #![allow(clippy::cast_possible_wrap)]
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use axum::{ use axum::{
@ -11,17 +11,17 @@ use axum::{
}; };
use log::{error, info}; use log::{error, info};
use netty::{Handshake, ReadError}; use netty::{Handshake, ReadError};
use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig};
use routing::RoutingTable; use routing::RoutingTable;
use rustls::{Certificate, PrivateKey}; use s2n_quic::{connection::Error as ConnectionError, Connection, Server};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream, net::TcpStream,
task::{JoinError, JoinSet},
}; };
use crate::{ use crate::{
netty::{ReadExt, WriteExt}, netty::{ReadExt, WriteExt},
proto::{ClientboundControlMessage, ServerboundControlMessage}, proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest,
}; };
mod netty; mod netty;
@ -107,80 +107,116 @@ async fn main() -> anyhow::Result<()> {
} }
async fn try_handle_quic( async fn try_handle_quic(
connection: Connecting, connection: Connection,
routing_table: &RoutingTable, routing_table: &RoutingTable,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let connection = connection.await?;
info!( info!(
"QUIClime connection established to: {}", "QUIClime connection established to: {}",
connection.remote_address() connection.remote_addr()?
); );
let (mut send_control, mut recv_control) = connection.accept_bi().await?; let mut control = connection
info!("Control channel open: {}", connection.remote_address()); .accept_bidirectional_stream()
.await?
.ok_or(anyhow!(
"Connection closed while waiting for control channel"
))?;
info!("Control channel open: {}", connection.remote_addr()?);
let mut handle = loop { let mut handle = loop {
let mut buf = vec![0u8; recv_control.read_u8().await? as _]; let mut buf = vec![0u8; control.read_u8().await? as _];
recv_control.read_exact(&mut buf).await?; control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) { if let Ok(parsed) = serde_json::from_slice(&buf) {
match parsed { match parsed {
ServerboundControlMessage::RequestDomainAssignment => { ServerboundControlMessage::RequestDomainAssignment => {
let handle = routing_table.register(); let handle = routing_table.register();
info!( info!(
"Domain assigned to {}: {}", "Domain assigned to {}: {}",
connection.remote_address(), connection.remote_addr()?,
handle.domain() handle.domain()
); );
let response = let response =
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete { serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
domain: handle.domain().to_string(), domain: handle.domain().to_string(),
})?; })?;
send_control.write_all(&[response.len() as u8]).await?; control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?; control.write_all(&response).await?;
break handle; break handle;
} }
} }
} }
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
send_control.write_all(&[response.len() as u8]).await?; control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?; control.write_all(&response).await?;
}; };
let mut set = JoinSet::new();
tokio::select! { let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel();
e = connection.closed() => { let (mut control_recv, mut control_send) = control.split();
match e { let send_task = tokio::spawn(async move {
ConnectionError::ConnectionClosed(_) while let Some(event) = control_message_queue_recv.recv().await {
| ConnectionError::ApplicationClosed(_) let response = serde_json::to_vec(&event)?;
| ConnectionError::LocallyClosed => Ok(()), control_send.write_all(&[response.len() as u8]).await?;
e => Err(e.into()), control_send.write_all(&response).await?;
} }
}, Ok::<_, tokio::io::Error>(())
r = async { });
while let Some(remote) = handle.next().await { let control_message_queue_ref = &control_message_queue;
match remote { set.spawn(async move {
routing::RouterRequest::RouteRequest(remote) => { loop {
let pair = connection.open_bi().await; let mut buf = vec![0u8; control_recv.read_u8().await? as _];
if let Err(ConnectionError::ApplicationClosed(_)) = pair { control_recv.read_exact(&mut buf).await?;
break; control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage);
} else if let Err(ConnectionError::ConnectionClosed(_)) = pair { }
break; Ok::<_, tokio::io::Error>(())
} });
remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?; enum Event {
} RouterEvent(RouterRequest),
routing::RouterRequest::BroadcastRequest(message) => { TaskSet(Result<Result<(), tokio::io::Error>, JoinError>)
let response = }
serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast { while let Some(remote) = tokio::select! {
message, v = handle.next() => v.map(Event::RouterEvent),
})?; v = set.join_next() => v.map(Event::TaskSet),
send_control.write_all(&[response.len() as u8]).await?; } {
send_control.write_all(&response).await?; match remote {
Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => {
let stream = connection.open_bidirectional_stream().await;
set.spawn(async move {
if let Err(
ConnectionError::Transport { .. }
| ConnectionError::Application { .. }
| ConnectionError::EndpointClosing { .. }
| ConnectionError::ImmediateClose { .. },
) = stream
{
Ok(())
} else {
let mut stream = stream?;
handshake.send(&mut stream).await?;
tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?;
Ok::<_, tokio::io::Error>(())
} }
});
}
Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => {
control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast {
message,
});
}
Event::TaskSet(Ok(Ok(()))) => {}
Event::TaskSet(Ok(Err(e))) => {
if e.kind() != ErrorKind::UnexpectedEof {
error!("Error in task: {e:?}")
} }
} }
Ok(()) Event::TaskSet(Err(e)) => {
} => r error!("Error in task: {e:?}")
}
}
} }
send_task.abort();
Ok(())
} }
async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) { async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
if let Err(e) = try_handle_quic(connection, routing_table).await { if let Err(e) = try_handle_quic(connection, routing_table).await {
error!("Error handling QUIClime connection: {}", e); error!("Error handling QUIClime connection: {}", e);
}; };
@ -188,7 +224,7 @@ async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
} }
async fn listen_quic( async fn listen_quic(
endpoint: &'static Endpoint, mut endpoint: Server,
routing_table: &'static RoutingTable, routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> { ) -> anyhow::Result<Infallible> {
while let Some(connection) = endpoint.accept().await { while let Some(connection) = endpoint.accept().await {
@ -198,7 +234,6 @@ async fn listen_quic(
} }
async fn listen_control( async fn listen_control(
endpoint: &'static Endpoint,
routing_table: &'static RoutingTable, routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> { ) -> anyhow::Result<Infallible> {
let app = axum::Router::new() let app = axum::Router::new()

@ -1,20 +1,21 @@
use log::info; use log::info;
use log::warn; use log::warn;
use parking_lot::RwLock; use parking_lot::RwLock;
use quinn::RecvStream;
use quinn::SendStream;
use rand::prelude::*; use rand::prelude::*;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::netty::Handshake;
#[derive(Debug)] #[derive(Debug)]
pub enum RouterRequest { pub enum RouterRequest {
RouteRequest(RouterCallback), RouteRequest(RouterConnection),
BroadcastRequest(String), BroadcastRequest(String),
} }
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>; type RouterConnection = (Handshake, TcpStream);
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>; type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
@ -44,14 +45,12 @@ impl RoutingTable {
} }
} }
pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> { pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> {
let (send, recv) = oneshot::channel();
self.table self.table
.read() .read()
.get(domain)? .get(domain)?
.send(RouterRequest::RouteRequest(send)) .send(RouterRequest::RouteRequest(conn))
.ok()?; .ok()
recv.await.ok()
} }
pub fn register(&self) -> RoutingHandle { pub fn register(&self) -> RoutingHandle {

Loading…
Cancel
Save