quinoa/src/main.rs

1583 lines
51 KiB
Rust

#![deny(unreachable_code)]
#![deny(unused_must_use)]
#[macro_use]
extern crate anyhow;
#[macro_use]
extern crate serde;
#[macro_use]
extern crate tracing;
mod auth;
mod io_util;
mod passwd;
mod pty;
mod ser;
mod terminfo;
mod user_info;
#[cfg(all(feature = "server", target_os = "macos"))]
use pam_client_macos as pam_client;
use tokio::net::TcpSocket;
use std::ffi::{CStr, CString};
use std::fmt;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::os::fd::FromRawFd;
use std::pin::Pin;
use std::ptr::NonNull;
use std::str::FromStr;
use std::sync::Arc;
use std::task::Poll;
use anyhow::{Context, Result};
use base64::Engine as _;
use nix::unistd::Uid;
use quinn::{ReadExactError, RecvStream, SendStream};
use rustls::client::ServerCertVerifier;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::Instrument;
use user_info::UserInfo;
use webpki::SubjectNameRef;
use ser::{ByteSlice, CStrAsBytes};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
struct Term<'a> {
name: &'a str,
info: ByteSlice<'a>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum ForwardDirection {
LocalToRemote,
RemoteToLocal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum ForwardProtocol {
Tcp,
Udp,
}
impl ForwardProtocol {
pub const fn name(self) -> &'static str {
match self {
ForwardProtocol::Tcp => "tcp",
ForwardProtocol::Udp => "udp",
}
}
}
impl fmt::Display for ForwardProtocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
impl FromStr for ForwardProtocol {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
const TCP: &str = ForwardProtocol::Tcp.name();
const UDP: &str = ForwardProtocol::Udp.name();
match s {
TCP => Ok(Self::Tcp),
UDP => Ok(Self::Udp),
_ => Err(anyhow!("Expected one of tcp, udp: {:?}", s)),
}
}
}
#[derive(Debug, Clone)]
struct ForwardSpec {
direction: ForwardDirection,
bind: SocketAddrV4,
// TODO: allow DNS lookup here
forward: (Ipv4Addr, u16),
protocol: ForwardProtocol,
}
impl FromStr for ForwardSpec {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let i = s
.as_bytes()
.iter()
.position(|&b| b == b'-' || b == b'<')
.with_context(|| format!("Expected a local-remote delimiter: {:?}", s))?;
let direction = match s.as_bytes().get(i..i + 2) {
Some(b"->") => ForwardDirection::LocalToRemote,
Some(b"<-") => ForwardDirection::RemoteToLocal,
_ => bail!("Expected a local-remote delimiter: {:?}", s),
};
let local = &s[..i];
let remote = &s[i + 2..];
let (bind, forward) = match direction {
ForwardDirection::LocalToRemote => (local, remote),
ForwardDirection::RemoteToLocal => (remote, local),
};
let (bind_sock, protocol) = bind
.split_once('/')
.context("Expected a port-protocol delimiter")?;
let bind_sock = bind_sock.parse::<SocketAddrV4>()?;
let protocol = protocol.parse()?;
let (forward_addr, forward_port) = forward
.split_once(':')
.context("Expected an address-port delimiter")?;
let forward_addr = forward_addr
.parse::<Ipv4Addr>()
.with_context(|| forward_addr.to_owned())?;
let forward_port = forward_port
.parse::<u16>()
.with_context(|| forward_port.to_owned())?;
Ok(Self {
direction,
bind: bind_sock,
forward: (forward_addr, forward_port),
protocol,
})
}
}
#[derive(Clone, Serialize, Deserialize)]
enum Stream<'a> {
Exec,
Shell {
#[serde(borrow)]
env_term: Option<Term<'a>>,
command: Option<(CStrAsBytes<'a>, Vec<CStrAsBytes<'a>>)>,
},
Forward {
addr: Ipv4Addr,
port: u16,
protocol: ForwardProtocol,
},
// TODO: "backward"
}
impl<'a> fmt::Debug for Stream<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Exec => f.write_str("Exec"),
Self::Shell {
env_term: _,
command,
} => f.debug_struct("Shell").field("command", command).finish(),
Self::Forward {
addr,
port,
protocol,
} => write!(f, "-> {}:{}/{}", addr, port, protocol),
}
}
}
struct Args {
verbose: bool,
command: String,
rem: std::env::Args,
}
impl Args {
pub fn parse() -> Result<Self> {
let mut args = std::env::args();
_ = args.next();
let mut verbose = false;
let command = loop {
let arg = args.next().context("Expected a COMMAND")?;
match arg.as_bytes() {
b"-v" | b"--verbose" => verbose = true,
[b'-', ..] => bail!("Unrecognized option: {}", arg),
_ => break arg,
}
};
Ok(Self { verbose, command, rem: args })
}
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse()?;
let default_level = if args.verbose {
tracing_subscriber::filter::LevelFilter::INFO
} else {
tracing_subscriber::filter::LevelFilter::WARN
};
use tracing_subscriber::fmt::format::FmtSpan;
tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_env_filter(tracing_subscriber::EnvFilter::builder()
.with_default_directive(default_level.into())
.try_from_env()?)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.finish(),
)
.unwrap();
let fut = async move {
match args.command.as_str() {
#[cfg(feature = "server")]
"server" => run_server().await,
"client" => run_client(args.rem).await,
cmd => bail!("Unrecognized command: {}", cmd),
}
};
if std::env::var("NO_CTRLC").is_ok() {
fut.await
} else {
let ctrl_c = tokio::signal::ctrl_c();
tokio::select! {
_ = ctrl_c => {
info!("Aborting");
Ok(())
}
r = fut => r,
}
}
}
const ALPN_QUINOA: &str = "quinoa";
pub struct ServerConfig {
listen: SocketAddr,
}
pub struct ClientConfig {
known_hosts_file: String,
known_hosts: parking_lot::Mutex<Vec<KnownHost<'static>>>,
ssh_key_file: String,
}
#[cfg(feature = "server")]
async fn run_server() -> Result<()> {
let cfg = {
let opt_listen = std::env::var("BIND_ADDR")
.expect("BIND_ADDR not specified")
.parse()?;
&*Box::leak(ServerConfig { listen: opt_listen }.into())
};
//let subject_alt_names = vec!["localhost".to_string()];
let subject_alt_names = vec![];
let (cert, key) = if !std::path::Path::new("cert.der").exists()
|| !std::path::Path::new("key.der").exists()
{
let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;
let key = rustls::PrivateKey(cert.serialize_private_key_der());
let cert = rustls::Certificate(cert.serialize_der()?);
std::fs::write("key.der", &key.0)?;
std::fs::write("cert.der", &cert.0)?;
(cert, key)
} else {
let cert = rustls::Certificate(std::fs::read("cert.der")?);
let key = rustls::PrivateKey(std::fs::read("key.der")?);
(cert, key)
};
let mut server_crypto = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.unwrap();
server_crypto.alpn_protocols = vec![ALPN_QUINOA.as_bytes().to_owned()];
let mut transport = transport_config();
transport.max_concurrent_uni_streams(0_u8.into());
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
server_config.transport_config(transport.into());
server_config.use_retry(true);
let endpoint = quinn::Endpoint::server(server_config, cfg.listen)?;
info!("listening on {}", endpoint.local_addr()?);
while let Some(conn) = endpoint.accept().await {
info!("connection incoming");
tokio::spawn(async move {
if let Err(e) = greet_conn(cfg, conn).await {
error!("connection failed: {reason}", reason = e.to_string());
}
});
}
Ok(())
}
fn transport_config() -> quinn::TransportConfig {
let mut transport = quinn::TransportConfig::default();
transport.stream_receive_window((64u32 * 1024 * 1024).into());
transport.send_window(64 * 1024 * 1024);
transport.receive_window((64u32 * 1024 * 1024).into());
transport
}
#[derive(Debug, Clone, Copy)]
struct FinishedEarly;
impl fmt::Display for FinishedEarly {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
ReadExactError::FinishedEarly.fmt(f)
}
}
impl std::error::Error for FinishedEarly {}
/// Reads a message `T`. `buf` will be cleared before reading.
async fn read_msg<'a, T: serde::Deserialize<'a>>(
recv: &mut RecvStream,
buf: &'a mut Vec<u8>,
) -> Result<Result<T, FinishedEarly>> {
let mut size = [0u8; 2];
match recv.read_exact(&mut size).await {
Ok(()) => {}
Err(ReadExactError::FinishedEarly) => return Ok(Err(FinishedEarly)),
Err(ReadExactError::ReadError(e)) => return Err(e.into()),
}
let size = u16::from_le_bytes(size);
buf.clear();
buf.reserve(size.into());
recv.take(size.into()).read_to_end(buf).await?;
Message::raw_to_value(buf).map(Ok)
}
struct Message(Vec<u8>);
impl Message {
pub fn from_value<T: serde::Serialize>(value: &T) -> Result<Self> {
Ok(Self(rmp_serde::to_vec(value)?))
}
pub fn from_raw(data: Vec<u8>) -> Self {
Self(data)
}
async fn write_len(&self, send: &mut SendStream) -> Result<()> {
send.write_all(&u16::try_from(self.0.len())?.to_le_bytes())
.await
.map_err(|e| e.into())
}
pub async fn write_ref(&self, send: &mut SendStream) -> Result<()> {
self.write_len(send).await?;
send.write_all(&self.0).await?;
Ok(())
}
pub async fn write(self, send: &mut SendStream) -> Result<()> {
self.write_len(send).await?;
send.write_chunk(self.0.into()).await?;
Ok(())
}
fn raw_to_value<'de, T: serde::Deserialize<'de>>(data: &'de [u8]) -> Result<T> {
rmp_serde::from_slice(data)
.with_context(|| {
format!("reading a {} byte message", data.len())
})
.with_context(|| {
let mut data = data;
format!("{:?}", rmpv::decode::read_value_ref(&mut data))
})
}
pub fn to_value<'de, T: serde::Deserialize<'de>>(&'de self) -> Result<T> {
Self::raw_to_value(&self.0)
}
}
async fn write_msg<T: serde::Serialize>(send: &mut SendStream, value: &T) -> Result<()> {
Message::from_value(value)?.write(send).await
}
struct InformedServerCertVerifier {
cfg: &'static ClientConfig,
}
impl InformedServerCertVerifier {
fn find_known_host<'a, 'b>(
known_hosts: &'b [KnownHost<'a>],
subject_name: webpki::SubjectNameRef<'_>,
) -> Option<(usize, &'b KnownHost<'a>)> {
known_hosts
.iter()
.enumerate()
.find(|(_, h)| h.host.as_ref() == subject_name.as_ref())
}
fn inform(
&self,
known_hosts: &mut Vec<KnownHost<'static>>,
end_entity: &rustls::Certificate,
subject_name: webpki::SubjectNameRef<'_>,
known_as: Option<usize>,
) -> Result<(), rustls::CertificateError> {
use rustls::CertificateError;
let hash = Hash::new(&end_entity.0);
eprintln!(
"The authenticity of host {} can't be established.",
match subject_name {
SubjectNameRef::DnsName(name) => name.into(),
SubjectNameRef::IpAddress(ip) => std::str::from_utf8(match ip {
webpki::IpAddrRef::V4(b, _) => b,
webpki::IpAddrRef::V6(b, _) => b,
}).map_err(|e| CertificateError::Other(Arc::new(e)))?,
}
);
eprintln!("Certificate hash is {}", hash);
if let Some(known_as) = known_as.and_then(|i| known_hosts.get(i)) {
eprintln!("Previously known as {}", known_as.hash);
}
eprintln!("This key is not known by any other names (TODO: check)");
loop {
eprintln!("Are you sure you want to continue connecting (yes/no)?");
let mut s = String::new();
std::io::stdin()
.read_line(&mut s)
.map_err(|e| CertificateError::Other(Arc::new(e)))?;
if s.ends_with('\n') {
s.pop();
}
let yes = if s == "yes" {
true
} else if s == "no" {
false
} else {
continue;
};
if yes {
let subject_name = match subject_name {
webpki::SubjectNameRef::DnsName(dns_name) => {
let dns_name = Box::leak(dns_name.as_ref().to_owned().into_boxed_slice());
webpki::SubjectNameRef::DnsName(
webpki::DnsNameRef::try_from_ascii(&*dns_name).unwrap(),
)
}
webpki::SubjectNameRef::IpAddress(ip_addr) => {
let ip_addr: &'static webpki::IpAddr =
Box::leak(Box::new(ip_addr.to_owned()));
let ip_addr: &'static str = ip_addr.as_ref();
webpki::SubjectNameRef::IpAddress(
webpki::IpAddrRef::try_from_ascii_str(ip_addr).unwrap(),
)
} //_ => return Err(CertificateError::NotValidForName.into()),
};
let known_host = KnownHost {
host: subject_name,
hash,
};
if known_as.is_some() {
while let Some((i, _)) = Self::find_known_host(&known_hosts, subject_name) {
known_hosts.remove(i);
}
}
known_hosts.push(known_host);
if let Err(e) = write_known_hosts(&self.cfg.known_hosts_file, &known_hosts) {
error!(
"Couldn't persist the new known-host. Continuing anyway.\n{}",
e
);
}
return Ok(());
} else {
eprintln!("Understood. Aborting...");
return Err(CertificateError::NotValidForName.into());
}
}
}
}
impl ServerCertVerifier for InformedServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
server_name: &rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
use rustls::client::ServerCertVerified;
use rustls::CertificateError;
use rustls::ServerName;
info!("starting verification");
fn pki_error(error: webpki::Error) -> rustls::Error {
use webpki::Error::*;
match error {
BadDer | BadDerTime => CertificateError::BadEncoding.into(),
CertNotValidYet => CertificateError::NotValidYet.into(),
CertExpired | InvalidCertValidity => CertificateError::Expired.into(),
UnknownIssuer => CertificateError::UnknownIssuer.into(),
CertNotValidForName => CertificateError::NotValidForName.into(),
InvalidSignatureForPublicKey
| UnsupportedSignatureAlgorithm
| UnsupportedSignatureAlgorithmForPublicKey => {
CertificateError::BadSignature.into()
}
_ => CertificateError::Other(Arc::new(error)).into(),
}
}
let _cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?;
let ip_addr_slot;
let subject_name = match server_name {
ServerName::DnsName(dns_name) => webpki::SubjectNameRef::DnsName(
webpki::DnsNameRef::try_from_ascii_str(dns_name.as_ref())
.map_err(|_| CertificateError::NotValidForName)?,
),
ServerName::IpAddress(ip_addr) => {
ip_addr_slot = webpki::IpAddr::from(*ip_addr);
webpki::SubjectNameRef::IpAddress(webpki::IpAddrRef::from(&ip_addr_slot))
}
_ => return Err(CertificateError::NotValidForName.into()),
};
// TODO: is expiry checked for us?
/*cert.verify_is_valid_for_subject_name(subject_name)
.map_err(pki_error)?;*/
let mut known_hosts = self.cfg.known_hosts.lock();
if let Some((h_index, h)) = Self::find_known_host(&known_hosts, subject_name) {
if let Err(e) = h.hash.verify(&end_entity.0) {
debug!("verification failed: {}", e);
self.inform(&mut known_hosts, end_entity, subject_name, Some(h_index))?;
} else {
eprintln!("Host authenticity verified");
}
} else {
self.inform(&mut known_hosts, end_entity, subject_name, None)?;
}
Ok(ServerCertVerified::assertion())
}
}
#[derive(Debug, Clone)]
struct KnownHost<'a> {
host: SubjectNameRef<'a>,
hash: Hash,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Hash {
SHA2_512([u8; 64]),
/// No hashing. Stores the certificate directly.
RAW(Vec<u8>),
}
impl Hash {
pub fn new(data: &[u8]) -> Self {
Self::raw(data)
}
pub fn sha2_512(data: &[u8]) -> Self {
use sha2::digest::FixedOutput;
use sha2::Digest;
let mut hsr = sha2::Sha512::new();
hsr.update(data);
Self::SHA2_512(hsr.finalize_fixed().into())
}
pub fn raw(data: &[u8]) -> Self {
Self::RAW(data.to_owned())
}
pub fn name(&self) -> &'static str {
match self {
Hash::SHA2_512(_) => "SHA2_512",
Hash::RAW(_) => "RAW",
}
}
pub fn bytes(&self) -> &[u8] {
match self {
Hash::SHA2_512(b) => b,
Hash::RAW(b) => b,
}
}
pub fn verify(&self, data: &[u8]) -> Result<()> {
let matches = match self {
Hash::SHA2_512(_) => {
let hash = Self::sha2_512(data);
hash == *self
}
Hash::RAW(b) => data == b,
};
if matches {
Ok(())
} else {
Err(anyhow!(
"Certificate hash does not match hash in known_hosts file"
))
}
}
}
const B64: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::STANDARD_NO_PAD;
impl fmt::Display for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use fast_hex::Encode;
match self {
Hash::SHA2_512(b) => {
write!(
f,
"{}:{}",
self.name(),
fast_hex::Encoder::<false>::display_sized(b)
)
}
Hash::RAW(b) => {
write!(f, "{}:{}", self.name(), B64.encode(b))
}
}
}
}
impl FromStr for Hash {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let (algo, hash) = s
.split_once(':')
.context("Expected an algorithm and raw hash")?;
let hash = match algo {
"SHA2_512" => Self::SHA2_512(
fast_hex::Decoder::decode_sized(hash.as_bytes().try_into()?)
.ok_or_else(|| anyhow!("Invalid hexadecimal"))?,
),
"RAW" => Self::RAW(B64.decode(hash)?),
_ => bail!("Unrecognized algorithm: {}", algo),
};
Ok(hash)
}
}
impl KnownHost<'static> {
fn from_str(s: &'static str) -> Result<Self> {
let (host, s) = s.split_once(' ').context("Expected a host and hash")?;
let host =
webpki::SubjectNameRef::try_from_ascii_str(host).map_err(|e| anyhow!("{:?}", e))?;
let hash = s
.parse::<Hash>()
.context("While parsing a known-host hash")?;
Ok(Self { host, hash })
}
}
/*fn append_known_host(file: &str, host: KnownHost<'_>) -> Result<()> {
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.read(false)
.append(true)
.create(true)
.open(file)?;
file.write_all(b"\n")?;
file.write_all(host.host.as_ref())?;
file.write_all(b" ")?;
file.write_all(host.hash.to_string().as_bytes())?;
Ok(())
}*/
fn write_known_hosts(file: &str, hosts: &[KnownHost<'_>]) -> Result<()> {
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.read(false)
.write(true)
.truncate(true)
.create(true)
.open(file)?;
for host in hosts {
file.write_all(host.host.as_ref())?;
file.write_all(b" ")?;
file.write_all(host.hash.to_string().as_bytes())?;
file.write_all(b"\n")?;
}
Ok(())
}
async fn run_client(mut args: std::env::Args) -> Result<()> {
info!("running client");
let home_dir = std::env::var("HOME")?;
let cfg_dir = format!("{}/.config/quinoa", home_dir);
tokio::fs::create_dir_all(&cfg_dir).await?;
let known_hosts_file = format!("{}/known_hosts", cfg_dir);
let known_hosts = if std::path::Path::new(&known_hosts_file).exists() {
let s = Box::leak(
tokio::fs::read_to_string(&known_hosts_file)
.await?
.into_boxed_str(),
);
s.lines()
.filter(|s| !s.is_empty() && !s.starts_with('#'))
.map(KnownHost::from_str)
.collect::<Result<Vec<_>>>()?
} else {
Vec::new()
};
let ssh_key_file = format!("{}/.ssh/id_ed25519", home_dir);
//let ssh_key_file = format!("{}/.ssh/id_rsa", home_dir);
let cfg = &*Box::leak(
ClientConfig {
known_hosts_file,
known_hosts: known_hosts.into(),
ssh_key_file,
}
.into(),
);
let mut conn_str = None;
let mut forwards = Vec::new();
let mut use_key = true;
while let Some(arg) = args.next() {
if let Some(arg) = arg.strip_prefix('-') {
match arg {
"-no-key" => {
use_key = false;
}
"-forward" => {
let v = args.next().context("Expected a value of the form LOCAL_ADDR:LOCAL_PORT[->|<-]REMOTE_ADDR:REMOTE_PORT (and /[tcp|udp] on bind side)")?;
forwards.push(v.parse::<ForwardSpec>()?);
}
_ => bail!("Unrecognized option: -{}", arg),
}
} else if conn_str.is_none() {
conn_str = Some(arg);
} else {
bail!("Unexpected argument: {:?}", arg);
}
}
let (username, host) = conn_str
.as_ref()
.and_then(|s| s.split_once('@'))
.context("Expected an argument of the form USERNAME@HOST")?;
let (host_name, _port) = host.split_once(':').unwrap_or((host, "8022"));
let mut client_crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(InformedServerCertVerifier { cfg }) as _)
.with_no_client_auth();
client_crypto.alpn_protocols = vec![ALPN_QUINOA.as_bytes().to_owned()];
let mut transport = transport_config();
transport.keep_alive_interval(Some(std::time::Duration::from_secs(5)));
let mut client_config = quinn::ClientConfig::new(Arc::new(client_crypto));
client_config.transport_config(transport.into());
//let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap())?;
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(client_config);
info!("connecting");
let conn = endpoint.connect(host.parse()?, host_name)?.await?;
// authenticating client
{
let (mut send, mut recv) = conn.open_bi().await?;
if use_key {
auth::ssh_key::client_authenticate(cfg, &mut send, &mut recv, username).await?;
} else {
auth::password::client_authenticate(&conn, &mut send, &mut recv, username).await?;
}
}
// authenticated client
for forward in forwards {
match forward.direction {
ForwardDirection::LocalToRemote => {
let conn = conn.clone();
tokio::task::spawn(async move {
if let Err(e) = do_forward_to(&conn, forward).await {
error!("Error in local-to-remote forwarder: {}", e);
}
});
}
ForwardDirection::RemoteToLocal => todo!(),
}
}
let _reset = {
use nix::sys::termios::*;
struct Reset(Termios);
impl Drop for Reset {
fn drop(&mut self) {
//_ = crossterm::terminal::disable_raw_mode();
_ = tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &self.0);
println!("termios reset!");
}
}
let mut termios = tcgetattr(libc::STDIN_FILENO)?;
let reset = Reset(termios.clone());
termios.local_flags.remove(LocalFlags::ECHO);
termios.local_flags.remove(LocalFlags::ICANON);
termios.local_flags.remove(LocalFlags::ISIG);
termios.local_flags.remove(LocalFlags::IEXTEN);
termios.input_flags.remove(InputFlags::IXON);
termios.input_flags.remove(InputFlags::ICRNL);
termios.output_flags.remove(OutputFlags::OPOST);
tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &termios)?;
reset
};
do_shell(&conn).await
}
async fn open_stream(
conn: &quinn::Connection,
stream: &Stream<'_>,
) -> Result<(SendStream, RecvStream)> {
let (mut send, recv) = conn.open_bi().await?;
write_msg(&mut send, stream).await?;
info!("connected");
Ok((send, recv))
}
pin_project_lite::pin_project! {
struct SendRecvStream {
#[pin] send: SendStream,
#[pin] recv: RecvStream,
}
}
impl tokio::io::AsyncWrite for SendRecvStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
self.project().send.poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().send.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().send.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
self.project().send.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.send.is_write_vectored()
}
}
impl tokio::io::AsyncRead for SendRecvStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().recv.poll_read(cx, buf)
}
}
async fn do_forward_to(conn: &quinn::Connection, spec: ForwardSpec) -> Result<()> {
assert_eq!(spec.direction, ForwardDirection::LocalToRemote);
match spec.protocol {
ForwardProtocol::Tcp => {
let socket = TcpSocket::new_v4()?;
socket.bind(spec.bind.into())?;
let list = socket.listen(1024)?;
loop {
let (mut local_stream, _peer) = list.accept().await?;
let (send, recv) = open_stream(
conn,
&Stream::Forward {
addr: spec.forward.0,
port: spec.forward.1,
protocol: spec.protocol,
},
)
.await?;
let mut remote_stream = SendRecvStream { send, recv };
tokio::task::spawn(async move {
tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?;
Result::<_>::Ok(())
});
}
}
ForwardProtocol::Udp => {
// this will require more thought into how to keep track of the peer.
todo!()
/*let socket = UdpSocket::bind(spec.bind).await?;
let mut buf = Vec::with_capacity(4096);
let (send, recv) = open_stream(conn, &Stream::Forward {
addr: spec.forward.0,
port: spec.forward.1,
protocol: spec.protocol,
}).await?;
loop {
socket.recv_buf(&mut buf).await?;
let mut remote_stream = SendRecvStream(send, recv);
tokio::task::spawn(async move {
tokio::io::copy_bidirectional(&mut socket, &mut remote_stream).await?;
Result::<_>::Ok(())
});
}*/
}
}
}
async fn do_shell(conn: &quinn::Connection) -> Result<()> {
let env_term = std::env::var("TERM");
let env_terminfo = std::env::var("TERMINFO");
let env_term = if let (Ok(name), Ok(path)) = (env_term, env_terminfo) {
let info = terminfo::read_terminfo(&path, &name).await?;
Some((name, info))
} else {
None
};
let stream = &Stream::Shell {
env_term: env_term.as_ref().map(|(name, info)| {
Term { name, info: info.as_slice().into() }
}),
command: None,
};
let (mut send, mut recv) = open_stream(&conn, &stream).await?;
let mut stdin = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDIN_FILENO)) };
let mut stdout =
unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) };
let mut stdin_buf = Vec::with_capacity(4096);
//let mut stdout_buf = Vec::with_capacity(4096);
let mut stdout_buf = vec![bytes::Bytes::new(); 128];
let mut stdin_eof = false;
loop {
tokio::select! {
/*r = tokio::io::copy(&mut stdin, &mut send) => {
r?;
info!("EOF on stdin");
}*/
r = stdin.read_buf(&mut stdin_buf), if !stdin_eof => {
if r? == 0 {
stdin_eof = true;
}
send.write_all(&stdin_buf).await?;
stdin_buf.clear();
//info!("sent stdin");
}
r = recv.read_chunks(&mut stdout_buf) => {
if let Some(n) = r? {
for chunk in &stdout_buf[..n] {
stdout.write_all(&chunk).await?;
}
//info!("recv stdout");
}
}
/*r = recv.read_buf(&mut stdout_buf) => if r? > 0 {
stdout.write_all(&stdout_buf).await?;
stdout_buf.clear();
//info!("recv stdout");
}*/,
r = send.stopped() => {
info!("Remote disconnected");
let code = r?.into_inner();
if code == 0 {
return Ok(());
} else {
return Err(anyhow!("Error code {}", code));
}
}
e = conn.closed() => {
info!("Remote disconnected: {}", e);
return Err(anyhow!("Remote connection closed"));
}
}
}
}
#[cfg(feature = "server")]
async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Result<()> {
info!("greeting connection");
let conn = conn.await?;
let span = info_span!(
"connection",
remote = %conn.remote_address(),
protocol = %conn
.handshake_data()
.unwrap()
.downcast::<quinn::crypto::rustls::HandshakeData>().unwrap()
.protocol
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned()),
username = tracing::field::Empty,
);
let (user_info, env) = match authenticate_conn(cfg, &conn).instrument(span.clone()).await {
Ok(t) => t,
Err(e) => {
error!("authentication failed: {}", e.to_string());
conn.close(1u8.into(), b"authentication error");
return Ok(());
}
};
span.record("username", &user_info.user.name);
if let Err(e) = handle_conn(cfg, &conn, user_info, env).instrument(span).await {
error!("handler failed: {}", e.to_string());
conn.close(1u8.into(), b"handler error");
}
Ok(())
}
#[cfg(feature = "server")]
async fn authenticate_conn(
_cfg: &'static ServerConfig,
conn: &quinn::Connection,
) -> Result<(UserInfo, Vec<(CString, CString)>)> {
info!("authenticating connection");
let (mut send, mut recv) = conn.accept_bi().await?;
let mut hello_buf = Vec::new();
let hello = read_msg::<auth::Hello>(&mut recv, &mut hello_buf).await??;
let user_info = user_info::get_user_info(&hello.username).await?;
match hello.auth_method {
auth::Method::Password => auth::password::server_authenticate(&mut send, &mut recv, hello.username.to_owned()).await?,
auth::Method::SshKey { public_key } => auth::ssh_key::server_authenticate(&mut send, &mut recv, user_info.user.id, public_key).await?,
}
let env = Vec::new();
info!("logged in");
write_msg(&mut send, &auth::Question::LoggedIn).await?;
send.finish().await?;
recv.stop(0u8.into())?;
Ok((user_info, env))
}
#[cfg(feature = "server")]
async fn handle_conn(
cfg: &'static ServerConfig,
conn: &quinn::Connection,
user_info: UserInfo,
env: Vec<(CString, CString)>,
) -> Result<()> {
info!("established");
loop {
let stream = conn.accept_bi().await;
let (send, mut recv) = match stream {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
info!("connection closed");
return Ok(());
}
Err(e) => {
return Err(e.into());
}
Ok(s) => s,
};
let id = send.id();
let span = info_span!(
"stream",
dir = %match id.dir() {
quinn_proto::Dir::Bi => "bi",
quinn_proto::Dir::Uni => "uni",
},
id = id.index()
);
let user_info = user_info.clone();
let env = env.clone();
tokio::task::spawn(
async move {
let mut stream_buf = Vec::new();
let r = match read_msg::<Stream>(&mut recv, &mut stream_buf).await {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(e.into()),
Err(e) => Err(e),
};
let stream = match r {
Ok(t) => t,
Err(e) => {
error!("Error in stream setup: {}", e);
return;
}
};
let r = match stream {
Stream::Exec => {
let span = info_span!("stream_exec");
handle_stream_exec(cfg, send, recv, &user_info)
.instrument(span)
.await
}
Stream::Shell { env_term, command } => {
let span = info_span!("stream_shell", ?command);
handle_stream_shell(cfg, send, recv, &user_info, env, env_term, command)
.instrument(span)
.await
}
Stream::Forward {
addr,
port,
protocol,
} => {
let span = info_span!("stream_forward", %addr, %port, %protocol);
handle_stream_forward(cfg, send, recv, addr, port, protocol)
.instrument(span)
.await
}
};
if let Err(e) = r {
error!("Error in stream handler: {}", e);
}
}
.instrument(span),
);
}
}
async fn handle_stream_exec(
_cfg: &ServerConfig,
_send: SendStream,
_recv: RecvStream,
_user_info: &UserInfo,
) -> Result<()> {
todo!()
}
async fn handle_stream_shell(
_cfg: &ServerConfig,
mut send: SendStream,
mut recv: RecvStream,
user_info: &UserInfo,
env: Vec<(CString, CString)>,
env_term: Option<Term<'_>>,
command: Option<(CStrAsBytes<'_>, Vec<CStrAsBytes<'_>>)>,
) -> Result<()> {
let mut user_passwd_buf = Vec::new();
let user_passwd = passwd::Passwd::from_uid(user_info.user.id, &mut user_passwd_buf)?;
let user_home = user_passwd.as_ref().map(|p| p.dir);
let shell = command
.as_ref()
.map(|(command, _)| command.as_ref())
.unwrap_or_else(|| {
user_passwd
.as_ref()
.map(|p| p.shell)
.unwrap_or(CStr::from_bytes_with_nul(b"/bin/sh\0").unwrap())
});
let shell_name = CStr::from_bytes_with_nul(
shell
.to_bytes_with_nul()
.iter()
.rposition(|b| *b == b'/')
.map(|i| &shell.to_bytes_with_nul()[i + 1..])
.unwrap_or(shell.to_bytes_with_nul()),
)?;
let args = command
.as_ref()
.map(|(_, args)| args.iter().map(|s| s.as_ref()).collect::<Vec<_>>())
.unwrap_or_else(|| vec![&shell_name]);
let opt_shell = shell;
let c_user_name = CString::new(user_info.user.name.as_str())?;
let user_home = user_home.as_ref().and_then(|s| s.to_str().ok());
const TERMINFO_PATH: &str = "/.local/share/quinoa/terminfo";
let terminfo_path = if let Some(user_home) = user_home {
let mut terminfo_path = String::with_capacity(user_home.len() + TERMINFO_PATH.len());
terminfo_path.push_str(user_home);
for comp in TERMINFO_PATH.split_inclusive('/') {
terminfo_path.push_str(comp);
if comp != "/" {
io_util::ignore_already_exists(io_util::create_dir_owned(
&terminfo_path,
user_info.user.id,
user_info.group.id,
))?;
}
}
Some(terminfo_path)
} else {
None
};
if let (Some(env_term), Some(terminfo_path)) = (&env_term, &terminfo_path) {
terminfo::install_terminfo(
terminfo_path,
&env_term.name,
&env_term.info,
(user_info.user.id, user_info.group.id),
)
.await?;
}
let c_env_term = env_term.map(|t| CString::new(t.name)).transpose()?;
let c_env_terminfo = terminfo_path.map(CString::new).transpose()?;
let pre_exec = move || {
fn unsetenv(name: &CStr) -> Result<(), nix::Error> {
if unsafe { libc::unsetenv(name.as_ptr()) } != 0 {
Err(nix::Error::last().into())
} else {
Ok(())
}
}
fn setenv(name: &CStr, value: &CStr, overwrite: bool) -> Result<(), nix::Error> {
if unsafe { libc::setenv(name.as_ptr(), value.as_ptr(), overwrite as i32) } != 0 {
Err(nix::Error::last().into())
} else {
Ok(())
}
}
fn getenv(name: &CStr) -> Option<NonNull<CStr>> {
unsafe {
let ptr = libc::getenv(name.as_ptr());
NonNull::new(CStr::from_ptr(ptr) as *const _ as *mut _)
}
}
extern "C" {
#[cfg(any(target_os = "macos", target_os = "ios"))]
pub fn _NSGetEnviron() -> *mut *const *const std::ffi::c_char;
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
static mut environ: *const *const std::ffi::c_char;
}
let mut keep = Vec::new();
const PRESERVE_ENV: &[&[u8]] = &[b"PATH\0", b"LANG\0"];
unsafe {
#[cfg(any(target_os = "macos", target_os = "ios"))]
let mut env = *_NSGetEnviron();
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
let mut env = environ;
if !env.is_null() {
while !(*env).is_null() {
let key_value = CStr::from_ptr(*env).to_bytes_with_nul();
env = env.add(1);
let Some(i) = key_value.iter().position(|b| *b == b'=') else {
continue
};
let key = &key_value[..i];
let value = &key_value[i + 1..];
if PRESERVE_ENV.iter().any(|s| &s[..s.len() - 1] == key) {
keep.push((key, value));
}
}
}
}
unsafe { nix::env::clearenv() }?;
for (k, v) in keep
.into_iter()
.chain(env.iter().map(|(k, v)| (k.as_bytes(), v.as_bytes())))
{
setenv(&CString::new(k)?, CStr::from_bytes_with_nul(v)?, true)?;
}
setenv(
CStr::from_bytes_with_nul(b"USER\0").unwrap(),
&c_user_name,
true,
)?;
if let Some(c_env_terminfo) = &c_env_terminfo {
// otherwise system terminfo will be used
setenv(
CStr::from_bytes_with_nul(b"TERMINFO\0").unwrap(),
&c_env_terminfo,
true,
)?;
}
if let Some(c_env_term) = &c_env_term {
setenv(
CStr::from_bytes_with_nul(b"TERM\0").unwrap(),
c_env_term,
true,
)?;
}
#[cfg(not(any(
target_os = "macos",
target_os = "ios",
target_os = "redox",
target_os = "haiku"
)))]
nix::unistd::setgroups(&user_info.groups.iter().map(|g| g.id).collect::<Vec<_>>())
.context("setting supplementary groups")?;
nix::unistd::setgid(user_info.group.id).context("setting primary group")?;
nix::unistd::setuid(user_info.user.id).context("setting user")?;
if nix::unistd::seteuid(Uid::from_raw(0)).is_ok() {
return Err(anyhow!("We got the privileges back via seteuid. Very bad!"));
}
if nix::unistd::setuid(Uid::from_raw(0)).is_ok() {
return Err(anyhow!("We got the privileges back via setuid. Very bad!"));
}
std::env::set_current_dir("/")?;
Ok(())
};
let pre_exec = unsafe { pty::PreExec::new(pre_exec) };
let mut sh = pty::create_pty(opt_shell, &args, pre_exec)?;
sh.set_nodelay()?;
let mut pty = sh.pty;
//let mut pty = tokio::io::unix::AsyncFd::with_interest(pty, tokio::io::Interest::READABLE)?;
info!("created pty");
//let mut stdin_buf = Vec::with_capacity(4096);
let mut pty_buf = Vec::with_capacity(4096);
//let mut pty_buf = [0u8];
let (_waker_send, mut waker_recv) = flume::bounded::<()>(1);
/*let fd = pty.as_raw_fd();
std::thread::spawn(move || {
let mut set = [nix::poll::PollFd::new(fd.as_raw_fd(), nix::poll::PollFlags::POLLIN)];
loop {
if let Ok(n) = nix::poll::poll(&mut set, -1) {
if n != 0 {
if waker_send.send(()).is_err() {
break;
}
}
}
}
});*/
loop {
/*if let Some(code) = sh.proc.try_wait()? {
send.finish().await?;
if code != 0 {
info!("Child exit: {}", code);
recv.stop(1u8.into())?;
return Ok(());
} else {
info!("Child exit");
recv.stop(0u8.into())?;
return Ok(());
}
}*/
//let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50));
struct Wait<'a> {
proc: &'a pty::Proc,
}
impl<'a> Wait<'a> {
pub fn new(proc: &'a pty::Proc) -> Self {
Self { proc }
}
}
impl<'a> Future for Wait<'a> {
type Output = std::io::Result<i32>;
fn poll(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
match self.proc.try_wait() {
Ok(Some(code)) => Poll::Ready(Ok(code)),
Ok(None) => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}
/*pin_project_lite::pin_project! {
struct SelectRead<T: AsRawFd> {
//#[allow(dead_code)]
fd: T,
#[pin]
list: triggered::Listener,
}
}
impl<T: AsRawFd> SelectRead<T> {
pub fn new(fd: T) -> Self {
let (trig, list) = triggered::trigger();
Self { fd, list, set }
}
}
impl<T: AsRawFd> Future for SelectRead<T> {
type Output = nix::Result<()>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
//let mut this = self;
let this = self.project();
match this.list.poll(cx) {
Poll::Ready(()) => {
let r = nix::poll::poll(this.set, 0);
match r {
Ok(0) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Ok(_) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e)),
}
},
Poll::Pending => {
//cx.waker().wake_by_ref();
Poll::Pending
},
}
//let mut set = *this.set;
/*let r = nix::sys::select::select(
set.highest().unwrap() + 1,
&mut set,
None,
None,
&mut nix::sys::time::TimeVal::new(0, 0),
);*/
}
}*/
async fn read_pty(
//pty: &mut tokio::io::unix::AsyncFd<tokio::fs::File>,
pty: &mut tokio::fs::File,
_waker_recv: &mut flume::Receiver<()>,
buf: &mut Vec<u8>,
send: &mut SendStream,
) -> Result<()> {
loop {
//let mut pty = pty.readable_mut().await?;
//let pty = pty.get_inner_mut();
//SelectRead::new(pty.as_fd()).await?;
//_ = waker_recv.recv_async().await;
let r = pty.read_buf(buf).await;
//_ = waker_recv.try_recv();
if let Err(e) = r {
if e.raw_os_error() == Some(35) || e.raw_os_error() == Some(11) {
//debug!("not ready: {}", e);
//tokio::task::yield_now().await;
//tokio::time::sleep(std::time::Duration::from_millis(1)).await;
} else {
return Err(e.into());
}
} else if buf.len() == 0 {
debug!("not ready: empty");
//tokio::task::yield_now().await;
//tokio::time::sleep(std::time::Duration::from_millis(1)).await;
} else {
//return Ok(());
send.write_all(&buf).await?;
buf.clear();
//info!("sent pty");
}
}
}
tokio::select! {
/*_ = redraw.tick() => {
sh.pty.read_buf(&mut pty_buf).await?;
send.write_all(&pty_buf).await?;
pty_buf.clear();
info!("redraw complete");
}*/
r = Wait::new(&sh.proc) => {
let code = r?;
send.finish().await?;
if code != 0 {
info!("Child exit: {}", code);
recv.stop(1u8.into())?;
return Ok(());
} else {
info!("Child exit");
recv.stop(0u8.into())?;
return Ok(());
}
}
/*r = tokio::io::copy(&mut pty, &mut send) => {
r?;
}*/
r = read_pty(&mut pty, &mut waker_recv, &mut pty_buf, &mut send) => {
r?;
}
// FIXME: figure out a maximum chunk size
r = recv.read_chunk(usize::MAX, true) => {
if let Some(chunk) = r? {
pty.write_all(&chunk.bytes).await?;
info!("recv stdin");
}
}
/*r = recv.read_buf(&mut stdin_buf) => if r? > 0 {
//pty.get_mut().write_all(&stdin_buf).await?;
pty.write_all(&stdin_buf).await?;
stdin_buf.clear();
info!("recv stdin");
},*/
}
}
}
async fn handle_stream_forward(
_cfg: &ServerConfig,
send: SendStream,
recv: RecvStream,
addr: Ipv4Addr,
port: u16,
protocol: ForwardProtocol,
) -> Result<()> {
match protocol {
ForwardProtocol::Tcp => {
let socket = TcpSocket::new_v4()?;
let mut local_stream = socket.connect(SocketAddrV4::new(addr, port).into()).await?;
let mut remote_stream = SendRecvStream { send, recv };
tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?;
Result::<_>::Ok(())
}
ForwardProtocol::Udp => todo!(),
}
}