Initial commit

This commit is contained in:
Michael Pfaff 2023-06-06 00:32:07 -04:00
commit 96b1a94a4a
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
5 changed files with 1597 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

1026
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

20
Cargo.toml Normal file
View File

@ -0,0 +1,20 @@
[package]
name = "quic-shell"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.71"
libc = "0.2.145"
nix = "0.26.2"
quinn = "0.10.1"
rcgen = "0.10.0"
rmp-serde = "1.1.1"
rustls = { version = "0.21.1", default-features = false }
serde = { version = "1.0.163", features = ["derive"] }
#termion = "2.0.1"
tokio = { version = "1.28.2", default-features = false, features = ["rt-multi-thread", "macros", "process", "io-util", "io-std", "time", "fs", "signal"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }

429
src/main.rs Normal file
View File

@ -0,0 +1,429 @@
#[macro_use]
extern crate anyhow;
#[macro_use]
extern crate serde;
#[macro_use]
extern crate tracing;
mod pty;
use std::ffi::CStr;
use std::os::fd::FromRawFd;
use std::os::unix::process::CommandExt;
use std::process::Stdio;
use std::sync::Arc;
use anyhow::{Context, Result};
use quinn::{RecvStream, SendStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::{Child, Command};
use tracing::Instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
enum Stream {
Shell,
Heartbeat,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish(),
)
.unwrap();
let mut args = std::env::args();
_ = args.next();
let ctrl_c = tokio::signal::ctrl_c();
let fut = run_cmd(args);
tokio::select! {
_ = ctrl_c => Ok(()),
r = fut => r,
}
}
async fn run_cmd(mut args: std::env::Args) -> Result<()> {
let cmd = args.next().expect("COMMAND");
match cmd.as_str() {
"server" => run_server().await,
"client" => run_client().await,
_ => Err(anyhow!("Unrecognized command: {}", cmd)),
}
}
const ALPN_QUIC_SHELL: &str = "quic-shell";
async fn run_server() -> Result<()> {
let opt_shell = &*Box::leak(
std::env::var("SHELL")
.context("SHELL not defined")?
.into_boxed_str(),
);
let opt_listen = std::env::var("BIND_ADDR")
.unwrap_or_else(|_| "127.0.0.1:8022".to_owned())
.parse()?;
let subject_alt_names = vec!["localhost".to_string()];
let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;
let key = rustls::PrivateKey(cert.serialize_private_key_der());
let cert = rustls::Certificate(cert.serialize_der()?);
//std::fs::write("key.der", &key.0)?;
std::fs::write("cert.der", &cert.0)?;
let mut server_crypto = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.unwrap();
server_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()];
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.max_concurrent_uni_streams(0_u8.into());
server_config.use_retry(true);
let endpoint = quinn::Endpoint::server(server_config, opt_listen)?;
eprintln!("listening on {}", endpoint.local_addr()?);
while let Some(conn) = endpoint.accept().await {
info!("connection incoming");
tokio::spawn(async move {
if let Err(e) = handle_connection(opt_shell, conn).await {
error!("connection failed: {reason}", reason = e.to_string())
}
});
}
Ok(())
}
fn is_broken_pipe<T>(r: &Result<T, std::io::Error>) -> bool {
if let Err(e) = r {
e.kind() == std::io::ErrorKind::BrokenPipe
} else {
false
}
}
async fn run_client() -> Result<()> {
info!("running client");
let mut roots = rustls::RootCertStore::empty();
match std::fs::read("cert.der") {
Ok(cert) => {
roots.add(&rustls::Certificate(cert))?;
}
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
info!("local server certificate not found");
}
Err(e) => {
error!("failed to open local server certificate: {}", e);
}
}
info!("read roots");
let mut client_crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
client_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()];
let mut transport = quinn::TransportConfig::default();
transport.keep_alive_interval(Some(std::time::Duration::from_secs(5)));
let mut client_config = quinn::ClientConfig::new(Arc::new(client_crypto));
client_config.transport_config(transport.into());
//let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap())?;
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(client_config);
info!("connecting");
let conn = endpoint
.connect("127.0.0.1:8022".parse()?, "localhost")?
.await?;
let (mut send, mut recv) = conn.open_bi().await?;
write_header(&mut send, Stream::Shell).await?;
info!("connected");
let mut stdin = unsafe { tokio::fs::File::from_raw_fd(libc::STDIN_FILENO) };
let mut stdout = unsafe { tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO) };
let mut stdin_buf = Vec::with_capacity(4096);
let mut stdout_buf = Vec::with_capacity(4096);
let mut stdin_eof = false;
{
use nix::sys::termios::*;
let mut termios = tcgetattr(libc::STDIN_FILENO)?;
termios.local_flags.remove(LocalFlags::ECHO);
termios.local_flags.remove(LocalFlags::ICANON);
termios.local_flags.remove(LocalFlags::ISIG);
termios.local_flags.remove(LocalFlags::IEXTEN);
termios.input_flags.remove(InputFlags::IXON);
termios.input_flags.remove(InputFlags::ICRNL);
termios.output_flags.remove(OutputFlags::OPOST);
tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &termios)?;
}
/*let mut heartbeat = {
let (mut send, recv) = conn.open_bi().await?;
write_header(&mut send, Stream::Heartbeat).await?;
Box::pin(handle_stream_heartbeat(send, recv))
};*/
loop {
tokio::select! {
//_ = &mut heartbeat => {}
r = stdin.read_buf(&mut stdin_buf), if !stdin_eof => {
if r? == 0 {
stdin_eof = true;
}
send.write_all(&stdin_buf).await?;
stdin_buf.clear();
//info!("sent stdin");
},
r = recv.read_buf(&mut stdout_buf) => if r? > 0 {
stdout.write_all(&stdout_buf).await?;
stdout_buf.clear();
//info!("recv stdout");
},
r = send.stopped() => {
info!("Remote disconnected");
let code = r?.into_inner();
if code == 0 {
return Ok(());
} else {
return Err(anyhow!("Error code {}", code));
}
}
e = conn.closed() => {
info!("Remote disconnected: {}", e);
return Ok(());
}
}
}
}
async fn write_header(send: &mut SendStream, header: Stream) -> Result<()> {
let buf = rmp_serde::to_vec(&header)?;
send.write_all(&u16::try_from(buf.len())?.to_le_bytes())
.await?;
send.write_all(&buf).await?;
Ok(())
}
async fn read_header(recv: &mut RecvStream) -> Result<Stream> {
let mut size = [0u8; 2];
recv.read_exact(&mut size).await?;
let size = u16::from_le_bytes(size);
let mut buf = Vec::with_capacity(size.into());
recv.take(size.into()).read_to_end(&mut buf).await?;
Ok(rmp_serde::from_slice(&buf)?)
}
async fn handle_connection(opt_shell: &'static str, conn: quinn::Connecting) -> Result<()> {
let conn = conn.await?;
let span = info_span!(
"connection",
remote = %conn.remote_address(),
protocol = %conn
.handshake_data()
.unwrap()
.downcast::<quinn::crypto::rustls::HandshakeData>().unwrap()
.protocol
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned())
);
async {
info!("established");
loop {
let stream = conn.accept_bi().await;
let (send, mut recv) = match stream {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
info!("connection closed");
return Ok(());
}
Err(e) => {
return Err(e.into());
}
Ok(s) => s,
};
let stream = read_header(&mut recv).await?;
let span = info_span!(
"stream",
r#type = ?stream
);
tokio::task::spawn(
async move {
let r = match stream {
Stream::Shell => handle_stream_shell(opt_shell, send, recv).await,
Stream::Heartbeat => handle_stream_heartbeat(send, recv).await,
};
if let Err(e) = r {
error!("Error in stream handler: {e}");
}
}
.instrument(span),
);
}
}
.instrument(span)
.await
}
async fn handle_stream_shell(
opt_shell: &str,
mut send: SendStream,
mut recv: RecvStream,
) -> Result<()> {
let use_pty = true;
if use_pty {
let args = if opt_shell == "bash" || opt_shell.ends_with("/bash") {
vec![CStr::from_bytes_with_nul(b"-i\0")?]
} else {
vec![]
};
let mut opt_shell_with_nul = Vec::with_capacity(opt_shell.len() + 1);
opt_shell_with_nul.extend(opt_shell.as_bytes());
opt_shell_with_nul.push(0);
let opt_shell = CStr::from_bytes_with_nul(&opt_shell_with_nul)?;
let mut sh = pty::create_pty(opt_shell, &args)?;
info!("Created pty");
let mut stdin_buf = Vec::with_capacity(4096);
let mut pty_buf = Vec::with_capacity(4096);
//let mut pty_buf = [0u8];
loop {
if let Some(code) = sh.try_wait()? {
send.finish().await?;
if code != 0 {
info!("Child exit: {}", code);
recv.stop(1u8.into())?;
return Ok(());
} else {
info!("Child exit");
recv.stop(0u8.into())?;
return Ok(());
}
}
//let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50));
tokio::select! {
/*_ = redraw.tick() => {
sh.pty.read_buf(&mut pty_buf).await?;
send.write_all(&pty_buf).await?;
pty_buf.clear();
info!("redraw complete");
}*/
r = sh.pty.read_buf(&mut pty_buf) => {
//r = sh.pty.read_exact(&mut pty_buf) => {
if let Err(e) = r {
if e.raw_os_error() != Some(35) {
return Err(e.into());
}
}
if pty_buf.len() > 0 {
send.write_all(&pty_buf).await?;
pty_buf.clear();
info!("sent pty");
}
}
r = recv.read_buf(&mut stdin_buf) => if r? > 0 {
sh.pty.write_all(&stdin_buf).await?;
stdin_buf.clear();
info!("recv stdin");
},
}
}
} else {
let mut cmd = std::process::Command::new(opt_shell);
if opt_shell == "bash" || opt_shell.ends_with("/bash") {
cmd.arg("-i");
}
cmd.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::piped());
#[cfg(target_family = "unix")]
cmd.process_group(0);
info!("Running {:?}", cmd);
let mut sh = Command::from(cmd).kill_on_drop(true).spawn()?;
let mut stdout = sh.stdout.take().unwrap();
let mut stderr = sh.stderr.take().unwrap();
let mut stdin = sh.stdin.take().unwrap();
let mut stdout_buf = Vec::with_capacity(4096);
let mut stderr_buf = Vec::with_capacity(4096);
let mut stdin_buf = Vec::with_capacity(4096);
let mut stdout_eof = false;
let mut stderr_eof = false;
loop {
tokio::select! {
r = sh.wait() => {
let code = r?;
send.finish().await?;
if !code.success() {
info!("Child exit: {}", code);
recv.stop(1u8.into())?;
return Ok(());
} else {
info!("Child exit");
recv.stop(0u8.into())?;
return Ok(());
}
}
r = stdout.read_buf(&mut stdout_buf), if !stdout_eof => {
if is_broken_pipe(&r) || r? == 0 {
stdout_eof = true;
info!("stdout eof");
} else {
send.write_all(&stdout_buf).await?;
info!("sent stdout: {:x?}", stdout_buf);
stdout_buf.clear();
}
},
r = stderr.read_buf(&mut stderr_buf), if !stderr_eof => {
if is_broken_pipe(&r) || r? == 0 {
stderr_eof = true;
info!("stderr eof");
} else {
send.write_all(&stderr_buf).await?;
stderr_buf.clear();
info!("sent stderr: {:x?}", stderr_buf);
}
},
r = recv.read_buf(&mut stdin_buf) => if r? > 0 {
stdin.write_all(&stdin_buf).await?;
stdin_buf.clear();
info!("recv stdin");
},
}
}
}
}
async fn handle_stream_heartbeat(mut send: SendStream, _recv: RecvStream) -> Result<()> {
loop {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
send.write_all(&[0u8]).await?;
}
}

121
src/pty.rs Normal file
View File

@ -0,0 +1,121 @@
use nix::fcntl::{fcntl, open, FcntlArg, OFlag};
use nix::pty::{grantpt, posix_openpt, ptsname, unlockpt, Winsize};
use nix::sys::stat::Mode;
use nix::unistd::{ForkResult, Pid};
use nix::{ioctl_none_bad, ioctl_write_ptr_bad};
use libc::{STDERR_FILENO, STDIN_FILENO, STDOUT_FILENO, TIOCSCTTY};
// ioctl request code to set window size of pty:
use libc::TIOCSWINSZ;
use std::ffi::CStr;
use std::fs::File;
use std::os::unix::process::CommandExt;
use std::path::Path;
use std::process::{Command, Stdio};
use std::os::unix::io::{FromRawFd, IntoRawFd};
// nix macro that generates an ioctl call to set window size of pty:
ioctl_write_ptr_bad!(set_window_size, TIOCSWINSZ, Winsize);
// request to "Make the given terminal the controlling terminal of the calling process"
ioctl_none_bad!(set_controlling_terminal, TIOCSCTTY);
pub struct Child {
pub pty: tokio::fs::File,
pub pid: Pid,
}
impl Child {
// copied from
// https://doc.rust-lang.org/nightly/src/std/sys/unix/process/process_unix.rs.html#744-757
pub fn try_wait(&self) -> std::io::Result<Option<i32>> {
let mut status = 0;
let pid = cvt(unsafe { libc::waitpid(self.pid.as_raw(), &mut status, libc::WNOHANG) })?;
if pid == 0 {
Ok(None)
} else {
Ok(Some(status))
}
}
}
impl Drop for Child {
fn drop(&mut self) {
_ = nix::sys::signal::kill(self.pid, nix::sys::signal::Signal::SIGTERM);
}
}
pub fn create_pty<S: AsRef<CStr> + std::fmt::Debug>(path: &CStr, argv: &[S]) -> nix::Result<Child> {
/* Create a new master */
let master_fd = posix_openpt(OFlag::O_RDWR)?;
/* For some reason, you have to give permission to the master to have a
* pty. What is it good for otherwise? */
grantpt(&master_fd)?;
unlockpt(&master_fd)?;
/* Get the path of the slave */
let slave_name = unsafe { ptsname(&master_fd) }?;
/* Try to open the slave */
let _slave_fd = open(Path::new(&slave_name), OFlag::O_RDWR, Mode::empty())?;
info!("master opened the slave_fd!");
/* Launch our child process. The main application loop can inspect and then
pass the stdin data to it. */
let child_pid = match unsafe { nix::unistd::fork() } {
Ok(ForkResult::Child) => {
init_child(path, argv, slave_name).unwrap();
unreachable!()
}
Ok(ForkResult::Parent { child }) => child,
Err(e) => panic!("{}", e),
};
let winsize = Winsize {
ws_row: 25,
ws_col: 80,
ws_xpixel: 0,
ws_ypixel: 0,
};
let master_fd = master_fd.into_raw_fd();
/* Tell the master the size of the terminal */
unsafe { set_window_size(master_fd, &winsize)? };
fcntl(master_fd, FcntlArg::F_SETFL(OFlag::O_NDELAY)).unwrap();
let master_file = unsafe { File::from_raw_fd(master_fd) };
Ok(Child {
pty: master_file.into(),
pid: child_pid,
})
}
fn init_child<S: AsRef<CStr> + std::fmt::Debug>(
path: &CStr,
argv: &[S],
slave_name: String,
) -> anyhow::Result<std::convert::Infallible> {
/* Open slave end for pseudoterminal */
let slave_fd = open(Path::new(&slave_name), OFlag::O_RDWR, Mode::empty())?;
info!("child opened the slave_fd!");
// assign stdin, stdout, stderr to the tty
nix::unistd::dup2(slave_fd, STDIN_FILENO)?;
nix::unistd::dup2(slave_fd, STDOUT_FILENO)?;
nix::unistd::dup2(slave_fd, STDERR_FILENO)?;
nix::unistd::setsid().unwrap();
unsafe { set_controlling_terminal(slave_fd) }.unwrap();
info!("running exec: {:?} {:?}", path, argv);
Ok(nix::unistd::execv(path, argv)?)
}
// copied from... somewhere in std
pub fn cvt(t: i32) -> std::io::Result<i32> {
if t == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(t)
}
}