quinoa/src/auth/password.rs

221 lines
7.5 KiB
Rust

use std::{ffi::CString, mem::ManuallyDrop, os::fd::FromRawFd};
use anyhow::{Context, Result};
use quinn::{SendStream, RecvStream};
use tokio::io::AsyncWriteExt;
use zeroize::Zeroizing;
use crate::{write_msg, read_msg};
use super::{Question, Answer};
#[cfg(feature = "server")]
pub async fn server_authenticate(send: &mut SendStream, recv: &mut RecvStream, username: String) -> Result<()> {
use pam_client::ConversationHandler;
use crate::Message;
use super::*;
let (q_send, q_recv) = flume::bounded(1);
let (a_send, a_recv) = flume::bounded(1);
struct Conversation {
send: flume::Sender<Message>,
recv: flume::Receiver<CString>,
error: std::cell::Cell<Option<&'static str>>,
}
impl Conversation {
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
fn error(&self, msg: &'static str) -> pam_client::ErrorCode {
self.error.set(Some(msg));
pam_client::ErrorCode::CONV_ERR
}
fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> {
self.send
.send_timeout(Message::from_value(&question)
.map_err(|_| self.error("Serialization error"))?, Self::TIMEOUT)
.map_err(|_| self.error("Ask question timed out"))
}
fn prompt(&self, prompt: &CStr, echo: bool) -> Result<CString, pam_client::ErrorCode> {
self.ask(Question::Prompt {
prompt: prompt.into(),
echo,
})?;
self.recv
.recv_timeout(Self::TIMEOUT)
.map_err(|_| self.error("Wait for answer timed out"))
}
}
impl ConversationHandler for Conversation {
fn prompt_echo_on(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.prompt(prompt, true)
}
fn prompt_echo_off(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.prompt(prompt, false)
}
fn text_info(&mut self, msg: &CStr) {
_ = self.ask(Question::TextInfo(msg.into()));
}
fn error_msg(&mut self, msg: &CStr) {
_ = self.ask(Question::ErrorMsg(msg.into()));
}
}
let hdl = tokio::task::spawn_blocking(move || {
let mut ctx = pam_client::Context::new(
"sshd",
Some(&username),
Conversation {
send: q_send,
recv: a_recv,
error: Default::default(),
},
)?;
info!("created context");
ctx.authenticate(pam_client::Flag::NONE)
.with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"))
.context("authenticate")?;
info!("authenticated user");
// very odd behaviour:
// if user is currently logged in, we must call acct_mgmt first.
// if user is not currently logged in, we must call acc_mgmt first and ignore the error.
if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") {
if cfg!(any(target_os = "macos", target_os = "ios")) {
warn!("ignoring validation error due to macOS oddity: {}", e);
} else {
return Err(e).with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"));
}
}
info!("validated user");
/*let sess = ctx
.open_session(pam_client::Flag::NONE)
.context("open_session")?;
info!("opened session");
let sess = sess.leak();*/
let conv = ctx.conversation_mut();
conv.send = flume::bounded(0).0;
conv.recv = flume::bounded(0).1;
//Result::<_>::Ok((ctx, sess))
Result::<_>::Ok((ctx, ()))
});
let mut answer_buf = Vec::new();
while let Ok(question_msg) = q_recv.recv_async().await {
let question: Question = question_msg.to_value().unwrap();
debug!("received question: {:?}", question);
question_msg.write_ref(send).await?;
if matches!(question, Question::Prompt { .. }) {
let answer = read_msg(recv, &mut answer_buf).await??;
trace!("received answer: {:?}", answer);
let answer = match answer {
Answer::Prompt(s) => (*s).to_owned(),
//Answer::SshAuthResponse(_) => bail!("The received answer was unexpected: SshAuthResponse"),
};
a_send.send_async(answer).await?;
}
}
let (_ctx, _sess) = hdl.await??;
/*let sess = ctx.unleak_session(sess);
let env = sess.envlist();
let env = env
.into_iter()
.filter_map(|pair| {
let element = pair.as_cstr().to_bytes_with_nul();
let sep = element
.iter()
.position(|b| *b == b'=')
.unwrap_or(element.len());
let k = CString::new(&element[..sep]).ok()?;
let v = CString::new(&element[sep + 1..]).ok()?;
Some((k, v))
})
.collect::<Vec<_>>();*/
Ok(())
}
pub async fn client_authenticate(
conn: &quinn::Connection,
send: &mut SendStream,
recv: &mut RecvStream,
username: &str,
) -> Result<()> {
write_msg(
send,
&super::Hello {
username,
auth_method: super::Method::Password,
},
)
.await?;
let mut stdout =
unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) };
let mut msg_buf = Vec::new();
loop {
tokio::select! {
r = read_msg::<Question>(recv, &mut msg_buf) => {
match r?? {
Question::LoggedIn => {
return Ok(());
}
Question::Prompt {
prompt,
// TODO: implement this
echo: _,
} => {
stdout.write_all(prompt.to_bytes()).await?;
stdout.write_all(b" ").await?;
let answer = rpassword::read_password()?;
let answer = Zeroizing::new(CString::new(answer)?);
write_msg(send, &Answer::Prompt((*answer).as_ref().into())).await?;
},
//Question::SshAuthRequest(_) => bail!("Received an unexpected question from the server: SshAuthRequest"),
Question::TextInfo(s) => {
stdout.write_all(b"INFO ").await?;
stdout.write_all(s.to_bytes()).await?;
stdout.write_all(b"\n").await?;
},
Question::ErrorMsg(s) => {
stdout.write_all(b"ERRO ").await?;
stdout.write_all(s.to_bytes()).await?;
stdout.write_all(b"\n").await?;
},
}
}
r = send.stopped() => {
info!("Remote disconnected");
let code = r?.into_inner();
if code == 0 {
return Ok(());
} else {
return Err(anyhow!("Error code {}", code));
}
}
e = conn.closed() => {
info!("Remote disconnected: {}", e);
return Err(anyhow!("Remote connection closed"));
}
}
}
}