diff --git a/Cargo.lock b/Cargo.lock index 6356fbd..929dee6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1107,6 +1107,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "quinn", + "quinn-proto", "rand 0.8.5", "rcgen", "rmp", @@ -2227,14 +2228,14 @@ dependencies = [ "syn 2.0.18", ] -[[patch.unused]] -name = "chow" -version = "0.2.0" - [[patch.unused]] name = "how" version = "0.3.0" +[[patch.unused]] +name = "chow" +version = "0.2.0" + [[patch.unused]] name = "minify-html-onepass" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index 5b8fb40..9b0e9b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ nix = "0.26.2" parking_lot = "0.12.1" pin-project-lite = "0.2.9" quinn = "0.10.1" +quinn-proto = { version = "0.10.1", default-features = false } rand = "0.8.5" rcgen = "0.10.0" rmp = "0.8.11" diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 99d3318..0002f08 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -3,6 +3,8 @@ pub mod ssh_key; use std::ffi::CStr; +use crate::ser::CStrAsBytes; + #[derive(Debug, Serialize, Deserialize)] pub struct Hello<'a> { #[serde(borrow)] @@ -21,55 +23,6 @@ pub enum Method<'a> { }, } -mod cstr_as_bytes { - use std::ffi::CStr; - - use serde::{Serializer, Deserializer, Deserialize}; - - pub fn serialize(s: &CStr, ser: S) -> Result where S: Serializer { - ser.serialize_bytes(s.to_bytes_with_nul()) - } - - pub fn deserialize<'de, D>(de: D) -> Result<&'de CStr, D::Error> where D: Deserializer<'de> { - use serde::de::Error; - let b = <&'de [u8]>::deserialize(de)?; - CStr::from_bytes_with_nul(b) - .map_err(|_| D::Error::invalid_value(serde::de::Unexpected::Bytes(b), &"a sequence of bytes ending with NUL")) - } -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -#[serde(transparent)] -struct ByteSlice<'a>(#[serde(serialize_with = "serialize_bytes")] &'a [u8]); - -pub fn serialize_bytes(b: &[u8], ser: S) -> Result where S: serde::Serializer { - ser.serialize_bytes(b) -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -#[serde(transparent)] -pub struct CStrAsBytes<'a>(#[serde(serialize_with = "cstr_as_bytes::serialize", deserialize_with = "cstr_as_bytes::deserialize", borrow)] &'a CStr); - -impl<'a> std::ops::Deref for CStrAsBytes<'a> { - type Target = CStr; - - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl<'a> From<&'a CStr> for CStrAsBytes<'a> { - fn from(value: &'a CStr) -> Self { - Self(value) - } -} - -impl<'a> From> for &'a CStr { - fn from(value: CStrAsBytes<'a>) -> Self { - value.0 - } -} - #[derive(Debug, Serialize, Deserialize)] pub enum Question<'a> { Prompt { #[serde(borrow)] prompt: CStrAsBytes<'a>, echo: bool }, diff --git a/src/auth/ssh_key.rs b/src/auth/ssh_key.rs index 60557b5..98f755d 100644 --- a/src/auth/ssh_key.rs +++ b/src/auth/ssh_key.rs @@ -3,13 +3,13 @@ use nix::unistd::Uid; use quinn::{SendStream, RecvStream}; use zeroize::Zeroizing; -use crate::{read_msg, write_msg, ClientConfig, Message}; +use crate::{read_msg, write_msg, ClientConfig, Message, ser::ByteSlice}; use super::Question; #[derive(Debug, Serialize, Deserialize)] pub struct AuthRequest<'a> { - #[serde(serialize_with = "super::serialize_bytes")] - pub nonce: &'a [u8], + #[serde(borrow)] + pub nonce: ByteSlice<'a>, } impl<'a> AuthRequest<'a> { @@ -17,7 +17,7 @@ impl<'a> AuthRequest<'a> { pub fn new(nonce: &'a [u8]) -> AuthRequest<'a> { Self { - nonce, + nonce: nonce.into(), } } } diff --git a/src/main.rs b/src/main.rs index 76604df..58d692d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ mod auth; mod io_util; mod passwd; mod pty; +mod ser; mod terminfo; mod user_info; @@ -44,10 +45,12 @@ use tracing::Instrument; use user_info::UserInfo; use webpki::SubjectNameRef; -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Term { - name: String, - info: Vec, +use ser::{ByteSlice, CStrAsBytes}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +struct Term<'a> { + name: &'a str, + info: ByteSlice<'a>, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -145,11 +148,12 @@ impl FromStr for ForwardSpec { } #[derive(Clone, Serialize, Deserialize)] -enum Stream { +enum Stream<'a> { Exec, Shell { - env_term: Option, - command: Option<(CString, Vec)>, + #[serde(borrow)] + env_term: Option>, + command: Option<(CStrAsBytes<'a>, Vec>)>, }, Forward { addr: Ipv4Addr, @@ -159,7 +163,7 @@ enum Stream { // TODO: "backward" } -impl fmt::Debug for Stream { +impl<'a> fmt::Debug for Stream<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Exec => f.write_str("Exec"), @@ -176,22 +180,60 @@ impl fmt::Debug for Stream { } } +struct Args { + verbose: bool, + command: String, + rem: std::env::Args, +} + +impl Args { + pub fn parse() -> Result { + let mut args = std::env::args(); + _ = args.next(); + + let mut verbose = false; + let command = loop { + let arg = args.next().context("Expected a COMMAND")?; + match arg.as_bytes() { + b"-v" | b"--verbose" => verbose = true, + [b'-', ..] => bail!("Unrecognized option: {}", arg), + _ => break arg, + } + }; + + Ok(Self { verbose, command, rem: args }) + } +} + #[tokio::main] async fn main() -> Result<()> { + let args = Args::parse()?; + + let default_level = if args.verbose { + tracing_subscriber::filter::LevelFilter::INFO + } else { + tracing_subscriber::filter::LevelFilter::WARN + }; use tracing_subscriber::fmt::format::FmtSpan; tracing::subscriber::set_global_default( tracing_subscriber::FmtSubscriber::builder() .with_writer(std::io::stderr) - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_env_filter(tracing_subscriber::EnvFilter::builder() + .with_default_directive(default_level.into()) + .try_from_env()?) .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .finish(), ) .unwrap(); - let mut args = std::env::args(); - _ = args.next(); - - let fut = run_cmd(args); + let fut = async move { + match args.command.as_str() { + #[cfg(feature = "server")] + "server" => run_server().await, + "client" => run_client(args.rem).await, + cmd => bail!("Unrecognized command: {}", cmd), + } + }; if std::env::var("NO_CTRLC").is_ok() { fut.await } else { @@ -207,16 +249,6 @@ async fn main() -> Result<()> { } } -async fn run_cmd(mut args: std::env::Args) -> Result<()> { - let cmd = args.next().expect("COMMAND"); - match cmd.as_str() { - #[cfg(feature = "server")] - "server" => run_server().await, - "client" => run_client(args).await, - _ => Err(anyhow!("Unrecognized command: {}", cmd)), - } -} - const ALPN_QUINOA: &str = "quinoa"; pub struct ServerConfig { @@ -833,7 +865,7 @@ async fn run_client(mut args: std::env::Args) -> Result<()> { async fn open_stream( conn: &quinn::Connection, - stream: &Stream, + stream: &Stream<'_>, ) -> Result<(SendStream, RecvStream)> { let (mut send, recv) = conn.open_bi().await?; write_msg(&mut send, stream).await?; @@ -946,13 +978,16 @@ async fn do_forward_to(conn: &quinn::Connection, spec: ForwardSpec) -> Result<() async fn do_shell(conn: &quinn::Connection) -> Result<()> { let env_term = std::env::var("TERM"); let env_terminfo = std::env::var("TERMINFO"); + let env_term = if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { + let info = terminfo::read_terminfo(&path, &name).await?; + Some((name, info)) + } else { + None + }; let stream = &Stream::Shell { - env_term: if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { - let info = terminfo::read_terminfo(&path, &name).await?; - Some(Term { name, info }) - } else { - None - }, + env_term: env_term.as_ref().map(|(name, info)| { + Term { name, info: info.as_slice().into() } + }), command: None, }; @@ -1038,7 +1073,7 @@ async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Resu }; span.record("username", &user_info.user.name); - if let Err(e) = handle_conn(cfg, &conn, user_info, env).await { + if let Err(e) = handle_conn(cfg, &conn, user_info, env).instrument(span).await { error!("handler failed: {}", e.to_string()); conn.close(1u8.into(), b"handler error"); } @@ -1083,7 +1118,6 @@ async fn handle_conn( ) -> Result<()> { info!("established"); - let mut stream_buf = Vec::new(); loop { let stream = conn.accept_bi().await; let (send, mut recv) = match stream { @@ -1097,15 +1131,33 @@ async fn handle_conn( Ok(s) => s, }; - let stream = read_msg::(&mut recv, &mut stream_buf).await??; + let id = send.id(); let span = info_span!( "stream", - r#type = ?stream + dir = %match id.dir() { + quinn_proto::Dir::Bi => "bi", + quinn_proto::Dir::Uni => "uni", + }, + id = id.index() ); + let user_info = user_info.clone(); let env = env.clone(); tokio::task::spawn( async move { + let mut stream_buf = Vec::new(); + let r = match read_msg::(&mut recv, &mut stream_buf).await { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(e.into()), + Err(e) => Err(e), + }; + let stream = match r { + Ok(t) => t, + Err(e) => { + error!("Error in stream setup: {}", e); + return; + } + }; let r = match stream { Stream::Exec => { let span = info_span!("stream_exec"); @@ -1131,7 +1183,7 @@ async fn handle_conn( } }; if let Err(e) = r { - error!("Error in stream handler: {e}"); + error!("Error in stream handler: {}", e); } } .instrument(span), @@ -1216,8 +1268,8 @@ async fn handle_stream_shell( mut recv: RecvStream, user_info: &UserInfo, env: Vec<(CString, CString)>, - env_term: Option, - command: Option<(CString, Vec)>, + env_term: Option>, + command: Option<(CStrAsBytes<'_>, Vec>)>, ) -> Result<()> { let mut user_passwd_buf = Vec::new(); let user_passwd = passwd::Passwd::from_uid(user_info.user.id, &mut user_passwd_buf)?; diff --git a/src/ser.rs b/src/ser.rs new file mode 100644 index 0000000..70187f4 --- /dev/null +++ b/src/ser.rs @@ -0,0 +1,70 @@ + +use std::ffi::CStr; +mod cstr_as_bytes { + use std::ffi::CStr; + + use serde::{Serializer, Deserializer, Deserialize}; + + pub fn serialize(s: &CStr, ser: S) -> Result where S: Serializer { + ser.serialize_bytes(s.to_bytes_with_nul()) + } + + pub fn deserialize<'de, D>(de: D) -> Result<&'de CStr, D::Error> where D: Deserializer<'de> { + use serde::de::Error; + let b = <&'de [u8]>::deserialize(de)?; + CStr::from_bytes_with_nul(b) + .map_err(|_| D::Error::invalid_value(serde::de::Unexpected::Bytes(b), &"a sequence of bytes ending with NUL")) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ByteSlice<'a>(#[serde(serialize_with = "serialize_bytes")] &'a [u8]); + +pub fn serialize_bytes(b: &[u8], ser: S) -> Result where S: serde::Serializer { + ser.serialize_bytes(b) +} + +impl<'a> std::ops::Deref for ByteSlice<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> From<&'a [u8]> for ByteSlice<'a> { + fn from(value: &'a [u8]) -> Self { + Self(value) + } +} + +impl<'a> From> for &'a [u8] { + fn from(value: ByteSlice<'a>) -> Self { + value.0 + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(transparent)] +pub struct CStrAsBytes<'a>(#[serde(serialize_with = "cstr_as_bytes::serialize", deserialize_with = "cstr_as_bytes::deserialize", borrow)] &'a CStr); + +impl<'a> std::ops::Deref for CStrAsBytes<'a> { + type Target = CStr; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> From<&'a CStr> for CStrAsBytes<'a> { + fn from(value: &'a CStr) -> Self { + Self(value) + } +} + +impl<'a> From> for &'a CStr { + fn from(value: CStrAsBytes<'a>) -> Self { + value.0 + } +}