430 lines
14 KiB
Rust
430 lines
14 KiB
Rust
#[macro_use]
|
|
extern crate anyhow;
|
|
#[macro_use]
|
|
extern crate serde;
|
|
#[macro_use]
|
|
extern crate tracing;
|
|
|
|
mod pty;
|
|
|
|
use std::ffi::CStr;
|
|
use std::os::fd::FromRawFd;
|
|
use std::os::unix::process::CommandExt;
|
|
use std::process::Stdio;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, Result};
|
|
use quinn::{RecvStream, SendStream};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::process::{Child, Command};
|
|
use tracing::Instrument;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
enum Stream {
|
|
Shell,
|
|
Heartbeat,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
tracing::subscriber::set_global_default(
|
|
tracing_subscriber::FmtSubscriber::builder()
|
|
.with_writer(std::io::stderr)
|
|
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
|
.finish(),
|
|
)
|
|
.unwrap();
|
|
|
|
let mut args = std::env::args();
|
|
_ = args.next();
|
|
|
|
let ctrl_c = tokio::signal::ctrl_c();
|
|
let fut = run_cmd(args);
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => Ok(()),
|
|
r = fut => r,
|
|
}
|
|
}
|
|
|
|
async fn run_cmd(mut args: std::env::Args) -> Result<()> {
|
|
let cmd = args.next().expect("COMMAND");
|
|
match cmd.as_str() {
|
|
"server" => run_server().await,
|
|
"client" => run_client().await,
|
|
_ => Err(anyhow!("Unrecognized command: {}", cmd)),
|
|
}
|
|
}
|
|
|
|
const ALPN_QUIC_SHELL: &str = "quic-shell";
|
|
|
|
async fn run_server() -> Result<()> {
|
|
let opt_shell = &*Box::leak(
|
|
std::env::var("SHELL")
|
|
.context("SHELL not defined")?
|
|
.into_boxed_str(),
|
|
);
|
|
let opt_listen = std::env::var("BIND_ADDR")
|
|
.unwrap_or_else(|_| "127.0.0.1:8022".to_owned())
|
|
.parse()?;
|
|
|
|
let subject_alt_names = vec!["localhost".to_string()];
|
|
|
|
let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;
|
|
let key = rustls::PrivateKey(cert.serialize_private_key_der());
|
|
let cert = rustls::Certificate(cert.serialize_der()?);
|
|
|
|
//std::fs::write("key.der", &key.0)?;
|
|
std::fs::write("cert.der", &cert.0)?;
|
|
|
|
let mut server_crypto = rustls::ServerConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_no_client_auth()
|
|
.with_single_cert(vec![cert], key)
|
|
.unwrap();
|
|
server_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()];
|
|
|
|
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
|
|
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
|
|
transport_config.max_concurrent_uni_streams(0_u8.into());
|
|
server_config.use_retry(true);
|
|
|
|
let endpoint = quinn::Endpoint::server(server_config, opt_listen)?;
|
|
eprintln!("listening on {}", endpoint.local_addr()?);
|
|
|
|
while let Some(conn) = endpoint.accept().await {
|
|
info!("connection incoming");
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_connection(opt_shell, conn).await {
|
|
error!("connection failed: {reason}", reason = e.to_string())
|
|
}
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn is_broken_pipe<T>(r: &Result<T, std::io::Error>) -> bool {
|
|
if let Err(e) = r {
|
|
e.kind() == std::io::ErrorKind::BrokenPipe
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
async fn run_client() -> Result<()> {
|
|
info!("running client");
|
|
|
|
let mut roots = rustls::RootCertStore::empty();
|
|
match std::fs::read("cert.der") {
|
|
Ok(cert) => {
|
|
roots.add(&rustls::Certificate(cert))?;
|
|
}
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
|
info!("local server certificate not found");
|
|
}
|
|
Err(e) => {
|
|
error!("failed to open local server certificate: {}", e);
|
|
}
|
|
}
|
|
|
|
info!("read roots");
|
|
|
|
let mut client_crypto = rustls::ClientConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_root_certificates(roots)
|
|
.with_no_client_auth();
|
|
|
|
client_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()];
|
|
|
|
let mut transport = quinn::TransportConfig::default();
|
|
transport.keep_alive_interval(Some(std::time::Duration::from_secs(5)));
|
|
|
|
let mut client_config = quinn::ClientConfig::new(Arc::new(client_crypto));
|
|
client_config.transport_config(transport.into());
|
|
//let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap())?;
|
|
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?;
|
|
endpoint.set_default_client_config(client_config);
|
|
|
|
info!("connecting");
|
|
|
|
let conn = endpoint
|
|
.connect("127.0.0.1:8022".parse()?, "localhost")?
|
|
.await?;
|
|
let (mut send, mut recv) = conn.open_bi().await?;
|
|
|
|
write_header(&mut send, Stream::Shell).await?;
|
|
|
|
info!("connected");
|
|
|
|
let mut stdin = unsafe { tokio::fs::File::from_raw_fd(libc::STDIN_FILENO) };
|
|
let mut stdout = unsafe { tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO) };
|
|
let mut stdin_buf = Vec::with_capacity(4096);
|
|
let mut stdout_buf = Vec::with_capacity(4096);
|
|
|
|
let mut stdin_eof = false;
|
|
|
|
{
|
|
use nix::sys::termios::*;
|
|
let mut termios = tcgetattr(libc::STDIN_FILENO)?;
|
|
termios.local_flags.remove(LocalFlags::ECHO);
|
|
termios.local_flags.remove(LocalFlags::ICANON);
|
|
termios.local_flags.remove(LocalFlags::ISIG);
|
|
termios.local_flags.remove(LocalFlags::IEXTEN);
|
|
termios.input_flags.remove(InputFlags::IXON);
|
|
termios.input_flags.remove(InputFlags::ICRNL);
|
|
termios.output_flags.remove(OutputFlags::OPOST);
|
|
tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &termios)?;
|
|
}
|
|
|
|
/*let mut heartbeat = {
|
|
let (mut send, recv) = conn.open_bi().await?;
|
|
|
|
write_header(&mut send, Stream::Heartbeat).await?;
|
|
Box::pin(handle_stream_heartbeat(send, recv))
|
|
};*/
|
|
|
|
loop {
|
|
tokio::select! {
|
|
//_ = &mut heartbeat => {}
|
|
r = stdin.read_buf(&mut stdin_buf), if !stdin_eof => {
|
|
if r? == 0 {
|
|
stdin_eof = true;
|
|
}
|
|
send.write_all(&stdin_buf).await?;
|
|
stdin_buf.clear();
|
|
//info!("sent stdin");
|
|
},
|
|
r = recv.read_buf(&mut stdout_buf) => if r? > 0 {
|
|
stdout.write_all(&stdout_buf).await?;
|
|
stdout_buf.clear();
|
|
//info!("recv stdout");
|
|
},
|
|
r = send.stopped() => {
|
|
info!("Remote disconnected");
|
|
let code = r?.into_inner();
|
|
if code == 0 {
|
|
return Ok(());
|
|
} else {
|
|
return Err(anyhow!("Error code {}", code));
|
|
}
|
|
}
|
|
e = conn.closed() => {
|
|
info!("Remote disconnected: {}", e);
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn write_header(send: &mut SendStream, header: Stream) -> Result<()> {
|
|
let buf = rmp_serde::to_vec(&header)?;
|
|
send.write_all(&u16::try_from(buf.len())?.to_le_bytes())
|
|
.await?;
|
|
send.write_all(&buf).await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn read_header(recv: &mut RecvStream) -> Result<Stream> {
|
|
let mut size = [0u8; 2];
|
|
recv.read_exact(&mut size).await?;
|
|
let size = u16::from_le_bytes(size);
|
|
let mut buf = Vec::with_capacity(size.into());
|
|
recv.take(size.into()).read_to_end(&mut buf).await?;
|
|
Ok(rmp_serde::from_slice(&buf)?)
|
|
}
|
|
|
|
async fn handle_connection(opt_shell: &'static str, conn: quinn::Connecting) -> Result<()> {
|
|
let conn = conn.await?;
|
|
let span = info_span!(
|
|
"connection",
|
|
remote = %conn.remote_address(),
|
|
protocol = %conn
|
|
.handshake_data()
|
|
.unwrap()
|
|
.downcast::<quinn::crypto::rustls::HandshakeData>().unwrap()
|
|
.protocol
|
|
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned())
|
|
);
|
|
async {
|
|
info!("established");
|
|
|
|
loop {
|
|
let stream = conn.accept_bi().await;
|
|
let (send, mut recv) = match stream {
|
|
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
|
|
info!("connection closed");
|
|
return Ok(());
|
|
}
|
|
Err(e) => {
|
|
return Err(e.into());
|
|
}
|
|
Ok(s) => s,
|
|
};
|
|
|
|
let stream = read_header(&mut recv).await?;
|
|
let span = info_span!(
|
|
"stream",
|
|
r#type = ?stream
|
|
);
|
|
tokio::task::spawn(
|
|
async move {
|
|
let r = match stream {
|
|
Stream::Shell => handle_stream_shell(opt_shell, send, recv).await,
|
|
Stream::Heartbeat => handle_stream_heartbeat(send, recv).await,
|
|
};
|
|
if let Err(e) = r {
|
|
error!("Error in stream handler: {e}");
|
|
}
|
|
}
|
|
.instrument(span),
|
|
);
|
|
}
|
|
}
|
|
.instrument(span)
|
|
.await
|
|
}
|
|
|
|
async fn handle_stream_shell(
|
|
opt_shell: &str,
|
|
mut send: SendStream,
|
|
mut recv: RecvStream,
|
|
) -> Result<()> {
|
|
let use_pty = true;
|
|
if use_pty {
|
|
let args = if opt_shell == "bash" || opt_shell.ends_with("/bash") {
|
|
vec![CStr::from_bytes_with_nul(b"-i\0")?]
|
|
} else {
|
|
vec![]
|
|
};
|
|
let mut opt_shell_with_nul = Vec::with_capacity(opt_shell.len() + 1);
|
|
opt_shell_with_nul.extend(opt_shell.as_bytes());
|
|
opt_shell_with_nul.push(0);
|
|
let opt_shell = CStr::from_bytes_with_nul(&opt_shell_with_nul)?;
|
|
let mut sh = pty::create_pty(opt_shell, &args)?;
|
|
info!("Created pty");
|
|
|
|
let mut stdin_buf = Vec::with_capacity(4096);
|
|
let mut pty_buf = Vec::with_capacity(4096);
|
|
//let mut pty_buf = [0u8];
|
|
|
|
loop {
|
|
if let Some(code) = sh.try_wait()? {
|
|
send.finish().await?;
|
|
if code != 0 {
|
|
info!("Child exit: {}", code);
|
|
|
|
recv.stop(1u8.into())?;
|
|
return Ok(());
|
|
} else {
|
|
info!("Child exit");
|
|
recv.stop(0u8.into())?;
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
//let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50));
|
|
|
|
tokio::select! {
|
|
/*_ = redraw.tick() => {
|
|
sh.pty.read_buf(&mut pty_buf).await?;
|
|
send.write_all(&pty_buf).await?;
|
|
pty_buf.clear();
|
|
info!("redraw complete");
|
|
}*/
|
|
r = sh.pty.read_buf(&mut pty_buf) => {
|
|
//r = sh.pty.read_exact(&mut pty_buf) => {
|
|
if let Err(e) = r {
|
|
if e.raw_os_error() != Some(35) {
|
|
return Err(e.into());
|
|
}
|
|
}
|
|
if pty_buf.len() > 0 {
|
|
send.write_all(&pty_buf).await?;
|
|
pty_buf.clear();
|
|
info!("sent pty");
|
|
}
|
|
}
|
|
r = recv.read_buf(&mut stdin_buf) => if r? > 0 {
|
|
sh.pty.write_all(&stdin_buf).await?;
|
|
stdin_buf.clear();
|
|
info!("recv stdin");
|
|
},
|
|
}
|
|
}
|
|
} else {
|
|
let mut cmd = std::process::Command::new(opt_shell);
|
|
if opt_shell == "bash" || opt_shell.ends_with("/bash") {
|
|
cmd.arg("-i");
|
|
}
|
|
cmd.stdout(Stdio::piped())
|
|
.stderr(Stdio::piped())
|
|
.stdin(Stdio::piped());
|
|
#[cfg(target_family = "unix")]
|
|
cmd.process_group(0);
|
|
info!("Running {:?}", cmd);
|
|
let mut sh = Command::from(cmd).kill_on_drop(true).spawn()?;
|
|
|
|
let mut stdout = sh.stdout.take().unwrap();
|
|
let mut stderr = sh.stderr.take().unwrap();
|
|
let mut stdin = sh.stdin.take().unwrap();
|
|
let mut stdout_buf = Vec::with_capacity(4096);
|
|
let mut stderr_buf = Vec::with_capacity(4096);
|
|
let mut stdin_buf = Vec::with_capacity(4096);
|
|
|
|
let mut stdout_eof = false;
|
|
let mut stderr_eof = false;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
r = sh.wait() => {
|
|
let code = r?;
|
|
send.finish().await?;
|
|
if !code.success() {
|
|
info!("Child exit: {}", code);
|
|
|
|
recv.stop(1u8.into())?;
|
|
return Ok(());
|
|
} else {
|
|
info!("Child exit");
|
|
recv.stop(0u8.into())?;
|
|
return Ok(());
|
|
}
|
|
}
|
|
r = stdout.read_buf(&mut stdout_buf), if !stdout_eof => {
|
|
if is_broken_pipe(&r) || r? == 0 {
|
|
stdout_eof = true;
|
|
info!("stdout eof");
|
|
} else {
|
|
send.write_all(&stdout_buf).await?;
|
|
info!("sent stdout: {:x?}", stdout_buf);
|
|
stdout_buf.clear();
|
|
}
|
|
},
|
|
r = stderr.read_buf(&mut stderr_buf), if !stderr_eof => {
|
|
if is_broken_pipe(&r) || r? == 0 {
|
|
stderr_eof = true;
|
|
info!("stderr eof");
|
|
} else {
|
|
send.write_all(&stderr_buf).await?;
|
|
stderr_buf.clear();
|
|
info!("sent stderr: {:x?}", stderr_buf);
|
|
}
|
|
},
|
|
r = recv.read_buf(&mut stdin_buf) => if r? > 0 {
|
|
stdin.write_all(&stdin_buf).await?;
|
|
stdin_buf.clear();
|
|
info!("recv stdin");
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_stream_heartbeat(mut send: SendStream, _recv: RecvStream) -> Result<()> {
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
|
send.write_all(&[0u8]).await?;
|
|
}
|
|
}
|