use std::{ffi::CString, mem::ManuallyDrop, os::fd::FromRawFd}; use anyhow::{Context, Result}; use quinn::{SendStream, RecvStream}; use tokio::io::AsyncWriteExt; use zeroize::Zeroizing; use crate::{write_msg, read_msg}; use super::{Question, Answer}; #[cfg(feature = "server")] pub async fn server_authenticate(send: &mut SendStream, recv: &mut RecvStream, username: String) -> Result<()> { use pam_client::ConversationHandler; use crate::Message; use super::*; let (q_send, q_recv) = flume::bounded(1); let (a_send, a_recv) = flume::bounded(1); struct Conversation { send: flume::Sender, recv: flume::Receiver, error: std::cell::Cell>, } impl Conversation { const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); fn error(&self, msg: &'static str) -> pam_client::ErrorCode { self.error.set(Some(msg)); pam_client::ErrorCode::CONV_ERR } fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> { self.send .send_timeout(Message::from_value(&question) .map_err(|_| self.error("Serialization error"))?, Self::TIMEOUT) .map_err(|_| self.error("Ask question timed out")) } fn prompt(&self, prompt: &CStr, echo: bool) -> Result { self.ask(Question::Prompt { prompt: prompt.into(), echo, })?; self.recv .recv_timeout(Self::TIMEOUT) .map_err(|_| self.error("Wait for answer timed out")) } } impl ConversationHandler for Conversation { fn prompt_echo_on( &mut self, prompt: &CStr, ) -> std::result::Result { self.prompt(prompt, true) } fn prompt_echo_off( &mut self, prompt: &CStr, ) -> std::result::Result { self.prompt(prompt, false) } fn text_info(&mut self, msg: &CStr) { _ = self.ask(Question::TextInfo(msg.into())); } fn error_msg(&mut self, msg: &CStr) { _ = self.ask(Question::ErrorMsg(msg.into())); } } let hdl = tokio::task::spawn_blocking(move || { let mut ctx = pam_client::Context::new( "sshd", Some(&username), Conversation { send: q_send, recv: a_recv, error: Default::default(), }, )?; info!("created context"); ctx.authenticate(pam_client::Flag::NONE) .with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error")) .context("authenticate")?; info!("authenticated user"); // very odd behaviour: // if user is currently logged in, we must call acct_mgmt first. // if user is not currently logged in, we must call acc_mgmt first and ignore the error. if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") { if cfg!(any(target_os = "macos", target_os = "ios")) { warn!("ignoring validation error due to macOS oddity: {}", e); } else { return Err(e).with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error")); } } info!("validated user"); /*let sess = ctx .open_session(pam_client::Flag::NONE) .context("open_session")?; info!("opened session"); let sess = sess.leak();*/ let conv = ctx.conversation_mut(); conv.send = flume::bounded(0).0; conv.recv = flume::bounded(0).1; //Result::<_>::Ok((ctx, sess)) Result::<_>::Ok((ctx, ())) }); let mut answer_buf = Vec::new(); while let Ok(question_msg) = q_recv.recv_async().await { let question: Question = question_msg.to_value().unwrap(); debug!("received question: {:?}", question); question_msg.write_ref(send).await?; if matches!(question, Question::Prompt { .. }) { let answer = read_msg(recv, &mut answer_buf).await??; trace!("received answer: {:?}", answer); let answer = match answer { Answer::Prompt(s) => (*s).to_owned(), //Answer::SshAuthResponse(_) => bail!("The received answer was unexpected: SshAuthResponse"), }; a_send.send_async(answer).await?; } } let (_ctx, _sess) = hdl.await??; /*let sess = ctx.unleak_session(sess); let env = sess.envlist(); let env = env .into_iter() .filter_map(|pair| { let element = pair.as_cstr().to_bytes_with_nul(); let sep = element .iter() .position(|b| *b == b'=') .unwrap_or(element.len()); let k = CString::new(&element[..sep]).ok()?; let v = CString::new(&element[sep + 1..]).ok()?; Some((k, v)) }) .collect::>();*/ Ok(()) } pub async fn client_authenticate( conn: &quinn::Connection, send: &mut SendStream, recv: &mut RecvStream, username: &str, ) -> Result<()> { write_msg( send, &super::Hello { username, auth_method: super::Method::Password, }, ) .await?; let mut stdout = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) }; let mut msg_buf = Vec::new(); loop { tokio::select! { r = read_msg::(recv, &mut msg_buf) => { match r?? { Question::LoggedIn => { return Ok(()); } Question::Prompt { prompt, // TODO: implement this echo: _, } => { stdout.write_all(prompt.to_bytes()).await?; stdout.write_all(b" ").await?; let answer = rpassword::read_password()?; let answer = Zeroizing::new(CString::new(answer)?); write_msg(send, &Answer::Prompt((*answer).as_ref().into())).await?; }, //Question::SshAuthRequest(_) => bail!("Received an unexpected question from the server: SshAuthRequest"), Question::TextInfo(s) => { stdout.write_all(b"INFO ").await?; stdout.write_all(s.to_bytes()).await?; stdout.write_all(b"\n").await?; }, Question::ErrorMsg(s) => { stdout.write_all(b"ERRO ").await?; stdout.write_all(s.to_bytes()).await?; stdout.write_all(b"\n").await?; }, } } r = send.stopped() => { info!("Remote disconnected"); let code = r?.into_inner(); if code == 0 { return Ok(()); } else { return Err(anyhow!("Error code {}", code)); } } e = conn.closed() => { info!("Remote disconnected: {}", e); return Err(anyhow!("Remote connection closed")); } } } }