halfway done s2n-quic rewrite
This commit is contained in:
parent
0d4aae1d63
commit
6141fe71a5
4 changed files with 728 additions and 293 deletions
864
Cargo.lock
generated
864
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -14,10 +14,9 @@ env_logger = "0.10.0"
|
|||
idna = "0.4.0"
|
||||
log = "0.4.19"
|
||||
parking_lot = "0.12.1"
|
||||
quinn = "0.10.1"
|
||||
rand = "0.8.5"
|
||||
rustls = "0.21.9"
|
||||
rustls-pemfile = "1.0.2"
|
||||
s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] }
|
||||
serde = { version = "1.0.164", features = ["derive"] }
|
||||
serde_json = "1.0.97"
|
||||
thiserror = "1.0.40"
|
||||
|
|
135
src/main.rs
135
src/main.rs
|
@ -2,7 +2,7 @@
|
|||
#![allow(clippy::cast_possible_truncation)]
|
||||
#![allow(clippy::cast_possible_wrap)]
|
||||
|
||||
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use axum::{
|
||||
|
@ -11,17 +11,17 @@ use axum::{
|
|||
};
|
||||
use log::{error, info};
|
||||
use netty::{Handshake, ReadError};
|
||||
use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig};
|
||||
use routing::RoutingTable;
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use s2n_quic::{connection::Error as ConnectionError, Connection, Server};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
task::{JoinError, JoinSet},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
netty::{ReadExt, WriteExt},
|
||||
proto::{ClientboundControlMessage, ServerboundControlMessage},
|
||||
proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest,
|
||||
};
|
||||
|
||||
mod netty;
|
||||
|
@ -107,80 +107,116 @@ async fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
|
||||
async fn try_handle_quic(
|
||||
connection: Connecting,
|
||||
connection: Connection,
|
||||
routing_table: &RoutingTable,
|
||||
) -> anyhow::Result<()> {
|
||||
let connection = connection.await?;
|
||||
info!(
|
||||
"QUIClime connection established to: {}",
|
||||
connection.remote_address()
|
||||
connection.remote_addr()?
|
||||
);
|
||||
let (mut send_control, mut recv_control) = connection.accept_bi().await?;
|
||||
info!("Control channel open: {}", connection.remote_address());
|
||||
let mut control = connection
|
||||
.accept_bidirectional_stream()
|
||||
.await?
|
||||
.ok_or(anyhow!(
|
||||
"Connection closed while waiting for control channel"
|
||||
))?;
|
||||
info!("Control channel open: {}", connection.remote_addr()?);
|
||||
let mut handle = loop {
|
||||
let mut buf = vec![0u8; recv_control.read_u8().await? as _];
|
||||
recv_control.read_exact(&mut buf).await?;
|
||||
let mut buf = vec![0u8; control.read_u8().await? as _];
|
||||
control.read_exact(&mut buf).await?;
|
||||
if let Ok(parsed) = serde_json::from_slice(&buf) {
|
||||
match parsed {
|
||||
ServerboundControlMessage::RequestDomainAssignment => {
|
||||
let handle = routing_table.register();
|
||||
info!(
|
||||
"Domain assigned to {}: {}",
|
||||
connection.remote_address(),
|
||||
connection.remote_addr()?,
|
||||
handle.domain()
|
||||
);
|
||||
let response =
|
||||
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
|
||||
domain: handle.domain().to_string(),
|
||||
})?;
|
||||
send_control.write_all(&[response.len() as u8]).await?;
|
||||
send_control.write_all(&response).await?;
|
||||
control.write_all(&[response.len() as u8]).await?;
|
||||
control.write_all(&response).await?;
|
||||
break handle;
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
|
||||
send_control.write_all(&[response.len() as u8]).await?;
|
||||
send_control.write_all(&response).await?;
|
||||
control.write_all(&[response.len() as u8]).await?;
|
||||
control.write_all(&response).await?;
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
e = connection.closed() => {
|
||||
match e {
|
||||
ConnectionError::ConnectionClosed(_)
|
||||
| ConnectionError::ApplicationClosed(_)
|
||||
| ConnectionError::LocallyClosed => Ok(()),
|
||||
e => Err(e.into()),
|
||||
let mut set = JoinSet::new();
|
||||
let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (mut control_recv, mut control_send) = control.split();
|
||||
let send_task = tokio::spawn(async move {
|
||||
while let Some(event) = control_message_queue_recv.recv().await {
|
||||
let response = serde_json::to_vec(&event)?;
|
||||
control_send.write_all(&[response.len() as u8]).await?;
|
||||
control_send.write_all(&response).await?;
|
||||
}
|
||||
Ok::<_, tokio::io::Error>(())
|
||||
});
|
||||
let control_message_queue_ref = &control_message_queue;
|
||||
set.spawn(async move {
|
||||
loop {
|
||||
let mut buf = vec![0u8; control_recv.read_u8().await? as _];
|
||||
control_recv.read_exact(&mut buf).await?;
|
||||
control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage);
|
||||
}
|
||||
Ok::<_, tokio::io::Error>(())
|
||||
});
|
||||
enum Event {
|
||||
RouterEvent(RouterRequest),
|
||||
TaskSet(Result<Result<(), tokio::io::Error>, JoinError>)
|
||||
}
|
||||
while let Some(remote) = tokio::select! {
|
||||
v = handle.next() => v.map(Event::RouterEvent),
|
||||
v = set.join_next() => v.map(Event::TaskSet),
|
||||
} {
|
||||
match remote {
|
||||
Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => {
|
||||
let stream = connection.open_bidirectional_stream().await;
|
||||
set.spawn(async move {
|
||||
if let Err(
|
||||
ConnectionError::Transport { .. }
|
||||
| ConnectionError::Application { .. }
|
||||
| ConnectionError::EndpointClosing { .. }
|
||||
| ConnectionError::ImmediateClose { .. },
|
||||
) = stream
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
let mut stream = stream?;
|
||||
handshake.send(&mut stream).await?;
|
||||
tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?;
|
||||
Ok::<_, tokio::io::Error>(())
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
r = async {
|
||||
while let Some(remote) = handle.next().await {
|
||||
match remote {
|
||||
routing::RouterRequest::RouteRequest(remote) => {
|
||||
let pair = connection.open_bi().await;
|
||||
if let Err(ConnectionError::ApplicationClosed(_)) = pair {
|
||||
break;
|
||||
} else if let Err(ConnectionError::ConnectionClosed(_)) = pair {
|
||||
break;
|
||||
}
|
||||
remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
}
|
||||
routing::RouterRequest::BroadcastRequest(message) => {
|
||||
let response =
|
||||
serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast {
|
||||
message,
|
||||
})?;
|
||||
send_control.write_all(&[response.len() as u8]).await?;
|
||||
send_control.write_all(&response).await?;
|
||||
}
|
||||
Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => {
|
||||
control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast {
|
||||
message,
|
||||
});
|
||||
}
|
||||
Event::TaskSet(Ok(Ok(()))) => {}
|
||||
Event::TaskSet(Ok(Err(e))) => {
|
||||
if e.kind() != ErrorKind::UnexpectedEof {
|
||||
error!("Error in task: {e:?}")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
} => r
|
||||
Event::TaskSet(Err(e)) => {
|
||||
error!("Error in task: {e:?}")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
send_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
|
||||
async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
|
||||
if let Err(e) = try_handle_quic(connection, routing_table).await {
|
||||
error!("Error handling QUIClime connection: {}", e);
|
||||
};
|
||||
|
@ -188,7 +224,7 @@ async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
|
|||
}
|
||||
|
||||
async fn listen_quic(
|
||||
endpoint: &'static Endpoint,
|
||||
mut endpoint: Server,
|
||||
routing_table: &'static RoutingTable,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
while let Some(connection) = endpoint.accept().await {
|
||||
|
@ -198,7 +234,6 @@ async fn listen_quic(
|
|||
}
|
||||
|
||||
async fn listen_control(
|
||||
endpoint: &'static Endpoint,
|
||||
routing_table: &'static RoutingTable,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
let app = axum::Router::new()
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
use log::info;
|
||||
use log::warn;
|
||||
use parking_lot::RwLock;
|
||||
use quinn::RecvStream;
|
||||
use quinn::SendStream;
|
||||
use rand::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::netty::Handshake;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RouterRequest {
|
||||
RouteRequest(RouterCallback),
|
||||
RouteRequest(RouterConnection),
|
||||
BroadcastRequest(String),
|
||||
}
|
||||
|
||||
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>;
|
||||
type RouterConnection = (Handshake, TcpStream);
|
||||
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
|
@ -44,14 +45,12 @@ impl RoutingTable {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> {
|
||||
let (send, recv) = oneshot::channel();
|
||||
pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> {
|
||||
self.table
|
||||
.read()
|
||||
.get(domain)?
|
||||
.send(RouterRequest::RouteRequest(send))
|
||||
.ok()?;
|
||||
recv.await.ok()
|
||||
.send(RouterRequest::RouteRequest(conn))
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub fn register(&self) -> RoutingHandle {
|
||||
|
|
Loading…
Reference in a new issue