#![deny(unreachable_code)] #![deny(unused_must_use)] #[macro_use] extern crate anyhow; #[macro_use] extern crate serde; #[macro_use] extern crate tracing; mod auth; mod io_util; mod passwd; mod pty; mod ser; mod terminfo; mod user_info; #[cfg(all(feature = "server", target_os = "macos"))] use pam_client_macos as pam_client; use tokio::net::TcpSocket; use std::ffi::{CStr, CString}; use std::fmt; use std::future::Future; use std::mem::ManuallyDrop; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::os::fd::FromRawFd; use std::os::unix::process::CommandExt; use std::pin::Pin; use std::process::Stdio; use std::ptr::NonNull; use std::str::FromStr; use std::sync::Arc; use std::task::Poll; use anyhow::{Context, Result}; use base64::Engine as _; use nix::unistd::Uid; use quinn::{ReadExactError, RecvStream, SendStream}; use rustls::client::ServerCertVerifier; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command; use tracing::Instrument; use user_info::UserInfo; use webpki::SubjectNameRef; use ser::{ByteSlice, CStrAsBytes}; #[derive(Debug, Clone, Copy, Serialize, Deserialize)] struct Term<'a> { name: &'a str, info: ByteSlice<'a>, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] enum ForwardDirection { LocalToRemote, RemoteToLocal, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] enum ForwardProtocol { Tcp, Udp, } impl ForwardProtocol { pub const fn name(self) -> &'static str { match self { ForwardProtocol::Tcp => "tcp", ForwardProtocol::Udp => "udp", } } } impl fmt::Display for ForwardProtocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.name()) } } impl FromStr for ForwardProtocol { type Err = anyhow::Error; fn from_str(s: &str) -> Result { const TCP: &str = ForwardProtocol::Tcp.name(); const UDP: &str = ForwardProtocol::Udp.name(); match s { TCP => Ok(Self::Tcp), UDP => Ok(Self::Udp), _ => Err(anyhow!("Expected one of tcp, udp: {:?}", s)), } } } #[derive(Debug, Clone)] struct ForwardSpec { direction: ForwardDirection, bind: SocketAddrV4, // TODO: allow DNS lookup here forward: (Ipv4Addr, u16), protocol: ForwardProtocol, } impl FromStr for ForwardSpec { type Err = anyhow::Error; fn from_str(s: &str) -> Result { let i = s .as_bytes() .iter() .position(|&b| b == b'-' || b == b'<') .with_context(|| format!("Expected a local-remote delimiter: {:?}", s))?; let direction = match s.as_bytes().get(i..i + 2) { Some(b"->") => ForwardDirection::LocalToRemote, Some(b"<-") => ForwardDirection::RemoteToLocal, _ => bail!("Expected a local-remote delimiter: {:?}", s), }; let local = &s[..i]; let remote = &s[i + 2..]; let (bind, forward) = match direction { ForwardDirection::LocalToRemote => (local, remote), ForwardDirection::RemoteToLocal => (remote, local), }; let (bind_sock, protocol) = bind .split_once('/') .context("Expected a port-protocol delimiter")?; let bind_sock = bind_sock.parse::()?; let protocol = protocol.parse()?; let (forward_addr, forward_port) = forward .split_once(':') .context("Expected an address-port delimiter")?; let forward_addr = forward_addr .parse::() .with_context(|| forward_addr.to_owned())?; let forward_port = forward_port .parse::() .with_context(|| forward_port.to_owned())?; Ok(Self { direction, bind: bind_sock, forward: (forward_addr, forward_port), protocol, }) } } #[derive(Clone, Serialize, Deserialize)] enum Stream<'a> { Exec, Shell { #[serde(borrow)] env_term: Option>, command: Option<(CStrAsBytes<'a>, Vec>)>, }, Forward { addr: Ipv4Addr, port: u16, protocol: ForwardProtocol, }, // TODO: "backward" } impl<'a> fmt::Debug for Stream<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Exec => f.write_str("Exec"), Self::Shell { env_term: _, command, } => f.debug_struct("Shell").field("command", command).finish(), Self::Forward { addr, port, protocol, } => write!(f, "-> {}:{}/{}", addr, port, protocol), } } } struct Args { verbose: bool, command: String, rem: std::env::Args, } impl Args { pub fn parse() -> Result { let mut args = std::env::args(); _ = args.next(); let mut verbose = false; let command = loop { let arg = args.next().context("Expected a COMMAND")?; match arg.as_bytes() { b"-v" | b"--verbose" => verbose = true, [b'-', ..] => bail!("Unrecognized option: {}", arg), _ => break arg, } }; Ok(Self { verbose, command, rem: args }) } } #[tokio::main] async fn main() -> Result<()> { let args = Args::parse()?; let default_level = if args.verbose { tracing_subscriber::filter::LevelFilter::INFO } else { tracing_subscriber::filter::LevelFilter::WARN }; use tracing_subscriber::fmt::format::FmtSpan; tracing::subscriber::set_global_default( tracing_subscriber::FmtSubscriber::builder() .with_writer(std::io::stderr) .with_env_filter(tracing_subscriber::EnvFilter::builder() .with_default_directive(default_level.into()) .try_from_env()?) .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .finish(), ) .unwrap(); let fut = async move { match args.command.as_str() { #[cfg(feature = "server")] "server" => run_server().await, "client" => run_client(args.rem).await, cmd => bail!("Unrecognized command: {}", cmd), } }; if std::env::var("NO_CTRLC").is_ok() { fut.await } else { let ctrl_c = tokio::signal::ctrl_c(); tokio::select! { _ = ctrl_c => { info!("Aborting"); Ok(()) } r = fut => r, } } } const ALPN_QUINOA: &str = "quinoa"; pub struct ServerConfig { listen: SocketAddr, } pub struct ClientConfig { known_hosts_file: String, known_hosts: parking_lot::Mutex>>, ssh_key_file: String, } #[cfg(feature = "server")] async fn run_server() -> Result<()> { let cfg = { let opt_listen = std::env::var("BIND_ADDR") .expect("BIND_ADDR not specified") .parse()?; &*Box::leak(ServerConfig { listen: opt_listen }.into()) }; //let subject_alt_names = vec!["localhost".to_string()]; let subject_alt_names = vec![]; let (cert, key) = if !std::path::Path::new("cert.der").exists() || !std::path::Path::new("key.der").exists() { let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; let key = rustls::PrivateKey(cert.serialize_private_key_der()); let cert = rustls::Certificate(cert.serialize_der()?); std::fs::write("key.der", &key.0)?; std::fs::write("cert.der", &cert.0)?; (cert, key) } else { let cert = rustls::Certificate(std::fs::read("cert.der")?); let key = rustls::PrivateKey(std::fs::read("key.der")?); (cert, key) }; let mut server_crypto = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert], key) .unwrap(); server_crypto.alpn_protocols = vec![ALPN_QUINOA.as_bytes().to_owned()]; let mut transport = transport_config(); transport.max_concurrent_uni_streams(0_u8.into()); let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); server_config.transport_config(transport.into()); server_config.use_retry(true); let endpoint = quinn::Endpoint::server(server_config, cfg.listen)?; info!("listening on {}", endpoint.local_addr()?); while let Some(conn) = endpoint.accept().await { info!("connection incoming"); tokio::spawn(async move { if let Err(e) = greet_conn(cfg, conn).await { error!("connection failed: {reason}", reason = e.to_string()); } }); } Ok(()) } fn is_broken_pipe(r: &Result) -> bool { if let Err(e) = r { e.kind() == std::io::ErrorKind::BrokenPipe } else { false } } fn transport_config() -> quinn::TransportConfig { let mut transport = quinn::TransportConfig::default(); transport.stream_receive_window((64u32 * 1024 * 1024).into()); transport.send_window(64 * 1024 * 1024); transport.receive_window((64u32 * 1024 * 1024).into()); transport } #[derive(Debug, Clone, Copy)] struct FinishedEarly; impl fmt::Display for FinishedEarly { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { ReadExactError::FinishedEarly.fmt(f) } } impl std::error::Error for FinishedEarly {} /// Reads a message `T`. `buf` will be cleared before reading. async fn read_msg<'a, T: serde::Deserialize<'a>>( recv: &mut RecvStream, buf: &'a mut Vec, ) -> Result> { let mut size = [0u8; 2]; match recv.read_exact(&mut size).await { Ok(()) => {} Err(ReadExactError::FinishedEarly) => return Ok(Err(FinishedEarly)), Err(ReadExactError::ReadError(e)) => return Err(e.into()), } let size = u16::from_le_bytes(size); buf.clear(); buf.reserve(size.into()); recv.take(size.into()).read_to_end(buf).await?; Message::raw_to_value(buf).map(Ok) } struct Message(Vec); impl Message { pub fn from_value(value: &T) -> Result { Ok(Self(rmp_serde::to_vec(value)?)) } pub fn from_raw(data: Vec) -> Self { Self(data) } async fn write_len(&self, send: &mut SendStream) -> Result<()> { send.write_all(&u16::try_from(self.0.len())?.to_le_bytes()) .await .map_err(|e| e.into()) } pub async fn write_ref(&self, send: &mut SendStream) -> Result<()> { self.write_len(send).await?; send.write_all(&self.0).await?; Ok(()) } pub async fn write(self, send: &mut SendStream) -> Result<()> { self.write_len(send).await?; send.write_chunk(self.0.into()).await?; Ok(()) } fn raw_to_value<'de, T: serde::Deserialize<'de>>(data: &'de [u8]) -> Result { rmp_serde::from_slice(data) .with_context(|| { format!("reading a {} byte message", data.len()) }) .with_context(|| { let mut data = data; format!("{:?}", rmpv::decode::read_value_ref(&mut data)) }) } pub fn to_value<'de, T: serde::Deserialize<'de>>(&'de self) -> Result { Self::raw_to_value(&self.0) } } async fn write_msg(send: &mut SendStream, value: &T) -> Result<()> { Message::from_value(value)?.write(send).await } struct InformedServerCertVerifier { cfg: &'static ClientConfig, } impl InformedServerCertVerifier { fn find_known_host<'a, 'b>( known_hosts: &'b [KnownHost<'a>], subject_name: webpki::SubjectNameRef<'_>, ) -> Option<(usize, &'b KnownHost<'a>)> { known_hosts .iter() .enumerate() .find(|(_, h)| h.host.as_ref() == subject_name.as_ref()) } fn inform( &self, known_hosts: &mut Vec>, end_entity: &rustls::Certificate, subject_name: webpki::SubjectNameRef<'_>, known_as: Option, ) -> Result<(), rustls::CertificateError> { use rustls::CertificateError; let hash = Hash::new(&end_entity.0); eprintln!( "The authenticity of host {} can't be established.", match subject_name { SubjectNameRef::DnsName(name) => name.into(), SubjectNameRef::IpAddress(ip) => std::str::from_utf8(match ip { webpki::IpAddrRef::V4(b, _) => b, webpki::IpAddrRef::V6(b, _) => b, }).map_err(|e| CertificateError::Other(Arc::new(e)))?, } ); eprintln!("Certificate hash is {}", hash); if let Some(known_as) = known_as.and_then(|i| known_hosts.get(i)) { eprintln!("Previously known as {}", known_as.hash); } eprintln!("This key is not known by any other names (TODO: check)"); loop { eprintln!("Are you sure you want to continue connecting (yes/no)?"); let mut s = String::new(); std::io::stdin() .read_line(&mut s) .map_err(|e| CertificateError::Other(Arc::new(e)))?; if s.ends_with('\n') { s.pop(); } let yes = if s == "yes" { true } else if s == "no" { false } else { continue; }; if yes { let subject_name = match subject_name { webpki::SubjectNameRef::DnsName(dns_name) => { let dns_name = Box::leak(dns_name.as_ref().to_owned().into_boxed_slice()); webpki::SubjectNameRef::DnsName( webpki::DnsNameRef::try_from_ascii(&*dns_name).unwrap(), ) } webpki::SubjectNameRef::IpAddress(ip_addr) => { let ip_addr: &'static webpki::IpAddr = Box::leak(Box::new(ip_addr.to_owned())); let ip_addr: &'static str = ip_addr.as_ref(); webpki::SubjectNameRef::IpAddress( webpki::IpAddrRef::try_from_ascii_str(ip_addr).unwrap(), ) } //_ => return Err(CertificateError::NotValidForName.into()), }; let known_host = KnownHost { host: subject_name, hash, }; if known_as.is_some() { while let Some((i, _)) = Self::find_known_host(&known_hosts, subject_name) { known_hosts.remove(i); } } known_hosts.push(known_host); if let Err(e) = write_known_hosts(&self.cfg.known_hosts_file, &known_hosts) { error!( "Couldn't persist the new known-host. Continuing anyway.\n{}", e ); } return Ok(()); } else { eprintln!("Understood. Aborting..."); return Err(CertificateError::NotValidForName.into()); } } } } impl ServerCertVerifier for InformedServerCertVerifier { fn verify_server_cert( &self, end_entity: &rustls::Certificate, _intermediates: &[rustls::Certificate], server_name: &rustls::ServerName, _scts: &mut dyn Iterator, _ocsp_response: &[u8], _now: std::time::SystemTime, ) -> std::result::Result { use rustls::client::ServerCertVerified; use rustls::CertificateError; use rustls::ServerName; info!("starting verification"); fn pki_error(error: webpki::Error) -> rustls::Error { use webpki::Error::*; match error { BadDer | BadDerTime => CertificateError::BadEncoding.into(), CertNotValidYet => CertificateError::NotValidYet.into(), CertExpired | InvalidCertValidity => CertificateError::Expired.into(), UnknownIssuer => CertificateError::UnknownIssuer.into(), CertNotValidForName => CertificateError::NotValidForName.into(), InvalidSignatureForPublicKey | UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => { CertificateError::BadSignature.into() } _ => CertificateError::Other(Arc::new(error)).into(), } } let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?; let ip_addr_slot; let subject_name = match server_name { ServerName::DnsName(dns_name) => webpki::SubjectNameRef::DnsName( webpki::DnsNameRef::try_from_ascii_str(dns_name.as_ref()) .map_err(|_| CertificateError::NotValidForName)?, ), ServerName::IpAddress(ip_addr) => { ip_addr_slot = webpki::IpAddr::from(*ip_addr); webpki::SubjectNameRef::IpAddress(webpki::IpAddrRef::from(&ip_addr_slot)) } _ => return Err(CertificateError::NotValidForName.into()), }; // TODO: is expiry checked for us? /*cert.verify_is_valid_for_subject_name(subject_name) .map_err(pki_error)?;*/ let mut known_hosts = self.cfg.known_hosts.lock(); if let Some((h_index, h)) = Self::find_known_host(&known_hosts, subject_name) { if let Err(e) = h.hash.verify(&end_entity.0) { debug!("verification failed: {}", e); self.inform(&mut known_hosts, end_entity, subject_name, Some(h_index))?; } else { eprintln!("Host authenticity verified"); } } else { self.inform(&mut known_hosts, end_entity, subject_name, None)?; } Ok(ServerCertVerified::assertion()) } } #[derive(Debug, Clone)] struct KnownHost<'a> { host: SubjectNameRef<'a>, hash: Hash, } #[derive(Debug, Clone, PartialEq, Eq)] enum Hash { SHA2_512([u8; 64]), /// No hashing. Stores the certificate directly. RAW(Vec), } impl Hash { pub fn new(data: &[u8]) -> Self { Self::raw(data) } pub fn sha2_512(data: &[u8]) -> Self { use sha2::digest::FixedOutput; use sha2::Digest; let mut hsr = sha2::Sha512::new(); hsr.update(data); Self::SHA2_512(hsr.finalize_fixed().into()) } pub fn raw(data: &[u8]) -> Self { Self::RAW(data.to_owned()) } pub fn name(&self) -> &'static str { match self { Hash::SHA2_512(_) => "SHA2_512", Hash::RAW(_) => "RAW", } } pub fn bytes(&self) -> &[u8] { match self { Hash::SHA2_512(b) => b, Hash::RAW(b) => b, } } pub fn verify(&self, data: &[u8]) -> Result<()> { let matches = match self { Hash::SHA2_512(_) => { let hash = Self::sha2_512(data); hash == *self } Hash::RAW(b) => data == b, }; if matches { Ok(()) } else { Err(anyhow!( "Certificate hash does not match hash in known_hosts file" )) } } } const B64: base64::engine::general_purpose::GeneralPurpose = base64::engine::general_purpose::STANDARD_NO_PAD; impl fmt::Display for Hash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use fast_hex::Encode; match self { Hash::SHA2_512(b) => { write!( f, "{}:{}", self.name(), fast_hex::Encoder::::display_sized(b) ) } Hash::RAW(b) => { write!(f, "{}:{}", self.name(), B64.encode(b)) } } } } impl FromStr for Hash { type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { let (algo, hash) = s .split_once(':') .context("Expected an algorithm and raw hash")?; let hash = match algo { "SHA2_512" => Self::SHA2_512( fast_hex::Decoder::decode_sized(hash.as_bytes().try_into()?) .ok_or_else(|| anyhow!("Invalid hexadecimal"))?, ), "RAW" => Self::RAW(B64.decode(hash)?), _ => bail!("Unrecognized algorithm: {}", algo), }; Ok(hash) } } impl KnownHost<'static> { fn from_str(s: &'static str) -> Result { let (host, s) = s.split_once(' ').context("Expected a host and hash")?; let host = webpki::SubjectNameRef::try_from_ascii_str(host).map_err(|e| anyhow!("{:?}", e))?; let hash = s .parse::() .context("While parsing a known-host hash")?; Ok(Self { host, hash }) } } /*fn append_known_host(file: &str, host: KnownHost<'_>) -> Result<()> { use std::io::Write; let mut file = std::fs::OpenOptions::new() .read(false) .append(true) .create(true) .open(file)?; file.write_all(b"\n")?; file.write_all(host.host.as_ref())?; file.write_all(b" ")?; file.write_all(host.hash.to_string().as_bytes())?; Ok(()) }*/ fn write_known_hosts(file: &str, hosts: &[KnownHost<'_>]) -> Result<()> { use std::io::Write; let mut file = std::fs::OpenOptions::new() .read(false) .write(true) .truncate(true) .create(true) .open(file)?; for host in hosts { file.write_all(host.host.as_ref())?; file.write_all(b" ")?; file.write_all(host.hash.to_string().as_bytes())?; file.write_all(b"\n")?; } Ok(()) } async fn run_client(mut args: std::env::Args) -> Result<()> { info!("running client"); let home_dir = std::env::var("HOME")?; let cfg_dir = format!("{}/.config/quinoa", home_dir); tokio::fs::create_dir_all(&cfg_dir).await?; let known_hosts_file = format!("{}/known_hosts", cfg_dir); let known_hosts = if std::path::Path::new(&known_hosts_file).exists() { let s = Box::leak( tokio::fs::read_to_string(&known_hosts_file) .await? .into_boxed_str(), ); s.lines() .filter(|s| !s.is_empty() && !s.starts_with('#')) .map(KnownHost::from_str) .collect::>>()? } else { Vec::new() }; let ssh_key_file = format!("{}/.ssh/id_ed25519", home_dir); //let ssh_key_file = format!("{}/.ssh/id_rsa", home_dir); let cfg = &*Box::leak( ClientConfig { known_hosts_file, known_hosts: known_hosts.into(), ssh_key_file, } .into(), ); let mut conn_str = None; let mut forwards = Vec::new(); let mut use_key = true; while let Some(arg) = args.next() { if let Some(arg) = arg.strip_prefix('-') { match arg { "-no-key" => { use_key = false; } "-forward" => { let v = args.next().context("Expected a value of the form LOCAL_ADDR:LOCAL_PORT[->|<-]REMOTE_ADDR:REMOTE_PORT (and /[tcp|udp] on bind side)")?; forwards.push(v.parse::()?); } _ => bail!("Unrecognized option: -{}", arg), } } else if conn_str.is_none() { conn_str = Some(arg); } else { bail!("Unexpected argument: {:?}", arg); } } let (username, host) = conn_str .as_ref() .and_then(|s| s.split_once('@')) .context("Expected an argument of the form USERNAME@HOST")?; let (host_name, port) = host.split_once(':').unwrap_or((host, "8022")); let port = port.parse::()?; let mut client_crypto = rustls::ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(InformedServerCertVerifier { cfg }) as _) .with_no_client_auth(); client_crypto.alpn_protocols = vec![ALPN_QUINOA.as_bytes().to_owned()]; let mut transport = transport_config(); transport.keep_alive_interval(Some(std::time::Duration::from_secs(5))); let mut client_config = quinn::ClientConfig::new(Arc::new(client_crypto)); client_config.transport_config(transport.into()); //let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap())?; let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?; endpoint.set_default_client_config(client_config); info!("connecting"); let conn = endpoint.connect(host.parse()?, host_name)?.await?; // authenticating client { let (mut send, mut recv) = conn.open_bi().await?; if use_key { auth::ssh_key::client_authenticate(cfg, &mut send, &mut recv, username).await?; } else { auth::password::client_authenticate(&conn, &mut send, &mut recv, username).await?; } } // authenticated client for forward in forwards { match forward.direction { ForwardDirection::LocalToRemote => { let conn = conn.clone(); tokio::task::spawn(async move { if let Err(e) = do_forward_to(&conn, forward).await { error!("Error in local-to-remote forwarder: {}", e); } }); } ForwardDirection::RemoteToLocal => todo!(), } } let _reset = { use nix::sys::termios::*; struct Reset(Termios); impl Drop for Reset { fn drop(&mut self) { //_ = crossterm::terminal::disable_raw_mode(); _ = tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &self.0); println!("termios reset!"); } } let mut termios = tcgetattr(libc::STDIN_FILENO)?; let reset = Reset(termios.clone()); termios.local_flags.remove(LocalFlags::ECHO); termios.local_flags.remove(LocalFlags::ICANON); termios.local_flags.remove(LocalFlags::ISIG); termios.local_flags.remove(LocalFlags::IEXTEN); termios.input_flags.remove(InputFlags::IXON); termios.input_flags.remove(InputFlags::ICRNL); termios.output_flags.remove(OutputFlags::OPOST); tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &termios)?; reset }; do_shell(&conn).await } async fn open_stream( conn: &quinn::Connection, stream: &Stream<'_>, ) -> Result<(SendStream, RecvStream)> { let (mut send, recv) = conn.open_bi().await?; write_msg(&mut send, stream).await?; info!("connected"); Ok((send, recv)) } pin_project_lite::pin_project! { struct SendRecvStream { #[pin] send: SendStream, #[pin] recv: RecvStream, } } impl tokio::io::AsyncWrite for SendRecvStream { fn poll_write( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { self.project().send.poll_write(cx, buf) } fn poll_flush( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.project().send.poll_flush(cx) } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.project().send.poll_shutdown(cx) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll> { self.project().send.poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { self.send.is_write_vectored() } } impl tokio::io::AsyncRead for SendRecvStream { fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { self.project().recv.poll_read(cx, buf) } } async fn do_forward_to(conn: &quinn::Connection, spec: ForwardSpec) -> Result<()> { assert_eq!(spec.direction, ForwardDirection::LocalToRemote); match spec.protocol { ForwardProtocol::Tcp => { let socket = TcpSocket::new_v4()?; socket.bind(spec.bind.into())?; let list = socket.listen(1024)?; loop { let (mut local_stream, peer) = list.accept().await?; let (send, recv) = open_stream( conn, &Stream::Forward { addr: spec.forward.0, port: spec.forward.1, protocol: spec.protocol, }, ) .await?; let mut remote_stream = SendRecvStream { send, recv }; tokio::task::spawn(async move { tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?; Result::<_>::Ok(()) }); } } ForwardProtocol::Udp => { // this will require more thought into how to keep track of the peer. todo!() /*let socket = UdpSocket::bind(spec.bind).await?; let mut buf = Vec::with_capacity(4096); let (send, recv) = open_stream(conn, &Stream::Forward { addr: spec.forward.0, port: spec.forward.1, protocol: spec.protocol, }).await?; loop { socket.recv_buf(&mut buf).await?; let mut remote_stream = SendRecvStream(send, recv); tokio::task::spawn(async move { tokio::io::copy_bidirectional(&mut socket, &mut remote_stream).await?; Result::<_>::Ok(()) }); }*/ } } } async fn do_shell(conn: &quinn::Connection) -> Result<()> { let env_term = std::env::var("TERM"); let env_terminfo = std::env::var("TERMINFO"); let env_term = if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { let info = terminfo::read_terminfo(&path, &name).await?; Some((name, info)) } else { None }; let stream = &Stream::Shell { env_term: env_term.as_ref().map(|(name, info)| { Term { name, info: info.as_slice().into() } }), command: None, }; let (mut send, mut recv) = open_stream(&conn, &stream).await?; let mut stdin = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDIN_FILENO)) }; let mut stdout = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) }; let mut stdin_buf = Vec::with_capacity(4096); //let mut stdout_buf = Vec::with_capacity(4096); let mut stdout_buf = vec![bytes::Bytes::new(); 128]; let mut stdin_eof = false; loop { tokio::select! { /*r = tokio::io::copy(&mut stdin, &mut send) => { r?; info!("EOF on stdin"); }*/ r = stdin.read_buf(&mut stdin_buf), if !stdin_eof => { if r? == 0 { stdin_eof = true; } send.write_all(&stdin_buf).await?; stdin_buf.clear(); //info!("sent stdin"); } r = recv.read_chunks(&mut stdout_buf) => { if let Some(n) = r? { for chunk in &stdout_buf[..n] { stdout.write_all(&chunk).await?; } //info!("recv stdout"); } } /*r = recv.read_buf(&mut stdout_buf) => if r? > 0 { stdout.write_all(&stdout_buf).await?; stdout_buf.clear(); //info!("recv stdout"); }*/, r = send.stopped() => { info!("Remote disconnected"); let code = r?.into_inner(); if code == 0 { return Ok(()); } else { return Err(anyhow!("Error code {}", code)); } } e = conn.closed() => { info!("Remote disconnected: {}", e); return Err(anyhow!("Remote connection closed")); } } } } #[cfg(feature = "server")] async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Result<()> { info!("greeting connection"); let conn = conn.await?; let span = info_span!( "connection", remote = %conn.remote_address(), protocol = %conn .handshake_data() .unwrap() .downcast::().unwrap() .protocol .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()), username = tracing::field::Empty, ); let (user_info, env) = match authenticate_conn(cfg, &conn).instrument(span.clone()).await { Ok(t) => t, Err(e) => { error!("authentication failed: {}", e.to_string()); conn.close(1u8.into(), b"authentication error"); return Ok(()); } }; span.record("username", &user_info.user.name); if let Err(e) = handle_conn(cfg, &conn, user_info, env).instrument(span).await { error!("handler failed: {}", e.to_string()); conn.close(1u8.into(), b"handler error"); } Ok(()) } #[cfg(feature = "server")] async fn authenticate_conn( cfg: &'static ServerConfig, conn: &quinn::Connection, ) -> Result<(UserInfo, Vec<(CString, CString)>)> { info!("authenticating connection"); let (mut send, mut recv) = conn.accept_bi().await?; let mut hello_buf = Vec::new(); let hello = read_msg::(&mut recv, &mut hello_buf).await??; let user_info = user_info::get_user_info(&hello.username).await?; match hello.auth_method { auth::Method::Password => auth::password::server_authenticate(&mut send, &mut recv, hello.username.to_owned()).await?, auth::Method::SshKey { public_key } => auth::ssh_key::server_authenticate(&mut send, &mut recv, user_info.user.id, public_key).await?, } let env = Vec::new(); info!("logged in"); write_msg(&mut send, &auth::Question::LoggedIn).await?; send.finish().await?; recv.stop(0u8.into())?; Ok((user_info, env)) } #[cfg(feature = "server")] async fn handle_conn( cfg: &'static ServerConfig, conn: &quinn::Connection, user_info: UserInfo, env: Vec<(CString, CString)>, ) -> Result<()> { info!("established"); loop { let stream = conn.accept_bi().await; let (send, mut recv) = match stream { Err(quinn::ConnectionError::ApplicationClosed { .. }) => { info!("connection closed"); return Ok(()); } Err(e) => { return Err(e.into()); } Ok(s) => s, }; let id = send.id(); let span = info_span!( "stream", dir = %match id.dir() { quinn_proto::Dir::Bi => "bi", quinn_proto::Dir::Uni => "uni", }, id = id.index() ); let user_info = user_info.clone(); let env = env.clone(); tokio::task::spawn( async move { let mut stream_buf = Vec::new(); let r = match read_msg::(&mut recv, &mut stream_buf).await { Ok(Ok(t)) => Ok(t), Ok(Err(e)) => Err(e.into()), Err(e) => Err(e), }; let stream = match r { Ok(t) => t, Err(e) => { error!("Error in stream setup: {}", e); return; } }; let r = match stream { Stream::Exec => { let span = info_span!("stream_exec"); handle_stream_exec(cfg, send, recv, &user_info) .instrument(span) .await } Stream::Shell { env_term, command } => { let span = info_span!("stream_shell", ?command); handle_stream_shell(cfg, send, recv, &user_info, env, env_term, command) .instrument(span) .await } Stream::Forward { addr, port, protocol, } => { let span = info_span!("stream_forward", %addr, %port, %protocol); handle_stream_forward(cfg, send, recv, addr, port, protocol) .instrument(span) .await } }; if let Err(e) = r { error!("Error in stream handler: {}", e); } } .instrument(span), ); } } async fn handle_stream_exec( cfg: &ServerConfig, mut send: SendStream, mut recv: RecvStream, user_info: &UserInfo, ) -> Result<()> { (|| todo!())(); let mut cmd = std::process::Command::new(""); cmd.stdout(Stdio::piped()) .stderr(Stdio::piped()) .stdin(Stdio::piped()); #[cfg(target_family = "unix")] cmd.process_group(0); info!("Running {:?}", cmd); let mut sh = Command::from(cmd).kill_on_drop(true).spawn()?; let mut stdout = sh.stdout.take().unwrap(); let mut stderr = sh.stderr.take().unwrap(); let mut stdin = sh.stdin.take().unwrap(); let mut stdout_buf = Vec::with_capacity(4096); let mut stderr_buf = Vec::with_capacity(4096); let mut stdin_buf = Vec::with_capacity(4096); let mut stdout_eof = false; let mut stderr_eof = false; loop { tokio::select! { r = sh.wait() => { let code = r?; send.finish().await?; if !code.success() { info!("Child exit: {}", code); recv.stop(1u8.into())?; return Ok(()); } else { info!("Child exit"); recv.stop(0u8.into())?; return Ok(()); } } r = stdout.read_buf(&mut stdout_buf), if !stdout_eof => { if is_broken_pipe(&r) || r? == 0 { stdout_eof = true; info!("stdout eof"); } else { send.write_all(&stdout_buf).await?; info!("sent stdout: {:x?}", stdout_buf); stdout_buf.clear(); } }, r = stderr.read_buf(&mut stderr_buf), if !stderr_eof => { if is_broken_pipe(&r) || r? == 0 { stderr_eof = true; info!("stderr eof"); } else { send.write_all(&stderr_buf).await?; stderr_buf.clear(); info!("sent stderr: {:x?}", stderr_buf); } }, r = recv.read_buf(&mut stdin_buf) => if r? > 0 { stdin.write_all(&stdin_buf).await?; stdin_buf.clear(); info!("recv stdin"); }, } } } async fn handle_stream_shell( cfg: &ServerConfig, mut send: SendStream, mut recv: RecvStream, user_info: &UserInfo, env: Vec<(CString, CString)>, env_term: Option>, command: Option<(CStrAsBytes<'_>, Vec>)>, ) -> Result<()> { let mut user_passwd_buf = Vec::new(); let user_passwd = passwd::Passwd::from_uid(user_info.user.id, &mut user_passwd_buf)?; let user_home = user_passwd.as_ref().map(|p| p.dir); let shell = command .as_ref() .map(|(command, _)| command.as_ref()) .unwrap_or_else(|| { user_passwd .as_ref() .map(|p| p.shell) .unwrap_or(CStr::from_bytes_with_nul(b"/bin/sh\0").unwrap()) }); let shell_name = CStr::from_bytes_with_nul( shell .to_bytes_with_nul() .iter() .rposition(|b| *b == b'/') .map(|i| &shell.to_bytes_with_nul()[i + 1..]) .unwrap_or(shell.to_bytes_with_nul()), )?; let args = command .as_ref() .map(|(_, args)| args.iter().map(|s| s.as_ref()).collect::>()) .unwrap_or_else(|| vec![&shell_name]); let opt_shell = shell; let c_user_name = CString::new(user_info.user.name.as_str())?; let user_home = user_home.as_ref().and_then(|s| s.to_str().ok()); const TERMINFO_PATH: &str = "/.local/share/quinoa/terminfo"; let terminfo_path = if let Some(user_home) = user_home { let mut terminfo_path = String::with_capacity(user_home.len() + TERMINFO_PATH.len()); terminfo_path.push_str(user_home); for comp in TERMINFO_PATH.split_inclusive('/') { terminfo_path.push_str(comp); if comp != "/" { io_util::ignore_already_exists(io_util::create_dir_owned( &terminfo_path, user_info.user.id, user_info.group.id, ))?; } } Some(terminfo_path) } else { None }; if let (Some(env_term), Some(terminfo_path)) = (&env_term, &terminfo_path) { terminfo::install_terminfo( terminfo_path, &env_term.name, &env_term.info, (user_info.user.id, user_info.group.id), ) .await?; } let c_env_term = env_term.map(|t| CString::new(t.name)).transpose()?; let c_env_terminfo = terminfo_path.map(CString::new).transpose()?; let pre_exec = move || { fn unsetenv(name: &CStr) -> Result<(), nix::Error> { if unsafe { libc::unsetenv(name.as_ptr()) } != 0 { Err(nix::Error::last().into()) } else { Ok(()) } } fn setenv(name: &CStr, value: &CStr, overwrite: bool) -> Result<(), nix::Error> { if unsafe { libc::setenv(name.as_ptr(), value.as_ptr(), overwrite as i32) } != 0 { Err(nix::Error::last().into()) } else { Ok(()) } } fn getenv(name: &CStr) -> Option> { unsafe { let ptr = libc::getenv(name.as_ptr()); NonNull::new(CStr::from_ptr(ptr) as *const _ as *mut _) } } extern "C" { #[cfg(any(target_os = "macos", target_os = "ios"))] pub fn _NSGetEnviron() -> *mut *const *const std::ffi::c_char; #[cfg(not(any(target_os = "macos", target_os = "ios")))] static mut environ: *const *const std::ffi::c_char; } let mut keep = Vec::new(); const PRESERVE_ENV: &[&[u8]] = &[b"PATH\0", b"LANG\0"]; unsafe { #[cfg(any(target_os = "macos", target_os = "ios"))] let mut env = *_NSGetEnviron(); #[cfg(not(any(target_os = "macos", target_os = "ios")))] let mut env = environ; if !env.is_null() { while !(*env).is_null() { let key_value = CStr::from_ptr(*env).to_bytes_with_nul(); env = env.add(1); let Some(i) = key_value.iter().position(|b| *b == b'=') else { continue }; let key = &key_value[..i]; let value = &key_value[i + 1..]; if PRESERVE_ENV.iter().any(|s| &s[..s.len() - 1] == key) { keep.push((key, value)); } } } } unsafe { nix::env::clearenv() }?; for (k, v) in keep .into_iter() .chain(env.iter().map(|(k, v)| (k.as_bytes(), v.as_bytes()))) { setenv(&CString::new(k)?, CStr::from_bytes_with_nul(v)?, true)?; } setenv( CStr::from_bytes_with_nul(b"USER\0").unwrap(), &c_user_name, true, )?; if let Some(c_env_terminfo) = &c_env_terminfo { // otherwise system terminfo will be used setenv( CStr::from_bytes_with_nul(b"TERMINFO\0").unwrap(), &c_env_terminfo, true, )?; } if let Some(c_env_term) = &c_env_term { setenv( CStr::from_bytes_with_nul(b"TERM\0").unwrap(), c_env_term, true, )?; } #[cfg(not(any( target_os = "macos", target_os = "ios", target_os = "redox", target_os = "haiku" )))] nix::unistd::setgroups(&user_info.groups.iter().map(|g| g.id).collect::>()) .context("setting supplementary groups")?; nix::unistd::setgid(user_info.group.id).context("setting primary group")?; nix::unistd::setuid(user_info.user.id).context("setting user")?; if nix::unistd::seteuid(Uid::from_raw(0)).is_ok() { return Err(anyhow!("We got the privileges back via seteuid. Very bad!")); } if nix::unistd::setuid(Uid::from_raw(0)).is_ok() { return Err(anyhow!("We got the privileges back via setuid. Very bad!")); } std::env::set_current_dir("/")?; Ok(()) }; let pre_exec = unsafe { pty::PreExec::new(pre_exec) }; let mut sh = pty::create_pty(opt_shell, &args, pre_exec)?; sh.set_nodelay()?; let mut pty = sh.pty; //let mut pty = tokio::io::unix::AsyncFd::with_interest(pty, tokio::io::Interest::READABLE)?; info!("created pty"); //let mut stdin_buf = Vec::with_capacity(4096); let mut pty_buf = Vec::with_capacity(4096); //let mut pty_buf = [0u8]; let (_waker_send, mut waker_recv) = flume::bounded::<()>(1); /*let fd = pty.as_raw_fd(); std::thread::spawn(move || { let mut set = [nix::poll::PollFd::new(fd.as_raw_fd(), nix::poll::PollFlags::POLLIN)]; loop { if let Ok(n) = nix::poll::poll(&mut set, -1) { if n != 0 { if waker_send.send(()).is_err() { break; } } } } });*/ loop { /*if let Some(code) = sh.proc.try_wait()? { send.finish().await?; if code != 0 { info!("Child exit: {}", code); recv.stop(1u8.into())?; return Ok(()); } else { info!("Child exit"); recv.stop(0u8.into())?; return Ok(()); } }*/ //let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50)); struct Wait<'a> { proc: &'a pty::Proc, } impl<'a> Wait<'a> { pub fn new(proc: &'a pty::Proc) -> Self { Self { proc } } } impl<'a> Future for Wait<'a> { type Output = std::io::Result; fn poll( self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll { match self.proc.try_wait() { Ok(Some(code)) => Poll::Ready(Ok(code)), Ok(None) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } } /*pin_project_lite::pin_project! { struct SelectRead { //#[allow(dead_code)] fd: T, #[pin] list: triggered::Listener, } } impl SelectRead { pub fn new(fd: T) -> Self { let (trig, list) = triggered::trigger(); Self { fd, list, set } } } impl Future for SelectRead { type Output = nix::Result<()>; fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll { //let mut this = self; let this = self.project(); match this.list.poll(cx) { Poll::Ready(()) => { let r = nix::poll::poll(this.set, 0); match r { Ok(0) => { cx.waker().wake_by_ref(); Poll::Pending } Ok(_) => Poll::Ready(Ok(())), Err(e) => Poll::Ready(Err(e)), } }, Poll::Pending => { //cx.waker().wake_by_ref(); Poll::Pending }, } //let mut set = *this.set; /*let r = nix::sys::select::select( set.highest().unwrap() + 1, &mut set, None, None, &mut nix::sys::time::TimeVal::new(0, 0), );*/ } }*/ async fn read_pty( //pty: &mut tokio::io::unix::AsyncFd, pty: &mut tokio::fs::File, _waker_recv: &mut flume::Receiver<()>, buf: &mut Vec, send: &mut SendStream, ) -> Result<()> { loop { //let mut pty = pty.readable_mut().await?; //let pty = pty.get_inner_mut(); //SelectRead::new(pty.as_fd()).await?; //_ = waker_recv.recv_async().await; let r = pty.read_buf(buf).await; //_ = waker_recv.try_recv(); if let Err(e) = r { if e.raw_os_error() == Some(35) || e.raw_os_error() == Some(11) { //debug!("not ready: {}", e); //tokio::task::yield_now().await; //tokio::time::sleep(std::time::Duration::from_millis(1)).await; } else { return Err(e.into()); } } else if buf.len() == 0 { debug!("not ready: empty"); //tokio::task::yield_now().await; //tokio::time::sleep(std::time::Duration::from_millis(1)).await; } else { //return Ok(()); send.write_all(&buf).await?; buf.clear(); //info!("sent pty"); } } } tokio::select! { /*_ = redraw.tick() => { sh.pty.read_buf(&mut pty_buf).await?; send.write_all(&pty_buf).await?; pty_buf.clear(); info!("redraw complete"); }*/ r = Wait::new(&sh.proc) => { let code = r?; send.finish().await?; if code != 0 { info!("Child exit: {}", code); recv.stop(1u8.into())?; return Ok(()); } else { info!("Child exit"); recv.stop(0u8.into())?; return Ok(()); } } /*r = tokio::io::copy(&mut pty, &mut send) => { r?; }*/ r = read_pty(&mut pty, &mut waker_recv, &mut pty_buf, &mut send) => { r?; } // FIXME: figure out a maximum chunk size r = recv.read_chunk(usize::MAX, true) => { if let Some(chunk) = r? { pty.write_all(&chunk.bytes).await?; info!("recv stdin"); } } /*r = recv.read_buf(&mut stdin_buf) => if r? > 0 { //pty.get_mut().write_all(&stdin_buf).await?; pty.write_all(&stdin_buf).await?; stdin_buf.clear(); info!("recv stdin"); },*/ } } } async fn handle_stream_forward( cfg: &ServerConfig, send: SendStream, recv: RecvStream, addr: Ipv4Addr, port: u16, protocol: ForwardProtocol, ) -> Result<()> { match protocol { ForwardProtocol::Tcp => { let socket = TcpSocket::new_v4()?; let mut local_stream = socket.connect(SocketAddrV4::new(addr, port).into()).await?; let mut remote_stream = SendRecvStream { send, recv }; tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?; Result::<_>::Ok(()) } ForwardProtocol::Udp => todo!(), } }