e4mc/server/src/main.rs
2023-05-14 01:12:25 +09:00

464 lines
15 KiB
Rust

use std::collections::HashMap;
use std::env;
use std::net::SocketAddr;
use anyhow::{anyhow, Result};
use async_tungstenite::tokio::TokioAdapter;
use async_tungstenite::tungstenite::Message;
use async_tungstenite::WebSocketStream;
use e4mc_common::{ClientboundControlMessage, ServerboundControlMessage};
use futures::stream::{SplitSink, SplitStream};
use futures::{Sink, SinkExt, Stream, StreamExt};
use lazy_static::lazy_static;
use log::{error, info, trace, warn};
use netty::ReadExtNetty;
use rand::seq::SliceRandom;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{Mutex, RwLock};
use tokio::task;
use crate::netty::{NettyReadError, WriteExtNetty};
mod netty;
mod wordlist;
#[derive(Debug, Clone)]
struct Handshake {
protocol_version: i32,
server_address: String,
server_port: u16,
next_state: HandshakeType,
}
#[derive(Debug, Clone, Copy)]
#[repr(i32)]
enum HandshakeType {
Status = 1,
Login = 2,
}
impl Handshake {
async fn new(mut packet: &[u8]) -> Result<Self> {
let packet_type = packet.read_varint().await?;
if packet_type != 0 {
Err(anyhow!("Not a Handshake packet"))
} else {
let protocol_version = packet.read_varint().await?;
let server_address = packet.read_string().await?;
let server_port = packet.read_u16().await?;
let next_state = match packet.read_varint().await? {
1 => HandshakeType::Status,
2 => HandshakeType::Login,
_ => return Err(anyhow!("Invalid next state")),
};
Ok(Self {
protocol_version,
server_address,
server_port,
next_state,
})
}
}
async fn send(&self, mut writer: impl AsyncWriteExt + Unpin + Send) -> tokio::io::Result<()> {
let mut buf = vec![];
buf.write_varint(0).await?;
buf.write_varint(self.protocol_version).await?;
buf.write_string(&self.server_address).await?;
buf.write_all(&self.server_port.to_be_bytes()).await?;
buf.write_varint(self.next_state as i32).await?;
writer.write_varint(buf.len() as i32).await?;
writer.write_all(&buf).await?;
Ok(())
}
}
#[derive(Debug)]
enum WebsocketHandlerMessage {
ChannelOpen {
id_callback: tokio::sync::oneshot::Sender<Option<u8>>,
backchannel: UnboundedSender<MinecraftHandlerMessage>,
addr: SocketAddr,
},
ChannelClose(u8),
Data(Vec<u8>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum MinecraftHandlerMessage {
ChannelClose,
Data(Vec<u8>),
}
lazy_static! {
static ref ROUTING_MAP: RwLock<HashMap<String, UnboundedSender<WebsocketHandlerMessage>>> =
RwLock::new(HashMap::new());
static ref BASE_DOMAIN: String = env::var("BASE_DOMAIN").expect("BASE_DOMAIN missing");
}
#[tokio::main]
async fn main() -> Result<()> {
let _ = env_logger::try_init();
let ws_bind_addr = env::var("WS_BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:80".to_string());
let mc_bind_addr = env::var("MC_BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:25565".to_string());
futures::try_join!(
async {
let listener = TcpListener::bind(&ws_bind_addr).await?;
info!("WebSocket Listening on: {}", ws_bind_addr);
while let Ok((stream, _)) = listener.accept().await {
task::spawn(async {
if let Err(e) = accept_websocket_connection(stream).await {
error!("Error handling WebSocket connection: {}", e);
}
});
}
Ok(()) as Result<()>
},
async {
let listener = TcpListener::bind(&mc_bind_addr).await?;
info!("Minecraft Listening on: {}", mc_bind_addr);
while let Ok((stream, _)) = listener.accept().await {
task::spawn(async {
if let Err(e) = accept_minecraft_connection(stream).await {
error!("Error handling Minecraft connection: {}", e);
}
});
}
Ok(()) as Result<()>
}
)?;
Ok(())
}
struct RoutingHandle {
receiver: UnboundedReceiver<WebsocketHandlerMessage>,
domain: String,
}
impl RoutingHandle {
async fn new(domain: String) -> Self {
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
ROUTING_MAP.write().await.insert(domain.clone(), sender);
Self { receiver, domain }
}
async fn recv(&mut self) -> Option<WebsocketHandlerMessage> {
self.receiver.recv().await
}
}
impl Drop for RoutingHandle {
fn drop(&mut self) {
let domain = self.domain.clone();
tokio::spawn(async move {
ROUTING_MAP.write().await.remove(&domain);
});
}
}
async fn accept_websocket_connection(stream: TcpStream) -> Result<()> {
let addr = stream.peer_addr()?;
info!("WebSocket Peer address: {}", addr);
let mut ws_stream = async_tungstenite::tokio::accept_async(stream).await?;
info!("New WebSocket connection: {}", addr);
let domain = get_random_domain().await;
ws_stream
.send(ClientboundControlMessage::DomainAssigned(domain.clone()).into())
.await?;
let handle = RoutingHandle::new(domain.clone()).await;
let channel_table = Mutex::new(HashMap::new());
let (mut write, mut read) = ws_stream.split();
tokio::select! {
res = handle_websocket_send(&mut write, &channel_table, handle) => {
if let Err(e) = res {
error!("Error on WebSocket send loop: {}", e);
}
},
res = handle_websocket_recv(&mut read, &channel_table) => {
if let Err(e) = res {
error!("Error on WebSocket recv loop: {}", e);
}
}
}
let _ = write.reunite(read)?.close(None).await; // if this errors, why bother?
Ok(())
}
async fn handle_websocket_send(
write: &mut SplitSink<WebSocketStream<TokioAdapter<TcpStream>>, Message>,
channel_table: &Mutex<HashMap<u8, UnboundedSender<MinecraftHandlerMessage>>>,
mut handle: RoutingHandle,
) -> Result<()> {
while let Some(data) = handle.recv().await {
match data {
WebsocketHandlerMessage::ChannelOpen {
id_callback,
backchannel,
addr,
} => {
let mut table = channel_table.lock().await;
let channel = get_available_channel(table.keys().copied().collect());
if id_callback.send(channel).is_err() {
warn!("ID callback died before we could send the assigned ID");
} else if let Some(channel) = channel {
table.insert(channel, backchannel);
write
.send(ClientboundControlMessage::ChannelOpen(channel, addr).into())
.await?;
}
}
WebsocketHandlerMessage::ChannelClose(channel) => {
write
.send(ClientboundControlMessage::ChannelClosed(channel).into())
.await?;
channel_table.lock().await.remove(&channel);
}
WebsocketHandlerMessage::Data(buf) => {
write.send(Message::Binary(buf)).await?;
}
}
}
Ok(())
}
async fn handle_websocket_recv(
read: &mut SplitStream<WebSocketStream<TokioAdapter<TcpStream>>>,
channel_table: &Mutex<HashMap<u8, UnboundedSender<MinecraftHandlerMessage>>>,
) -> Result<()> {
while let Some(message) = read.next().await {
if let Err(async_tungstenite::tungstenite::Error::ConnectionClosed) = message {
info!("Connection closed normally");
return Ok(());
}
let message = message?;
match message {
Message::Text(message) => {
let message: ServerboundControlMessage = serde_json::from_str(&message)?;
match message {
ServerboundControlMessage::ChannelClosed(channel) => {
if let Some(sender) = channel_table.lock().await.remove(&channel) {
sender.send(MinecraftHandlerMessage::ChannelClose)?;
}
}
}
}
Message::Binary(buf) => {
let channel = buf[0];
if let Some(sender) = channel_table.lock().await.get(&channel) {
sender.send(MinecraftHandlerMessage::Data(buf[1..].to_vec()))?;
}
}
_ => {}
}
}
Ok(())
}
struct ChannelHandle {
receiver: UnboundedReceiver<MinecraftHandlerMessage>,
sender: UnboundedSender<WebsocketHandlerMessage>,
channel: u8,
}
impl ChannelHandle {
async fn new(
sender: &UnboundedSender<WebsocketHandlerMessage>,
addr: SocketAddr,
) -> Result<Self> {
let (mc_sender, mc_receiver) = tokio::sync::mpsc::unbounded_channel();
let (id_sender, id_receiver) = tokio::sync::oneshot::channel();
sender.send(WebsocketHandlerMessage::ChannelOpen {
id_callback: id_sender,
backchannel: mc_sender,
addr,
})?;
if let Some(channel) = id_receiver.await? {
Ok(Self {
receiver: mc_receiver,
sender: sender.clone(),
channel,
})
} else {
Err(anyhow!("Websocket handler couldn't give channel"))
}
}
fn send(&self, mut buf: Vec<u8>) -> Result<()> {
buf.insert(0, self.channel);
self.sender.send(WebsocketHandlerMessage::Data(buf))?;
Ok(())
}
async fn recv(&mut self) -> Option<MinecraftHandlerMessage> {
self.receiver.recv().await
}
fn split(&mut self) -> (ChannelHandleSend, ChannelHandleRecv) {
(
ChannelHandleSend {
sender: &mut self.sender,
channel: self.channel,
},
ChannelHandleRecv {
receiver: &mut self.receiver,
},
)
}
}
struct ChannelHandleSend<'a> {
sender: &'a mut UnboundedSender<WebsocketHandlerMessage>,
channel: u8,
}
impl ChannelHandleSend<'_> {
fn send(&self, mut buf: Vec<u8>) -> Result<()> {
buf.insert(0, self.channel);
self.sender.send(WebsocketHandlerMessage::Data(buf))?;
Ok(())
}
}
struct ChannelHandleRecv<'a> {
receiver: &'a mut UnboundedReceiver<MinecraftHandlerMessage>,
}
impl ChannelHandleRecv<'_> {
async fn recv(&mut self) -> Option<MinecraftHandlerMessage> {
self.receiver.recv().await
}
}
impl Drop for ChannelHandle {
fn drop(&mut self) {
self.sender
.send(WebsocketHandlerMessage::ChannelClose(self.channel))
.unwrap();
}
}
async fn accept_minecraft_connection(mut stream: TcpStream) -> Result<()> {
let addr = stream.peer_addr()?;
info!("Minecraft Peer address: {}", addr);
let packet = stream.read_packet().await;
if let Err(NettyReadError::LegacyServerListPing) = packet {
stream
.write_all(include_bytes!("legacy_serverlistping_response.bin"))
.await?;
return Ok(());
}
let handshake = Handshake::new(&packet?).await?;
info!("Minecraft client {} is connecting to {}", addr, handshake.server_address);
if let Some(sender) = ROUTING_MAP.read().await.get(&handshake.server_address) {
let mut handle = ChannelHandle::new(sender, addr).await?;
let mut buf = vec![];
handshake.send(&mut buf).await?;
handle.send(buf)?;
let (send, recv) = handle.split();
let (read, write) = stream.split();
tokio::select! {
res = handle_minecraft_send(write, recv) => {
if let Err(e) = res {
error!("Error on Minecraft send loop: {}", e);
}
},
res = handle_minecraft_recv(read, send) => {
if let Err(e) = res {
error!("Error on Minecraft recv loop: {}", e);
}
}
}
stream.shutdown().await?;
} else {
match handshake.next_state {
HandshakeType::Status => {
let mut buf = vec![];
buf.write_varint(0).await?;
buf.write_string(include_str!("./serverlistping_response.json"))
.await?;
stream.write_varint(buf.len() as i32).await?;
stream.write_all(&buf).await?;
}
HandshakeType::Login => {
let _ = stream.read_packet().await?;
let mut buf = vec![];
buf.write_varint(0).await?;
buf.write_string(include_str!("./disconnect_response.json"))
.await?;
stream.write_varint(buf.len() as i32).await?;
stream.write_all(&buf).await?;
}
}
}
Ok(())
}
async fn handle_minecraft_send(
mut write: WriteHalf<'_>,
mut recv: ChannelHandleRecv<'_>,
) -> Result<()> {
while let Some(data) = recv.recv().await {
match data {
MinecraftHandlerMessage::ChannelClose => return Ok(()),
MinecraftHandlerMessage::Data(buf) => write.write_all(&buf).await?,
}
}
Ok(())
}
async fn handle_minecraft_recv(mut read: ReadHalf<'_>, send: ChannelHandleSend<'_>) -> Result<()> {
let mut buf = [0u8; 1024];
loop {
let ready = read.ready(tokio::io::Interest::READABLE).await?;
if ready.is_read_closed() {
return Ok(());
}
let len = read.read(&mut buf).await?;
if len == 0 {
continue;
}
let packet = Vec::from(&buf[..len]);
trace!("send: {:?}", packet);
send.send(packet)?;
}
}
async fn get_random_domain() -> String {
let mut domain = format!(
"{}-{}.{}",
wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(),
wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(),
BASE_DOMAIN.as_str()
);
let map = ROUTING_MAP.read().await;
while map.contains_key(&domain) {
warn!("Randomly selected domain {} conflicts; trying again", domain);
domain = format!(
"{}-{}.{}",
wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(),
wordlist::ID_WORDS.choose(&mut rand::thread_rng()).unwrap(),
BASE_DOMAIN.as_str()
);
}
domain
}
fn get_available_channel(used: Vec<u8>) -> Option<u8> {
(0u8..=255).find(|&i| !used.contains(&i))
}