Verbose flag, improved log output, optimized (de)serialization

This commit is contained in:
Michael Pfaff 2023-06-11 22:56:31 -04:00
parent 413e0eaea5
commit 0e48e930ba
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
6 changed files with 171 additions and 94 deletions

9
Cargo.lock generated
View File

@ -1107,6 +1107,7 @@ dependencies = [
"parking_lot", "parking_lot",
"pin-project-lite", "pin-project-lite",
"quinn", "quinn",
"quinn-proto",
"rand 0.8.5", "rand 0.8.5",
"rcgen", "rcgen",
"rmp", "rmp",
@ -2227,14 +2228,14 @@ dependencies = [
"syn 2.0.18", "syn 2.0.18",
] ]
[[patch.unused]]
name = "chow"
version = "0.2.0"
[[patch.unused]] [[patch.unused]]
name = "how" name = "how"
version = "0.3.0" version = "0.3.0"
[[patch.unused]]
name = "chow"
version = "0.2.0"
[[patch.unused]] [[patch.unused]]
name = "minify-html-onepass" name = "minify-html-onepass"
version = "0.11.1" version = "0.11.1"

View File

@ -24,6 +24,7 @@ nix = "0.26.2"
parking_lot = "0.12.1" parking_lot = "0.12.1"
pin-project-lite = "0.2.9" pin-project-lite = "0.2.9"
quinn = "0.10.1" quinn = "0.10.1"
quinn-proto = { version = "0.10.1", default-features = false }
rand = "0.8.5" rand = "0.8.5"
rcgen = "0.10.0" rcgen = "0.10.0"
rmp = "0.8.11" rmp = "0.8.11"

View File

@ -3,6 +3,8 @@ pub mod ssh_key;
use std::ffi::CStr; use std::ffi::CStr;
use crate::ser::CStrAsBytes;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Hello<'a> { pub struct Hello<'a> {
#[serde(borrow)] #[serde(borrow)]
@ -21,55 +23,6 @@ pub enum Method<'a> {
}, },
} }
mod cstr_as_bytes {
use std::ffi::CStr;
use serde::{Serializer, Deserializer, Deserialize};
pub fn serialize<S>(s: &CStr, ser: S) -> Result<S::Ok, S::Error> where S: Serializer {
ser.serialize_bytes(s.to_bytes_with_nul())
}
pub fn deserialize<'de, D>(de: D) -> Result<&'de CStr, D::Error> where D: Deserializer<'de> {
use serde::de::Error;
let b = <&'de [u8]>::deserialize(de)?;
CStr::from_bytes_with_nul(b)
.map_err(|_| D::Error::invalid_value(serde::de::Unexpected::Bytes(b), &"a sequence of bytes ending with NUL"))
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
struct ByteSlice<'a>(#[serde(serialize_with = "serialize_bytes")] &'a [u8]);
pub fn serialize_bytes<S>(b: &[u8], ser: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
ser.serialize_bytes(b)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CStrAsBytes<'a>(#[serde(serialize_with = "cstr_as_bytes::serialize", deserialize_with = "cstr_as_bytes::deserialize", borrow)] &'a CStr);
impl<'a> std::ops::Deref for CStrAsBytes<'a> {
type Target = CStr;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a> From<&'a CStr> for CStrAsBytes<'a> {
fn from(value: &'a CStr) -> Self {
Self(value)
}
}
impl<'a> From<CStrAsBytes<'a>> for &'a CStr {
fn from(value: CStrAsBytes<'a>) -> Self {
value.0
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum Question<'a> { pub enum Question<'a> {
Prompt { #[serde(borrow)] prompt: CStrAsBytes<'a>, echo: bool }, Prompt { #[serde(borrow)] prompt: CStrAsBytes<'a>, echo: bool },

View File

@ -3,13 +3,13 @@ use nix::unistd::Uid;
use quinn::{SendStream, RecvStream}; use quinn::{SendStream, RecvStream};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use crate::{read_msg, write_msg, ClientConfig, Message}; use crate::{read_msg, write_msg, ClientConfig, Message, ser::ByteSlice};
use super::Question; use super::Question;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct AuthRequest<'a> { pub struct AuthRequest<'a> {
#[serde(serialize_with = "super::serialize_bytes")] #[serde(borrow)]
pub nonce: &'a [u8], pub nonce: ByteSlice<'a>,
} }
impl<'a> AuthRequest<'a> { impl<'a> AuthRequest<'a> {
@ -17,7 +17,7 @@ impl<'a> AuthRequest<'a> {
pub fn new(nonce: &'a [u8]) -> AuthRequest<'a> { pub fn new(nonce: &'a [u8]) -> AuthRequest<'a> {
Self { Self {
nonce, nonce: nonce.into(),
} }
} }
} }

View File

@ -12,6 +12,7 @@ mod auth;
mod io_util; mod io_util;
mod passwd; mod passwd;
mod pty; mod pty;
mod ser;
mod terminfo; mod terminfo;
mod user_info; mod user_info;
@ -44,10 +45,12 @@ use tracing::Instrument;
use user_info::UserInfo; use user_info::UserInfo;
use webpki::SubjectNameRef; use webpki::SubjectNameRef;
#[derive(Debug, Clone, Serialize, Deserialize)] use ser::{ByteSlice, CStrAsBytes};
struct Term {
name: String, #[derive(Debug, Clone, Copy, Serialize, Deserialize)]
info: Vec<u8>, struct Term<'a> {
name: &'a str,
info: ByteSlice<'a>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@ -145,11 +148,12 @@ impl FromStr for ForwardSpec {
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
enum Stream { enum Stream<'a> {
Exec, Exec,
Shell { Shell {
env_term: Option<Term>, #[serde(borrow)]
command: Option<(CString, Vec<CString>)>, env_term: Option<Term<'a>>,
command: Option<(CStrAsBytes<'a>, Vec<CStrAsBytes<'a>>)>,
}, },
Forward { Forward {
addr: Ipv4Addr, addr: Ipv4Addr,
@ -159,7 +163,7 @@ enum Stream {
// TODO: "backward" // TODO: "backward"
} }
impl fmt::Debug for Stream { impl<'a> fmt::Debug for Stream<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::Exec => f.write_str("Exec"), Self::Exec => f.write_str("Exec"),
@ -176,22 +180,60 @@ impl fmt::Debug for Stream {
} }
} }
struct Args {
verbose: bool,
command: String,
rem: std::env::Args,
}
impl Args {
pub fn parse() -> Result<Self> {
let mut args = std::env::args();
_ = args.next();
let mut verbose = false;
let command = loop {
let arg = args.next().context("Expected a COMMAND")?;
match arg.as_bytes() {
b"-v" | b"--verbose" => verbose = true,
[b'-', ..] => bail!("Unrecognized option: {}", arg),
_ => break arg,
}
};
Ok(Self { verbose, command, rem: args })
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let args = Args::parse()?;
let default_level = if args.verbose {
tracing_subscriber::filter::LevelFilter::INFO
} else {
tracing_subscriber::filter::LevelFilter::WARN
};
use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::fmt::format::FmtSpan;
tracing::subscriber::set_global_default( tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder() tracing_subscriber::FmtSubscriber::builder()
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_env_filter(tracing_subscriber::EnvFilter::builder()
.with_default_directive(default_level.into())
.try_from_env()?)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.finish(), .finish(),
) )
.unwrap(); .unwrap();
let mut args = std::env::args(); let fut = async move {
_ = args.next(); match args.command.as_str() {
#[cfg(feature = "server")]
let fut = run_cmd(args); "server" => run_server().await,
"client" => run_client(args.rem).await,
cmd => bail!("Unrecognized command: {}", cmd),
}
};
if std::env::var("NO_CTRLC").is_ok() { if std::env::var("NO_CTRLC").is_ok() {
fut.await fut.await
} else { } else {
@ -207,16 +249,6 @@ async fn main() -> Result<()> {
} }
} }
async fn run_cmd(mut args: std::env::Args) -> Result<()> {
let cmd = args.next().expect("COMMAND");
match cmd.as_str() {
#[cfg(feature = "server")]
"server" => run_server().await,
"client" => run_client(args).await,
_ => Err(anyhow!("Unrecognized command: {}", cmd)),
}
}
const ALPN_QUINOA: &str = "quinoa"; const ALPN_QUINOA: &str = "quinoa";
pub struct ServerConfig { pub struct ServerConfig {
@ -833,7 +865,7 @@ async fn run_client(mut args: std::env::Args) -> Result<()> {
async fn open_stream( async fn open_stream(
conn: &quinn::Connection, conn: &quinn::Connection,
stream: &Stream, stream: &Stream<'_>,
) -> Result<(SendStream, RecvStream)> { ) -> Result<(SendStream, RecvStream)> {
let (mut send, recv) = conn.open_bi().await?; let (mut send, recv) = conn.open_bi().await?;
write_msg(&mut send, stream).await?; write_msg(&mut send, stream).await?;
@ -946,13 +978,16 @@ async fn do_forward_to(conn: &quinn::Connection, spec: ForwardSpec) -> Result<()
async fn do_shell(conn: &quinn::Connection) -> Result<()> { async fn do_shell(conn: &quinn::Connection) -> Result<()> {
let env_term = std::env::var("TERM"); let env_term = std::env::var("TERM");
let env_terminfo = std::env::var("TERMINFO"); let env_terminfo = std::env::var("TERMINFO");
let env_term = if let (Ok(name), Ok(path)) = (env_term, env_terminfo) {
let info = terminfo::read_terminfo(&path, &name).await?;
Some((name, info))
} else {
None
};
let stream = &Stream::Shell { let stream = &Stream::Shell {
env_term: if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { env_term: env_term.as_ref().map(|(name, info)| {
let info = terminfo::read_terminfo(&path, &name).await?; Term { name, info: info.as_slice().into() }
Some(Term { name, info }) }),
} else {
None
},
command: None, command: None,
}; };
@ -1038,7 +1073,7 @@ async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Resu
}; };
span.record("username", &user_info.user.name); span.record("username", &user_info.user.name);
if let Err(e) = handle_conn(cfg, &conn, user_info, env).await { if let Err(e) = handle_conn(cfg, &conn, user_info, env).instrument(span).await {
error!("handler failed: {}", e.to_string()); error!("handler failed: {}", e.to_string());
conn.close(1u8.into(), b"handler error"); conn.close(1u8.into(), b"handler error");
} }
@ -1083,7 +1118,6 @@ async fn handle_conn(
) -> Result<()> { ) -> Result<()> {
info!("established"); info!("established");
let mut stream_buf = Vec::new();
loop { loop {
let stream = conn.accept_bi().await; let stream = conn.accept_bi().await;
let (send, mut recv) = match stream { let (send, mut recv) = match stream {
@ -1097,15 +1131,33 @@ async fn handle_conn(
Ok(s) => s, Ok(s) => s,
}; };
let stream = read_msg::<Stream>(&mut recv, &mut stream_buf).await??; let id = send.id();
let span = info_span!( let span = info_span!(
"stream", "stream",
r#type = ?stream dir = %match id.dir() {
quinn_proto::Dir::Bi => "bi",
quinn_proto::Dir::Uni => "uni",
},
id = id.index()
); );
let user_info = user_info.clone(); let user_info = user_info.clone();
let env = env.clone(); let env = env.clone();
tokio::task::spawn( tokio::task::spawn(
async move { async move {
let mut stream_buf = Vec::new();
let r = match read_msg::<Stream>(&mut recv, &mut stream_buf).await {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(e.into()),
Err(e) => Err(e),
};
let stream = match r {
Ok(t) => t,
Err(e) => {
error!("Error in stream setup: {}", e);
return;
}
};
let r = match stream { let r = match stream {
Stream::Exec => { Stream::Exec => {
let span = info_span!("stream_exec"); let span = info_span!("stream_exec");
@ -1131,7 +1183,7 @@ async fn handle_conn(
} }
}; };
if let Err(e) = r { if let Err(e) = r {
error!("Error in stream handler: {e}"); error!("Error in stream handler: {}", e);
} }
} }
.instrument(span), .instrument(span),
@ -1216,8 +1268,8 @@ async fn handle_stream_shell(
mut recv: RecvStream, mut recv: RecvStream,
user_info: &UserInfo, user_info: &UserInfo,
env: Vec<(CString, CString)>, env: Vec<(CString, CString)>,
env_term: Option<Term>, env_term: Option<Term<'_>>,
command: Option<(CString, Vec<CString>)>, command: Option<(CStrAsBytes<'_>, Vec<CStrAsBytes<'_>>)>,
) -> Result<()> { ) -> Result<()> {
let mut user_passwd_buf = Vec::new(); let mut user_passwd_buf = Vec::new();
let user_passwd = passwd::Passwd::from_uid(user_info.user.id, &mut user_passwd_buf)?; let user_passwd = passwd::Passwd::from_uid(user_info.user.id, &mut user_passwd_buf)?;

70
src/ser.rs Normal file
View File

@ -0,0 +1,70 @@
use std::ffi::CStr;
mod cstr_as_bytes {
use std::ffi::CStr;
use serde::{Serializer, Deserializer, Deserialize};
pub fn serialize<S>(s: &CStr, ser: S) -> Result<S::Ok, S::Error> where S: Serializer {
ser.serialize_bytes(s.to_bytes_with_nul())
}
pub fn deserialize<'de, D>(de: D) -> Result<&'de CStr, D::Error> where D: Deserializer<'de> {
use serde::de::Error;
let b = <&'de [u8]>::deserialize(de)?;
CStr::from_bytes_with_nul(b)
.map_err(|_| D::Error::invalid_value(serde::de::Unexpected::Bytes(b), &"a sequence of bytes ending with NUL"))
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ByteSlice<'a>(#[serde(serialize_with = "serialize_bytes")] &'a [u8]);
pub fn serialize_bytes<S>(b: &[u8], ser: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
ser.serialize_bytes(b)
}
impl<'a> std::ops::Deref for ByteSlice<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a> From<&'a [u8]> for ByteSlice<'a> {
fn from(value: &'a [u8]) -> Self {
Self(value)
}
}
impl<'a> From<ByteSlice<'a>> for &'a [u8] {
fn from(value: ByteSlice<'a>) -> Self {
value.0
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CStrAsBytes<'a>(#[serde(serialize_with = "cstr_as_bytes::serialize", deserialize_with = "cstr_as_bytes::deserialize", borrow)] &'a CStr);
impl<'a> std::ops::Deref for CStrAsBytes<'a> {
type Target = CStr;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a> From<&'a CStr> for CStrAsBytes<'a> {
fn from(value: &'a CStr) -> Self {
Self(value)
}
}
impl<'a> From<CStrAsBytes<'a>> for &'a CStr {
fn from(value: CStrAsBytes<'a>) -> Self {
value.0
}
}