use std::collections::HashMap; use std::env; use std::net::SocketAddr; use anyhow::{anyhow, Result}; use async_tungstenite::tokio::TokioAdapter; use async_tungstenite::tungstenite::Message; use async_tungstenite::WebSocketStream; use e4mc_common::{ClientboundControlMessage, ServerboundControlMessage}; use futures::stream::{SplitSink, SplitStream}; use futures::{Sink, SinkExt, Stream, StreamExt}; use lazy_static::lazy_static; use log::{error, info, trace, warn}; use netty::ReadExtNetty; use rand::seq::SliceRandom; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{Mutex, RwLock}; use tokio::task; use crate::netty::{NettyReadError, WriteExtNetty}; mod netty; mod wordlist; #[derive(Debug, Clone)] struct Handshake { protocol_version: i32, server_address: String, server_port: u16, next_state: HandshakeType, } #[derive(Debug, Clone, Copy)] #[repr(i32)] enum HandshakeType { Status = 1, Login = 2, } impl Handshake { async fn new(mut packet: &[u8]) -> Result { let packet_type = packet.read_varint().await?; if packet_type != 0 { Err(anyhow!("Not a Handshake packet")) } else { let protocol_version = packet.read_varint().await?; let server_address = packet.read_string().await?; let server_port = packet.read_u16().await?; let next_state = match packet.read_varint().await? { 1 => HandshakeType::Status, 2 => HandshakeType::Login, _ => return Err(anyhow!("Invalid next state")), }; Ok(Self { protocol_version, server_address, server_port, next_state, }) } } async fn send(&self, mut writer: impl AsyncWriteExt + Unpin + Send) -> tokio::io::Result<()> { let mut buf = vec![]; buf.write_varint(0).await?; buf.write_varint(self.protocol_version).await?; buf.write_string(&self.server_address).await?; buf.write_all(&self.server_port.to_be_bytes()).await?; buf.write_varint(self.next_state as i32).await?; writer.write_varint(buf.len() as i32).await?; writer.write_all(&buf).await?; Ok(()) } } #[derive(Debug)] enum WebsocketHandlerMessage { ChannelOpen { id_callback: tokio::sync::oneshot::Sender>, backchannel: UnboundedSender, addr: SocketAddr, }, ChannelClose(u8), Data(Vec), } #[derive(Debug, Clone, PartialEq, Eq)] enum MinecraftHandlerMessage { ChannelClose, Data(Vec), } lazy_static! { static ref ROUTING_MAP: RwLock>> = RwLock::new(HashMap::new()); static ref BASE_DOMAIN: String = env::var("BASE_DOMAIN").expect("BASE_DOMAIN missing"); } #[tokio::main] async fn main() -> Result<()> { let _ = env_logger::try_init(); let ws_bind_addr = env::var("WS_BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:80".to_string()); let mc_bind_addr = env::var("MC_BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:25565".to_string()); futures::try_join!( async { let listener = TcpListener::bind(&ws_bind_addr).await?; info!("WebSocket Listening on: {}", ws_bind_addr); while let Ok((stream, _)) = listener.accept().await { task::spawn(async { if let Err(e) = accept_websocket_connection(stream).await { error!("Error handling WebSocket connection: {}", e); } }); } Ok(()) as Result<()> }, async { let listener = TcpListener::bind(&mc_bind_addr).await?; info!("Minecraft Listening on: {}", mc_bind_addr); while let Ok((stream, _)) = listener.accept().await { task::spawn(async { if let Err(e) = accept_minecraft_connection(stream).await { error!("Error handling Minecraft connection: {}", e); } }); } Ok(()) as Result<()> } )?; Ok(()) } struct RoutingHandle { receiver: UnboundedReceiver, domain: String, } impl RoutingHandle { async fn new(domain: String) -> Self { let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); ROUTING_MAP.write().await.insert(domain.clone(), sender); Self { receiver, domain } } async fn recv(&mut self) -> Option { self.receiver.recv().await } } impl Drop for RoutingHandle { fn drop(&mut self) { let domain = self.domain.clone(); tokio::spawn(async move { ROUTING_MAP.write().await.remove(&domain); }); } } async fn accept_websocket_connection(stream: TcpStream) -> Result<()> { let addr = stream.peer_addr()?; info!("WebSocket Peer address: {}", addr); let mut ws_stream = async_tungstenite::tokio::accept_async(stream).await?; info!("New WebSocket connection: {}", addr); let domain = get_random_domain().await; ws_stream .send(ClientboundControlMessage::DomainAssigned(domain.clone()).into()) .await?; let handle = RoutingHandle::new(domain.clone()).await; let channel_table = Mutex::new(HashMap::new()); let (mut write, mut read) = ws_stream.split(); tokio::select! { res = handle_websocket_send(&mut write, &channel_table, handle) => { if let Err(e) = res { error!("Error on WebSocket send loop: {}", e); } }, res = handle_websocket_recv(&mut read, &channel_table) => { if let Err(e) = res { error!("Error on WebSocket recv loop: {}", e); } } } let _ = write.reunite(read)?.close(None).await; // if this errors, why bother? Ok(()) } async fn handle_websocket_send( write: &mut SplitSink>, Message>, channel_table: &Mutex>>, mut handle: RoutingHandle, ) -> Result<()> { while let Some(data) = handle.recv().await { match data { WebsocketHandlerMessage::ChannelOpen { id_callback, backchannel, addr, } => { let mut table = channel_table.lock().await; let channel = get_available_channel(table.keys().copied().collect()); if id_callback.send(channel).is_err() { warn!("ID callback died before we could send the assigned ID"); } else if let Some(channel) = channel { table.insert(channel, backchannel); write .send(ClientboundControlMessage::ChannelOpen(channel, addr).into()) .await?; } } WebsocketHandlerMessage::ChannelClose(channel) => { write .send(ClientboundControlMessage::ChannelClosed(channel).into()) .await?; channel_table.lock().await.remove(&channel); } WebsocketHandlerMessage::Data(buf) => { write.send(Message::Binary(buf)).await?; } } } Ok(()) } async fn handle_websocket_recv( read: &mut SplitStream>>, channel_table: &Mutex>>, ) -> Result<()> { while let Some(message) = read.next().await { if let Err(async_tungstenite::tungstenite::Error::ConnectionClosed) = message { info!("Connection closed normally"); return Ok(()); } let message = message?; match message { Message::Text(message) => { let message: ServerboundControlMessage = serde_json::from_str(&message)?; match message { ServerboundControlMessage::ChannelClosed(channel) => { if let Some(sender) = channel_table.lock().await.remove(&channel) { sender.send(MinecraftHandlerMessage::ChannelClose)?; } } } } Message::Binary(buf) => { let channel = buf[0]; if let Some(sender) = channel_table.lock().await.get(&channel) { sender.send(MinecraftHandlerMessage::Data(buf[1..].to_vec()))?; } } _ => {} } } Ok(()) } struct ChannelHandle { receiver: UnboundedReceiver, sender: UnboundedSender, channel: u8, } impl ChannelHandle { async fn new( sender: &UnboundedSender, addr: SocketAddr, ) -> Result { let (mc_sender, mc_receiver) = tokio::sync::mpsc::unbounded_channel(); let (id_sender, id_receiver) = tokio::sync::oneshot::channel(); sender.send(WebsocketHandlerMessage::ChannelOpen { id_callback: id_sender, backchannel: mc_sender, addr, })?; if let Some(channel) = id_receiver.await? { Ok(Self { receiver: mc_receiver, sender: sender.clone(), channel, }) } else { Err(anyhow!("Websocket handler couldn't give channel")) } } fn send(&self, mut buf: Vec) -> Result<()> { buf.insert(0, self.channel); self.sender.send(WebsocketHandlerMessage::Data(buf))?; Ok(()) } async fn recv(&mut self) -> Option { self.receiver.recv().await } fn split(&mut self) -> (ChannelHandleSend, ChannelHandleRecv) { ( ChannelHandleSend { sender: &mut self.sender, channel: self.channel, }, ChannelHandleRecv { receiver: &mut self.receiver, }, ) } } struct ChannelHandleSend<'a> { sender: &'a mut UnboundedSender, channel: u8, } impl ChannelHandleSend<'_> { fn send(&self, mut buf: Vec) -> Result<()> { buf.insert(0, self.channel); self.sender.send(WebsocketHandlerMessage::Data(buf))?; Ok(()) } } struct ChannelHandleRecv<'a> { receiver: &'a mut UnboundedReceiver, } impl ChannelHandleRecv<'_> { async fn recv(&mut self) -> Option { self.receiver.recv().await } } impl Drop for ChannelHandle { fn drop(&mut self) { self.sender .send(WebsocketHandlerMessage::ChannelClose(self.channel)) .unwrap(); } } async fn accept_minecraft_connection(mut stream: TcpStream) -> Result<()> { let addr = stream.peer_addr()?; info!("Minecraft Peer address: {}", addr); let packet = stream.read_packet().await; if let Err(NettyReadError::LegacyServerListPing) = packet { stream .write_all(include_bytes!("legacy_serverlistping_response.bin")) .await?; return Ok(()); } let handshake = Handshake::new(&packet?).await?; info!("Minecraft client {} is connecting to {}", addr, handshake.server_address); if let Some(sender) = ROUTING_MAP.read().await.get(&handshake.server_address) { let mut handle = ChannelHandle::new(sender, addr).await?; let mut buf = vec![]; handshake.send(&mut buf).await?; handle.send(buf)?; let (send, recv) = handle.split(); let (read, write) = stream.split(); tokio::select! { res = handle_minecraft_send(write, recv) => { if let Err(e) = res { error!("Error on Minecraft send loop: {}", e); } }, res = handle_minecraft_recv(read, send) => { if let Err(e) = res { error!("Error on Minecraft recv loop: {}", e); } } } stream.shutdown().await?; } else { match handshake.next_state { HandshakeType::Status => { let mut buf = vec![]; buf.write_varint(0).await?; buf.write_string(include_str!("./serverlistping_response.json")) .await?; stream.write_varint(buf.len() as i32).await?; stream.write_all(&buf).await?; } HandshakeType::Login => { let _ = stream.read_packet().await?; let mut buf = vec![]; buf.write_varint(0).await?; buf.write_string(include_str!("./disconnect_response.json")) .await?; stream.write_varint(buf.len() as i32).await?; stream.write_all(&buf).await?; } } } Ok(()) } async fn handle_minecraft_send( mut write: WriteHalf<'_>, mut recv: ChannelHandleRecv<'_>, ) -> Result<()> { while let Some(data) = recv.recv().await { match data { MinecraftHandlerMessage::ChannelClose => return Ok(()), MinecraftHandlerMessage::Data(buf) => write.write_all(&buf).await?, } } Ok(()) } async fn handle_minecraft_recv(mut read: ReadHalf<'_>, send: ChannelHandleSend<'_>) -> Result<()> { let mut buf = [0u8; 1024]; loop { let ready = read.ready(tokio::io::Interest::READABLE).await?; if ready.is_read_closed() { return Ok(()); } let len = read.read(&mut buf).await?; if len == 0 { continue; } let packet = Vec::from(&buf[..len]); trace!("send: {:?}", packet); send.send(packet)?; } } async fn get_random_domain() -> String { let mut domain = format!( "{}-{}.{}", wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(), wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(), BASE_DOMAIN.as_str() ); let map = ROUTING_MAP.read().await; while map.contains_key(&domain) { warn!("Randomly selected domain {} conflicts; trying again", domain); domain = format!( "{}-{}.{}", wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(), wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(), BASE_DOMAIN.as_str() ); } domain } fn get_available_channel(used: Vec) -> Option { (0u8..=255).find(|&i| !used.contains(&i)) }