diff --git a/Cargo.toml b/Cargo.toml index 54ac2c5..63b0c7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ serde = { version = "1.0.163", features = ["derive"] } sha2 = "0.10.6" #termion = "2.0.1" tokio = { version = "1.28.2", default-features = false, features = ["rt-multi-thread", "macros", "process", "io-util", "io-std", "time", "fs", "signal"] } -tracing = "0.1.37" +tracing = { version = "0.1.37", features = ["max_level_debug", "release_max_level_info"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } triggered = "0.1.2" diff --git a/src/main.rs b/src/main.rs index 45f0806..a3b9605 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ -#[deny(unused_must_use)] +#![deny(unreachable_code)] +#![deny(unused_must_use)] + #[macro_use] extern crate anyhow; #[macro_use] @@ -14,14 +16,16 @@ mod user_info; #[cfg(all(feature = "server", target_os = "macos"))] use pam_client_macos as pam_client; +use tokio::net::TcpSocket; use std::ffi::{CStr, CString}; use std::fmt; use std::future::Future; use std::mem::ManuallyDrop; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::os::fd::FromRawFd; use std::os::unix::process::CommandExt; +use std::pin::Pin; use std::process::Stdio; use std::ptr::NonNull; use std::str::FromStr; @@ -30,7 +34,7 @@ use std::task::Poll; use anyhow::{Context, Result}; use base64::Engine as _; -use nix::unistd::{Gid, Uid}; +use nix::unistd::Uid; #[cfg(feature = "server")] use pam_client::ConversationHandler; use quinn::{ReadExactError, RecvStream, SendStream}; @@ -41,14 +45,106 @@ use tracing::Instrument; use user_info::UserInfo; use webpki::SubjectNameRef; -use crate::user_info::get_user_info; - #[derive(Debug, Clone, Serialize, Deserialize)] struct Term { name: String, info: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +enum ForwardDirection { + LocalToRemote, + RemoteToLocal, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +enum ForwardProtocol { + Tcp, + Udp, +} + +impl ForwardProtocol { + pub const fn name(self) -> &'static str { + match self { + ForwardProtocol::Tcp => "tcp", + ForwardProtocol::Udp => "udp", + } + } +} + +impl fmt::Display for ForwardProtocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.name()) + } +} + +impl FromStr for ForwardProtocol { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + const TCP: &str = ForwardProtocol::Tcp.name(); + const UDP: &str = ForwardProtocol::Udp.name(); + match s { + TCP => Ok(Self::Tcp), + UDP => Ok(Self::Udp), + _ => Err(anyhow!("Expected one of tcp, udp: {:?}", s)), + } + } +} + +#[derive(Debug, Clone)] +struct ForwardSpec { + direction: ForwardDirection, + bind: SocketAddrV4, + // TODO: allow DNS lookup here + forward: (Ipv4Addr, u16), + protocol: ForwardProtocol, +} + +impl FromStr for ForwardSpec { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let i = s + .as_bytes() + .iter() + .position(|&b| b == b'-' || b == b'<') + .with_context(|| format!("Expected a local-remote delimiter: {:?}", s))?; + let direction = match s.as_bytes().get(i..i + 2) { + Some(b"->") => ForwardDirection::LocalToRemote, + Some(b"<-") => ForwardDirection::RemoteToLocal, + _ => bail!("Expected a local-remote delimiter: {:?}", s), + }; + let local = &s[..i]; + let remote = &s[i + 2..]; + let (bind, forward) = match direction { + ForwardDirection::LocalToRemote => (local, remote), + ForwardDirection::RemoteToLocal => (remote, local), + }; + let (bind_sock, protocol) = bind + .split_once('/') + .context("Expected a port-protocol delimiter")?; + let bind_sock = bind_sock.parse::()?; + let protocol = protocol.parse()?; + let (forward_addr, forward_port) = forward + .split_once(':') + .context("Expected an address-port delimiter")?; + let forward_addr = forward_addr + .parse::() + .with_context(|| forward_addr.to_owned())?; + let forward_port = forward_port + .parse::() + .with_context(|| forward_port.to_owned())?; + + Ok(Self { + direction, + bind: bind_sock, + forward: (forward_addr, forward_port), + protocol, + }) + } +} + #[derive(Clone, Serialize, Deserialize)] enum Stream { Exec, @@ -56,7 +152,12 @@ enum Stream { env_term: Option, command: Option<(CString, Vec)>, }, - // TODO: port forwarding + Forward { + addr: Ipv4Addr, + port: u16, + protocol: ForwardProtocol, + }, + // TODO: "backward" } impl fmt::Debug for Stream { @@ -67,16 +168,23 @@ impl fmt::Debug for Stream { env_term: _, command, } => f.debug_struct("Shell").field("command", command).finish(), + Self::Forward { + addr, + port, + protocol, + } => write!(f, "-> {}:{}/{}", addr, port, protocol), } } } #[tokio::main] async fn main() -> Result<()> { + use tracing_subscriber::fmt::format::FmtSpan; tracing::subscriber::set_global_default( tracing_subscriber::FmtSubscriber::builder() .with_writer(std::io::stderr) .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .finish(), ) .unwrap(); @@ -565,8 +673,28 @@ async fn run_client(mut args: std::env::Args) -> Result<()> { .into(), ); - let conn_str = args.next().expect("USERNAME@HOST"); - let (username, host) = conn_str.split_once('@').expect("USERNAME@HOST"); + let mut conn_str = None; + let mut forwards = Vec::new(); + while let Some(arg) = args.next() { + if let Some(arg) = arg.strip_prefix('-') { + match arg { + "-forward" => { + let v = args.next().context("Expected a value of the form LOCAL_ADDR:LOCAL_PORT[->|<-]REMOTE_ADDR:REMOTE_PORT (and /[tcp|udp] on bind side)")?; + forwards.push(v.parse::()?); + } + _ => bail!("Unrecognized option: -{}", arg), + } + } else if conn_str.is_none() { + conn_str = Some(arg); + } else { + bail!("Unexpected argument: {:?}", arg); + } + } + + let (username, host) = conn_str + .as_ref() + .and_then(|s| s.split_once('@')) + .context("Expected an argument of the form USERNAME@HOST")?; let mut client_crypto = rustls::ClientConfig::builder() .with_safe_defaults() @@ -604,6 +732,20 @@ async fn run_client(mut args: std::env::Args) -> Result<()> { // authenticated client + for forward in forwards { + match forward.direction { + ForwardDirection::LocalToRemote => { + let conn = conn.clone(); + tokio::task::spawn(async move { + if let Err(e) = do_forward_to(&conn, forward).await { + error!("Error in local-to-remote forwarder: {}", e); + } + }); + } + ForwardDirection::RemoteToLocal => todo!(), + } + } + let _reset = { use nix::sys::termios::*; @@ -629,35 +771,136 @@ async fn run_client(mut args: std::env::Args) -> Result<()> { reset }; - let (mut send, mut recv) = conn.open_bi().await?; - - let env_term = std::env::var("TERM"); - let env_terminfo = std::env::var("TERMINFO"); - - write_msg( - &mut send, - &Stream::Shell { - env_term: if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { - let info = terminfo::read_terminfo(&path, &name).await?; - Some(Term { name, info }) - } else { - None - }, - command: None, - }, - ) - .await?; - - info!("connected"); - - do_shell(&conn, &mut send, &mut recv).await + do_shell(&conn).await } -async fn do_shell( +async fn open_stream( conn: &quinn::Connection, - send: &mut SendStream, - recv: &mut RecvStream, -) -> Result<()> { + stream: &Stream, +) -> Result<(SendStream, RecvStream)> { + let (mut send, recv) = conn.open_bi().await?; + write_msg(&mut send, stream).await?; + info!("connected"); + Ok((send, recv)) +} + +pin_project_lite::pin_project! { + struct SendRecvStream { + #[pin] send: SendStream, + #[pin] recv: RecvStream, + } +} + +impl tokio::io::AsyncWrite for SendRecvStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().send.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().send.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().send.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().send.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.send.is_write_vectored() + } +} + +impl tokio::io::AsyncRead for SendRecvStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.project().recv.poll_read(cx, buf) + } +} + +async fn do_forward_to(conn: &quinn::Connection, spec: ForwardSpec) -> Result<()> { + assert_eq!(spec.direction, ForwardDirection::LocalToRemote); + + match spec.protocol { + ForwardProtocol::Tcp => { + let socket = TcpSocket::new_v4()?; + socket.bind(spec.bind.into())?; + let list = socket.listen(1024)?; + loop { + let (mut local_stream, peer) = list.accept().await?; + let (send, recv) = open_stream( + conn, + &Stream::Forward { + addr: spec.forward.0, + port: spec.forward.1, + protocol: spec.protocol, + }, + ) + .await?; + let mut remote_stream = SendRecvStream { send, recv }; + tokio::task::spawn(async move { + tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?; + Result::<_>::Ok(()) + }); + } + } + ForwardProtocol::Udp => { + // this will require more thought into how to keep track of the peer. + todo!() + /*let socket = UdpSocket::bind(spec.bind).await?; + let mut buf = Vec::with_capacity(4096); + let (send, recv) = open_stream(conn, &Stream::Forward { + addr: spec.forward.0, + port: spec.forward.1, + protocol: spec.protocol, + }).await?; + loop { + socket.recv_buf(&mut buf).await?; + + let mut remote_stream = SendRecvStream(send, recv); + tokio::task::spawn(async move { + tokio::io::copy_bidirectional(&mut socket, &mut remote_stream).await?; + Result::<_>::Ok(()) + }); + }*/ + } + } +} + +async fn do_shell(conn: &quinn::Connection) -> Result<()> { + let env_term = std::env::var("TERM"); + let env_terminfo = std::env::var("TERMINFO"); + let stream = &Stream::Shell { + env_term: if let (Ok(name), Ok(path)) = (env_term, env_terminfo) { + let info = terminfo::read_terminfo(&path, &name).await?; + Some(Term { name, info }) + } else { + None + }, + command: None, + }; + + let (mut send, mut recv) = open_stream(&conn, &stream).await?; + let mut stdin = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDIN_FILENO)) }; let mut stdout = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) }; @@ -781,11 +1024,22 @@ async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Resu .unwrap() .downcast::().unwrap() .protocol - .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) + .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()), + username = tracing::field::Empty, ); - if let Err(e) = authenticate_conn(cfg, &conn).instrument(span).await { - error!("handler failed: {reason}", reason = e.to_string()); + let (user_info, env) = match authenticate_conn(cfg, &conn).instrument(span.clone()).await { + Ok(t) => t, + Err(e) => { + error!("authentication failed: {}", e.to_string()); + conn.close(1u8.into(), b"authentication error"); + return Ok(()); + } + }; + span.record("username", &user_info.user.name); + + if let Err(e) = handle_conn(cfg, &conn, user_info, env).await { + error!("handler failed: {}", e.to_string()); conn.close(1u8.into(), b"handler error"); } @@ -815,7 +1069,10 @@ mod auth { } #[cfg(feature = "server")] -async fn authenticate_conn(cfg: &'static ServerConfig, conn: &quinn::Connection) -> Result<()> { +async fn authenticate_conn( + cfg: &'static ServerConfig, + conn: &quinn::Connection, +) -> Result<(UserInfo, Vec<(CString, CString)>)> { use auth::*; info!("authenticating connection"); @@ -903,24 +1160,26 @@ async fn authenticate_conn(cfg: &'static ServerConfig, conn: &quinn::Connection) // very odd behaviour: // if user is currently logged in, we must call acct_mgmt first. // if user is not currently logged in, we must call acc_mgmt first and ignore the error. - let r = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt"); - if cfg!(not(any(target_os = "macos", target_os = "ios"))) { - r?; - info!("validated user"); - } else { - info!("not validating user due to macOS oddity"); + if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") { + if cfg!(any(target_os = "macos", target_os = "ios")) { + warn!("ignoring validation error due to macOS oddity: {}", e); + } else { + return Err(e); + } } + info!("validated user"); - let sess = ctx + /*let sess = ctx .open_session(pam_client::Flag::NONE) .context("open_session")?; info!("opened session"); - let sess = sess.leak(); + let sess = sess.leak();*/ let conv = ctx.conversation_mut(); conv.send = flume::bounded(0).0; conv.recv = flume::bounded(0).1; - Result::<_>::Ok((ctx, sess)) + //Result::<_>::Ok((ctx, sess)) + Result::<_>::Ok((ctx, ())) }); while let Ok(question) = q_recv.recv_async().await { @@ -1007,33 +1266,8 @@ async fn authenticate_conn(cfg: &'static ServerConfig, conn: &quinn::Connection) } let (mut ctx, sess) = hdl.await??; - let sess = ctx.unleak_session(sess); + /*let sess = ctx.unleak_session(sess); let env = sess.envlist(); - info!("logged in: {}", env); - - write_msg(&mut send, &Question::LoggedIn).await?; - send.finish().await?; - recv.stop(0u8.into())?; - - let user_info = get_user_info(&hello.username).await?; - - let span = info_span!("logged_in", username = user_info.user.name,); - - // TODO: move out of authenticate_conn, return necessary info - handle_conn(cfg, conn, user_info, env) - .instrument(span) - .await -} - -#[cfg(feature = "server")] -async fn handle_conn( - cfg: &'static ServerConfig, - conn: &quinn::Connection, - user_info: UserInfo, - env: pam_client::env_list::EnvList, -) -> Result<()> { - info!("established"); - let env = env .into_iter() .filter_map(|pair| { @@ -1046,7 +1280,27 @@ async fn handle_conn( let v = CString::new(&element[sep + 1..]).ok()?; Some((k, v)) }) - .collect::>(); + .collect::>();*/ + let env = Vec::new(); + info!("logged in"); + + write_msg(&mut send, &Question::LoggedIn).await?; + send.finish().await?; + recv.stop(0u8.into())?; + + let user_info = user_info::get_user_info(&hello.username).await?; + + Ok((user_info, env)) +} + +#[cfg(feature = "server")] +async fn handle_conn( + cfg: &'static ServerConfig, + conn: &quinn::Connection, + user_info: UserInfo, + env: Vec<(CString, CString)>, +) -> Result<()> { + info!("established"); loop { let stream = conn.accept_bi().await; @@ -1071,9 +1325,26 @@ async fn handle_conn( tokio::task::spawn( async move { let r = match stream { - Stream::Exec => handle_stream_exec(cfg, send, recv, &user_info).await, + Stream::Exec => { + let span = info_span!("stream_exec"); + handle_stream_exec(cfg, send, recv, &user_info) + .instrument(span) + .await + } Stream::Shell { env_term, command } => { + let span = info_span!("stream_shell", ?command); handle_stream_shell(cfg, send, recv, &user_info, env, env_term, command) + .instrument(span) + .await + } + Stream::Forward { + addr, + port, + protocol, + } => { + let span = info_span!("stream_forward", %addr, %port, %protocol); + handle_stream_forward(cfg, send, recv, addr, port, protocol) + .instrument(span) .await } }; @@ -1225,112 +1496,111 @@ async fn handle_stream_shell( let c_env_term = env_term.map(|t| CString::new(t.name)).transpose()?; let c_env_terminfo = terminfo_path.map(CString::new).transpose()?; - let pre_exec = unsafe { - pty::PreExec::new(move || { - fn unsetenv(name: &CStr) -> Result<(), nix::Error> { - if unsafe { libc::unsetenv(name.as_ptr()) } != 0 { - Err(nix::Error::last().into()) - } else { - Ok(()) - } + let pre_exec = move || { + fn unsetenv(name: &CStr) -> Result<(), nix::Error> { + if unsafe { libc::unsetenv(name.as_ptr()) } != 0 { + Err(nix::Error::last().into()) + } else { + Ok(()) } + } - fn setenv(name: &CStr, value: &CStr, overwrite: bool) -> Result<(), nix::Error> { - if unsafe { libc::setenv(name.as_ptr(), value.as_ptr(), overwrite as i32) } != 0 { - Err(nix::Error::last().into()) - } else { - Ok(()) - } + fn setenv(name: &CStr, value: &CStr, overwrite: bool) -> Result<(), nix::Error> { + if unsafe { libc::setenv(name.as_ptr(), value.as_ptr(), overwrite as i32) } != 0 { + Err(nix::Error::last().into()) + } else { + Ok(()) } + } - fn getenv(name: &CStr) -> Option> { - unsafe { - let ptr = libc::getenv(name.as_ptr()); - NonNull::new(CStr::from_ptr(ptr) as *const _ as *mut _) - } - } - - extern "C" { - #[cfg(any(target_os = "macos", target_os = "ios"))] - pub fn _NSGetEnviron() -> *mut *const *const std::os::raw::c_char; - #[cfg(not(any(target_os = "macos", target_os = "ios")))] - static mut environ: *const *const c_char; - } - - let mut keep = Vec::new(); - const PRESERVE_ENV: &[&[u8]] = &[b"PATH\0", b"LANG\0"]; + fn getenv(name: &CStr) -> Option> { unsafe { - #[cfg(any(target_os = "macos", target_os = "ios"))] - let mut environ = *_NSGetEnviron(); - #[cfg(not(any(target_os = "macos", target_os = "ios")))] - let mut environ = environ; - if !environ.is_null() { - while !(*environ).is_null() { - let key_value = CStr::from_ptr(*environ).to_bytes_with_nul(); - environ = environ.add(1); - let Some(i) = key_value.iter().position(|b| *b == b'=') else { - continue - }; - let key = &key_value[..i]; - let value = &key_value[i + 1..]; - if PRESERVE_ENV.iter().any(|s| &s[..s.len() - 1] == key) { - keep.push((key, value)); - } + let ptr = libc::getenv(name.as_ptr()); + NonNull::new(CStr::from_ptr(ptr) as *const _ as *mut _) + } + } + + extern "C" { + #[cfg(any(target_os = "macos", target_os = "ios"))] + pub fn _NSGetEnviron() -> *mut *const *const std::os::raw::c_char; + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + static mut environ: *const *const c_char; + } + + let mut keep = Vec::new(); + const PRESERVE_ENV: &[&[u8]] = &[b"PATH\0", b"LANG\0"]; + unsafe { + #[cfg(any(target_os = "macos", target_os = "ios"))] + let mut environ = *_NSGetEnviron(); + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + let mut environ = environ; + if !environ.is_null() { + while !(*environ).is_null() { + let key_value = CStr::from_ptr(*environ).to_bytes_with_nul(); + environ = environ.add(1); + let Some(i) = key_value.iter().position(|b| *b == b'=') else { + continue + }; + let key = &key_value[..i]; + let value = &key_value[i + 1..]; + if PRESERVE_ENV.iter().any(|s| &s[..s.len() - 1] == key) { + keep.push((key, value)); } } } - nix::env::clearenv()?; - for (k, v) in keep - .into_iter() - .chain(env.iter().map(|(k, v)| (k.as_bytes(), v.as_bytes()))) - { - setenv(&CString::new(k)?, CStr::from_bytes_with_nul(v)?, true)?; - } + } + unsafe { nix::env::clearenv() }?; + for (k, v) in keep + .into_iter() + .chain(env.iter().map(|(k, v)| (k.as_bytes(), v.as_bytes()))) + { + setenv(&CString::new(k)?, CStr::from_bytes_with_nul(v)?, true)?; + } + setenv( + CStr::from_bytes_with_nul(b"USER\0").unwrap(), + &c_user_name, + true, + )?; + if let Some(c_env_terminfo) = &c_env_terminfo { + // otherwise system terminfo will be used setenv( - CStr::from_bytes_with_nul(b"USER\0").unwrap(), - &c_user_name, + CStr::from_bytes_with_nul(b"TERMINFO\0").unwrap(), + &c_env_terminfo, true, )?; - if let Some(c_env_terminfo) = &c_env_terminfo { - // otherwise system terminfo will be used - setenv( - CStr::from_bytes_with_nul(b"TERMINFO\0").unwrap(), - &c_env_terminfo, - true, - )?; - } - if let Some(c_env_term) = &c_env_term { - setenv( - CStr::from_bytes_with_nul(b"TERM\0").unwrap(), - c_env_term, - true, - )?; - } + } + if let Some(c_env_term) = &c_env_term { + setenv( + CStr::from_bytes_with_nul(b"TERM\0").unwrap(), + c_env_term, + true, + )?; + } - #[cfg(not(any( - target_os = "macos", - target_os = "ios", - target_os = "redox", - target_os = "haiku" - )))] - nix::unistd::setgroups(&user_info.groups.into_iter().map(|g| g.id).collect()) - .context("setting supplementary groups")?; + #[cfg(not(any( + target_os = "macos", + target_os = "ios", + target_os = "redox", + target_os = "haiku" + )))] + nix::unistd::setgroups(&user_info.groups.into_iter().map(|g| g.id).collect()) + .context("setting supplementary groups")?; - nix::unistd::setgid(user_info.group.id).context("setting primary group")?; - nix::unistd::setuid(user_info.user.id).context("setting user")?; + nix::unistd::setgid(user_info.group.id).context("setting primary group")?; + nix::unistd::setuid(user_info.user.id).context("setting user")?; - if nix::unistd::seteuid(Uid::from_raw(0)).is_ok() { - return Err(anyhow!("We got the privileges back via seteuid. Very bad!")); - } - if nix::unistd::setuid(Uid::from_raw(0)).is_ok() { - return Err(anyhow!("We got the privileges back via setuid. Very bad!")); - } + if nix::unistd::seteuid(Uid::from_raw(0)).is_ok() { + return Err(anyhow!("We got the privileges back via seteuid. Very bad!")); + } + if nix::unistd::setuid(Uid::from_raw(0)).is_ok() { + return Err(anyhow!("We got the privileges back via setuid. Very bad!")); + } - std::env::set_current_dir("/")?; + std::env::set_current_dir("/")?; - Ok(()) - }) + Ok(()) }; + let pre_exec = unsafe { pty::PreExec::new(pre_exec) }; let mut sh = pty::create_pty(opt_shell, &args, pre_exec)?; sh.set_nodelay()?; let mut pty = sh.pty; @@ -1530,3 +1800,23 @@ async fn handle_stream_shell( } } } + +async fn handle_stream_forward( + cfg: &ServerConfig, + send: SendStream, + recv: RecvStream, + addr: Ipv4Addr, + port: u16, + protocol: ForwardProtocol, +) -> Result<()> { + match protocol { + ForwardProtocol::Tcp => { + let socket = TcpSocket::new_v4()?; + let mut local_stream = socket.connect(SocketAddrV4::new(addr, port).into()).await?; + let mut remote_stream = SendRecvStream { send, recv }; + tokio::io::copy_bidirectional(&mut local_stream, &mut remote_stream).await?; + Result::<_>::Ok(()) + } + ForwardProtocol::Udp => todo!(), + } +} diff --git a/src/passwd.rs b/src/passwd.rs index 0354ce7..de20d03 100644 --- a/src/passwd.rs +++ b/src/passwd.rs @@ -77,7 +77,7 @@ impl<'a> Passwd<'a> { /// /// This is a shortcut for `Passwd::from_uid(libc::getuid())`. pub fn current_user(buf: &'a mut Vec) -> Result> { - Self::from_uid(unsafe { nix::unistd::getuid() }, buf) + Self::from_uid(nix::unistd::getuid(), buf) } unsafe fn from_c_struct(passwd: libc::passwd) -> Self { diff --git a/src/pty.rs b/src/pty.rs index 0bdd62e..96b7208 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -84,7 +84,11 @@ where /* Launch our child process. The main application loop can inspect and then pass the stdin data to it. */ let child_pid = match unsafe { nix::unistd::fork() } { - Ok(ForkResult::Child) => match init_child(path, argv, pre_exec, &slave_name).unwrap() {}, + Ok(ForkResult::Child) => { + let t = init_child(path, argv, pre_exec, &slave_name).unwrap(); + #[allow(unreachable_code)] + match t {} + }, Ok(ForkResult::Parent { child }) => child, Err(e) => panic!("{}", e), };