Compare commits

..

No commits in common. 's2n' and 'kity' have entirely different histories.
s2n ... kity

818
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -14,9 +14,10 @@ env_logger = "0.10.0"
idna = "0.4.0"
log = "0.4.19"
parking_lot = "0.12.1"
quinn = "0.10.1"
rand = "0.8.5"
rustls = "0.21.9"
rustls-pemfile = "1.0.2"
s2n-quic = { version = "1.46.0", default-features = false, features = ["provider-address-token-default", "provider-tls-rustls"] }
serde = { version = "1.0.164", features = ["derive"] }
serde_json = "1.0.97"
thiserror = "1.0.40"
@ -24,4 +25,4 @@ tokio = { version = "1.28.2", features = ["rt-multi-thread", "fs", "macros", "io
[profile.release]
lto = "fat"
debug = "full"
debug = "full"

@ -2,7 +2,7 @@
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
use std::{convert::Infallible, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
use anyhow::{anyhow, Context};
use axum::{
@ -11,17 +11,17 @@ use axum::{
};
use log::{error, info};
use netty::{Handshake, ReadError};
use quinn::{Connecting, ConnectionError, Endpoint, ServerConfig, TransportConfig};
use routing::RoutingTable;
use s2n_quic::{connection::Error as ConnectionError, Connection, Server};
use rustls::{Certificate, PrivateKey};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
task::{JoinError, JoinSet},
};
use crate::{
netty::{ReadExt, WriteExt},
proto::{ClientboundControlMessage, ServerboundControlMessage}, routing::RouterRequest,
proto::{ClientboundControlMessage, ServerboundControlMessage},
};
mod netty;
@ -107,116 +107,80 @@ async fn main() -> anyhow::Result<()> {
}
async fn try_handle_quic(
connection: Connection,
connection: Connecting,
routing_table: &RoutingTable,
) -> anyhow::Result<()> {
let connection = connection.await?;
info!(
"QUIClime connection established to: {}",
connection.remote_addr()?
connection.remote_address()
);
let mut control = connection
.accept_bidirectional_stream()
.await?
.ok_or(anyhow!(
"Connection closed while waiting for control channel"
))?;
info!("Control channel open: {}", connection.remote_addr()?);
let (mut send_control, mut recv_control) = connection.accept_bi().await?;
info!("Control channel open: {}", connection.remote_address());
let mut handle = loop {
let mut buf = vec![0u8; control.read_u8().await? as _];
control.read_exact(&mut buf).await?;
let mut buf = vec![0u8; recv_control.read_u8().await? as _];
recv_control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) {
match parsed {
ServerboundControlMessage::RequestDomainAssignment => {
let handle = routing_table.register();
info!(
"Domain assigned to {}: {}",
connection.remote_addr()?,
connection.remote_address(),
handle.domain()
);
let response =
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
domain: handle.domain().to_string(),
})?;
control.write_all(&[response.len() as u8]).await?;
control.write_all(&response).await?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
break handle;
}
}
}
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
control.write_all(&[response.len() as u8]).await?;
control.write_all(&response).await?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
};
let mut set = JoinSet::new();
let (control_message_queue, mut control_message_queue_recv) = tokio::sync::mpsc::unbounded_channel();
let (mut control_recv, mut control_send) = control.split();
let send_task = tokio::spawn(async move {
while let Some(event) = control_message_queue_recv.recv().await {
let response = serde_json::to_vec(&event)?;
control_send.write_all(&[response.len() as u8]).await?;
control_send.write_all(&response).await?;
}
Ok::<_, tokio::io::Error>(())
});
let control_message_queue_ref = &control_message_queue;
set.spawn(async move {
loop {
let mut buf = vec![0u8; control_recv.read_u8().await? as _];
control_recv.read_exact(&mut buf).await?;
control_message_queue_ref.send(ClientboundControlMessage::UnknownMessage);
}
Ok::<_, tokio::io::Error>(())
});
enum Event {
RouterEvent(RouterRequest),
TaskSet(Result<Result<(), tokio::io::Error>, JoinError>)
}
while let Some(remote) = tokio::select! {
v = handle.next() => v.map(Event::RouterEvent),
v = set.join_next() => v.map(Event::TaskSet),
} {
match remote {
Event::RouterEvent(RouterRequest::RouteRequest((handshake, mut client_stream))) => {
let stream = connection.open_bidirectional_stream().await;
set.spawn(async move {
if let Err(
ConnectionError::Transport { .. }
| ConnectionError::Application { .. }
| ConnectionError::EndpointClosing { .. }
| ConnectionError::ImmediateClose { .. },
) = stream
{
Ok(())
} else {
let mut stream = stream?;
handshake.send(&mut stream).await?;
tokio::io::copy_bidirectional(&mut stream, &mut client_stream).await?;
Ok::<_, tokio::io::Error>(())
}
});
}
Event::RouterEvent(RouterRequest::BroadcastRequest(message)) => {
control_message_queue.send(ClientboundControlMessage::RequestMessageBroadcast {
message,
});
tokio::select! {
e = connection.closed() => {
match e {
ConnectionError::ConnectionClosed(_)
| ConnectionError::ApplicationClosed(_)
| ConnectionError::LocallyClosed => Ok(()),
e => Err(e.into()),
}
Event::TaskSet(Ok(Ok(()))) => {}
Event::TaskSet(Ok(Err(e))) => {
if e.kind() != ErrorKind::UnexpectedEof {
error!("Error in task: {e:?}")
},
r = async {
while let Some(remote) = handle.next().await {
match remote {
routing::RouterRequest::RouteRequest(remote) => {
let pair = connection.open_bi().await;
if let Err(ConnectionError::ApplicationClosed(_)) = pair {
break;
} else if let Err(ConnectionError::ConnectionClosed(_)) = pair {
break;
}
remote.send(pair?).map_err(|e| anyhow::anyhow!("{:?}", e))?;
}
routing::RouterRequest::BroadcastRequest(message) => {
let response =
serde_json::to_vec(&ClientboundControlMessage::RequestMessageBroadcast {
message,
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
}
}
}
Event::TaskSet(Err(e)) => {
error!("Error in task: {e:?}")
}
}
Ok(())
} => r
}
send_task.abort();
Ok(())
}
async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
async fn handle_quic(connection: Connecting, routing_table: &RoutingTable) {
if let Err(e) = try_handle_quic(connection, routing_table).await {
error!("Error handling QUIClime connection: {}", e);
};
@ -224,7 +188,7 @@ async fn handle_quic(connection: Connection, routing_table: &RoutingTable) {
}
async fn listen_quic(
mut endpoint: Server,
endpoint: &'static Endpoint,
routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> {
while let Some(connection) = endpoint.accept().await {
@ -234,6 +198,7 @@ async fn listen_quic(
}
async fn listen_control(
endpoint: &'static Endpoint,
routing_table: &'static RoutingTable,
) -> anyhow::Result<Infallible> {
let app = axum::Router::new()

@ -1,21 +1,20 @@
use log::info;
use log::warn;
use parking_lot::RwLock;
use quinn::RecvStream;
use quinn::SendStream;
use rand::prelude::*;
use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use crate::netty::Handshake;
#[derive(Debug)]
pub enum RouterRequest {
RouteRequest(RouterConnection),
RouteRequest(RouterCallback),
BroadcastRequest(String),
}
type RouterConnection = (Handshake, TcpStream);
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>;
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
#[allow(clippy::module_name_repetitions)]
@ -45,12 +44,14 @@ impl RoutingTable {
}
}
pub fn route(&self, domain: &str, conn: RouterConnection) -> Option<()> {
pub async fn route(&self, domain: &str) -> Option<(SendStream, RecvStream)> {
let (send, recv) = oneshot::channel();
self.table
.read()
.get(domain)?
.send(RouterRequest::RouteRequest(conn))
.ok()
.send(RouterRequest::RouteRequest(send))
.ok()?;
recv.await.ok()
}
pub fn register(&self) -> RoutingHandle {

Loading…
Cancel
Save