This commit is contained in:
Skye 2025-12-28 20:49:40 +09:00
parent bf24b4ecaa
commit b5b0676e64
Signed by: me
GPG key ID: 0104BC05F41B77B8
4 changed files with 109 additions and 16 deletions

View file

@ -5,6 +5,7 @@
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
use axum::{
extract::Path,
http::StatusCode,
routing::{get, post},
};
@ -14,7 +15,7 @@ use netty::{Handshake, ReadError};
use quinn::{
crypto::rustls::QuicServerConfig,
rustls::pki_types::{CertificateDer, PrivateKeyDer},
ConnectionError, Endpoint, Incoming, ServerConfig, TransportConfig,
ConnectionError, Endpoint, Incoming, ServerConfig, TransportConfig, VarInt,
};
use routing::{RoutingError, RoutingTable};
use tokio::{
@ -23,8 +24,9 @@ use tokio::{
};
use crate::{
netty::{ReadExt, WriteExt},
netty::{read_varint, ReadExt, WriteExt},
proto::{ClientboundControlMessage, ServerboundControlMessage},
routing::RouterRequest,
};
mod netty;
@ -95,26 +97,50 @@ async fn try_handle_quic(connection: Incoming, routing_table: &RoutingTable) ->
);
let (mut send_control, mut recv_control) = connection.accept_bi().await?;
info!("Control channel open: {}", connection.remote_address());
let mut handle = loop {
let mut buf = vec![0u8; recv_control.read_u8().await? as _];
let mut dialtone_ticket = None;
let (mut handle, sender) = loop {
let len = read_varint(&mut recv_control).await?;
if !(0..=8192).contains(&len) {
connection.close(VarInt::from_u32(0), &[]);
return Ok(());
}
let mut buf = vec![0u8; len as usize];
recv_control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) {
match parsed {
ServerboundControlMessage::ProbeCapabilities => {
let response =
serde_json::to_vec(&ClientboundControlMessage::HasCapabilities {
caps: vec!["dialtone_sidecar".to_string()],
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
continue;
}
ServerboundControlMessage::RequestDomainAssignment => {
let handle = routing_table.register();
info!(
"Domain assigned to {}: {}",
connection.remote_address(),
handle.domain()
handle.0.domain()
);
let response =
serde_json::to_vec(&ClientboundControlMessage::DomainAssignmentComplete {
domain: handle.domain().to_string(),
domain: handle.0.domain().to_string(),
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
break handle;
}
ServerboundControlMessage::DialtoneRegisterTicket { ticket } => {
dialtone_ticket = Some(ticket);
let response =
serde_json::to_vec(&ClientboundControlMessage::TicketRegistered)?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
}
}
}
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
@ -150,10 +176,43 @@ async fn try_handle_quic(connection: Incoming, routing_table: &RoutingTable) ->
})?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
},
routing::RouterRequest::ServerboundControlMessage(message) => {
match message {
ServerboundControlMessage::DialtoneRegisterTicket { ticket } => {
info!("registering ticket {ticket:?}");
dialtone_ticket = Some(ticket);
let response = serde_json::to_vec(&ClientboundControlMessage::TicketRegistered)?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
},
_ => {
let response = serde_json::to_vec(&ClientboundControlMessage::UnknownMessage)?;
send_control.write_all(&[response.len() as u8]).await?;
send_control.write_all(&response).await?;
}
}
},
routing::RouterRequest::TicketRequest(callback) => {
_ = callback.send(dialtone_ticket.clone());
}
}
}
Ok(())
} => r,
r = async {
loop {
let len = read_varint(&mut recv_control).await?;
if !(0..=8192).contains(&len) {
connection.close(VarInt::from_u32(0), &[]);
return Ok(());
}
let mut buf = vec![0u8; len as usize];
recv_control.read_exact(&mut buf).await?;
if let Ok(parsed) = serde_json::from_slice(&buf) {
sender.send(RouterRequest::ServerboundControlMessage(parsed))?;
}
}
} => r
}
}
@ -180,6 +239,18 @@ async fn listen_control(
routing_table: &'static RoutingTable,
) -> eyre::Result<Infallible> {
let app = axum::Router::new()
.route(
"/.well-known/dialtone_ticket/:domain",
get(async |Path(addr): Path<String>| {
let Some(addr) = unicode_madness::validate_and_normalize_domain(&addr) else {
return (StatusCode::NOT_FOUND, String::new());
};
match routing_table.check_ticket(&addr).await {
Some(ticket) => (StatusCode::OK, ticket),
None => (StatusCode::NOT_FOUND, String::new()),
}
}),
)
.route(
"/metrics",
get(|| async { format!("host_count {}", routing_table.size()) }),

View file

@ -122,7 +122,7 @@ pub async fn read_packet(
Ok(buf)
}
async fn read_varint(mut reader: impl AsyncReadExt + Unpin) -> Result<i32, ReadError> {
pub async fn read_varint(mut reader: impl AsyncReadExt + Unpin) -> Result<i32, ReadError> {
let mut res = 0i32;
for i in 0..5 {
let part = reader.read_u8().await?;

View file

@ -1,17 +1,21 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "kind")]
#[serde(rename_all = "snake_case")]
pub enum ServerboundControlMessage {
ProbeCapabilities,
RequestDomainAssignment,
DialtoneRegisterTicket { ticket: String },
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "kind")]
#[serde(rename_all = "snake_case")]
pub enum ClientboundControlMessage {
UnknownMessage,
HasCapabilities { caps: Vec<String> },
DomainAssignmentComplete { domain: String },
RequestMessageBroadcast { message: String },
TicketRegistered,
}

View file

@ -11,13 +11,18 @@ use std::net::IpAddr;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use crate::proto::ServerboundControlMessage;
#[derive(Debug)]
pub enum RouterRequest {
RouteRequest(RouterCallback),
BroadcastRequest(String),
ServerboundControlMessage(ServerboundControlMessage),
TicketRequest(TicketCallback),
}
type RouterCallback = oneshot::Sender<(SendStream, RecvStream)>;
type TicketCallback = oneshot::Sender<Option<String>>;
type RouteRequestReceiver = mpsc::UnboundedSender<RouterRequest>;
#[allow(clippy::module_name_repetitions)]
@ -82,7 +87,7 @@ impl RoutingTable {
)
}
pub fn register(&self) -> RoutingHandle {
pub fn register(&self) -> (RoutingHandle, RouteRequestReceiver) {
let mut lock = self.table.write();
let mut domain = self.random_domain();
while lock.contains_key(&domain) {
@ -95,12 +100,25 @@ impl RoutingTable {
domain = crate::unicode_madness::validate_and_normalize_domain(&domain)
.expect("Resulting domain is not valid");
let (send, recv) = mpsc::unbounded_channel();
lock.insert(domain.clone(), send);
RoutingHandle {
recv,
domain,
parent: self,
}
lock.insert(domain.clone(), send.clone());
(
RoutingHandle {
recv,
domain,
parent: self,
},
send,
)
}
pub async fn check_ticket(&self, domain: &str) -> Option<String> {
let (send, recv) = oneshot::channel();
self.table
.read()
.get(domain)?
.send(RouterRequest::TicketRequest(send))
.ok()?;
recv.await.ok()?
}
}