diff --git a/src/main.rs b/src/main.rs index e5d506a..4715fd1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ + extract::Path, http::StatusCode, routing::{get, post}, }; @@ -14,7 +15,7 @@ use netty::{Handshake, ReadError}; use quinn::{ crypto::rustls::QuicServerConfig, rustls::pki_types::{CertificateDer, PrivateKeyDer}, - ConnectionError, Endpoint, Incoming, ServerConfig, TransportConfig, + ConnectionError, Endpoint, Incoming, ServerConfig, TransportConfig, VarInt, }; use routing::{RoutingError, RoutingTable}; use tokio::{ @@ -23,8 +24,9 @@ use tokio::{ }; use crate::{ - netty::{ReadExt, WriteExt}, + netty::{read_varint, ReadExt, WriteExt}, proto::{ClientboundControlMessage, ServerboundControlMessage}, + routing::RouterRequest, }; mod netty; @@ -95,26 +97,50 @@ async fn try_handle_quic(connection: Incoming, routing_table: &RoutingTable) -> ); let (mut send_control, mut recv_control) = connection.accept_bi().await?; info!("Control channel open: {}", connection.remote_address()); - let mut handle = loop { - let mut buf = vec![0u8; recv_control.read_u8().await? as _]; + + let mut dialtone_ticket = None; + + let (mut handle, sender) = loop { + let len = read_varint(&mut recv_control).await?; + if !(0..=8192).contains(&len) { + connection.close(VarInt::from_u32(0), &[]); + return Ok(()); + } + let mut buf = vec![0u8; len as usize]; recv_control.read_exact(&mut buf).await?; if let Ok(parsed) = serde_json::from_slice(&buf) { match parsed { + ServerboundControlMessage::ProbeCapabilities => { + let response = + serde_json::to_vec(&ClientboundControlMessage::HasCapabilities { + caps: vec!["dialtone_sidecar".to_string()], + })?; + send_control.write_all(&[response.len() as u8]).await?; + send_control.write_all(&response).await?; + continue; + } ServerboundControlMessage::RequestDomainAssignment => { let handle = routing_table.register(); info!( "Domain assigned to {}: {}", connection.remote_address(), - handle.domain() + handle.0.domain() ); let response = serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete { - domain: handle.domain().to_string(), + domain: handle.0.domain().to_string(), })?; send_control.write_all(&[response.len() as u8]).await?; send_control.write_all(&response).await?; break handle; } + ServerboundControlMessage::DialtoneRegisterTicket { ticket } => { + dialtone_ticket = Some(ticket); + let response = + serde_json::to_vec(&ClientboundControlMessage::TicketRegistered)?; + send_control.write_all(&[response.len() as u8]).await?; + send_control.write_all(&response).await?; + } } } let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; @@ -150,10 +176,43 @@ async fn try_handle_quic(connection: Incoming, routing_table: &RoutingTable) -> })?; send_control.write_all(&[response.len() as u8]).await?; send_control.write_all(&response).await?; + }, + routing::RouterRequest::ServerboundControlMessage(message) => { + match message { + ServerboundControlMessage::DialtoneRegisterTicket { ticket } => { + info!("registering ticket {ticket:?}"); + dialtone_ticket = Some(ticket); + let response = serde_json::to_vec(&ClientboundControlMessage::TicketRegistered)?; + send_control.write_all(&[response.len() as u8]).await?; + send_control.write_all(&response).await?; + }, + _ => { + let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?; + send_control.write_all(&[response.len() as u8]).await?; + send_control.write_all(&response).await?; + } + } + }, + routing::RouterRequest::TicketRequest(callback) => { + _ = callback.send(dialtone_ticket.clone()); } } } Ok(()) + } => r, + r = async { + loop { + let len = read_varint(&mut recv_control).await?; + if !(0..=8192).contains(&len) { + connection.close(VarInt::from_u32(0), &[]); + return Ok(()); + } + let mut buf = vec![0u8; len as usize]; + recv_control.read_exact(&mut buf).await?; + if let Ok(parsed) = serde_json::from_slice(&buf) { + sender.send(RouterRequest::ServerboundControlMessage(parsed))?; + } + } } => r } } @@ -180,6 +239,18 @@ async fn listen_control( routing_table: &'static RoutingTable, ) -> eyre::Result { let app = axum::Router::new() + .route( + "/.well-known/dialtone_ticket/:domain", + get(async |Path(addr): Path| { + let Some(addr) = unicode_madness::validate_and_normalize_domain(&addr) else { + return (StatusCode::NOT_FOUND, String::new()); + }; + match routing_table.check_ticket(&addr).await { + Some(ticket) => (StatusCode::OK, ticket), + None => (StatusCode::NOT_FOUND, String::new()), + } + }), + ) .route( "/metrics", get(|| async { format!("host_count {}", routing_table.size()) }), diff --git a/src/netty.rs b/src/netty.rs index f15bf90..fb301d8 100644 --- a/src/netty.rs +++ b/src/netty.rs @@ -122,7 +122,7 @@ pub async fn read_packet( Ok(buf) } -async fn read_varint(mut reader: impl AsyncReadExt + Unpin) -> Result { +pub async fn read_varint(mut reader: impl AsyncReadExt + Unpin) -> Result { let mut res = 0i32; for i in 0..5 { let part = reader.read_u8().await?; diff --git a/src/proto.rs b/src/proto.rs index c48d59f..33b10e7 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,17 +1,21 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "kind")] #[serde(rename_all = "snake_case")] pub enum ServerboundControlMessage { + ProbeCapabilities, RequestDomainAssignment, + DialtoneRegisterTicket { ticket: String }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "kind")] #[serde(rename_all = "snake_case")] pub enum ClientboundControlMessage { UnknownMessage, + HasCapabilities { caps: Vec }, DomainAssignmentComplete { domain: String }, RequestMessageBroadcast { message: String }, + TicketRegistered, } diff --git a/src/routing.rs b/src/routing.rs index 7faeb67..9f98caf 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -11,13 +11,18 @@ use std::net::IpAddr; use tokio::sync::mpsc; use tokio::sync::oneshot; +use crate::proto::ServerboundControlMessage; + #[derive(Debug)] pub enum RouterRequest { RouteRequest(RouterCallback), BroadcastRequest(String), + ServerboundControlMessage(ServerboundControlMessage), + TicketRequest(TicketCallback), } type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>; +type TicketCallback = oneshot::Sender>; type RouteRequestReceiver = mpsc::UnboundedSender; #[allow(clippy::module_name_repetitions)] @@ -82,7 +87,7 @@ impl RoutingTable { ) } - pub fn register(&self) -> RoutingHandle { + pub fn register(&self) -> (RoutingHandle, RouteRequestReceiver) { let mut lock = self.table.write(); let mut domain = self.random_domain(); while lock.contains_key(&domain) { @@ -95,12 +100,25 @@ impl RoutingTable { domain = crate::unicode_madness::validate_and_normalize_domain(&domain) .expect("Resulting domain is not valid"); let (send, recv) = mpsc::unbounded_channel(); - lock.insert(domain.clone(), send); - RoutingHandle { - recv, - domain, - parent: self, - } + lock.insert(domain.clone(), send.clone()); + ( + RoutingHandle { + recv, + domain, + parent: self, + }, + send, + ) + } + + pub async fn check_ticket(&self, domain: &str) -> Option { + let (send, recv) = oneshot::channel(); + self.table + .read() + .get(domain)? + .send(RouterRequest::TicketRequest(send)) + .ok()?; + recv.await.ok()? } }