221 lines
7.5 KiB
Rust
221 lines
7.5 KiB
Rust
use std::{ffi::CString, mem::ManuallyDrop, os::fd::FromRawFd};
|
|
|
|
use anyhow::{Context, Result};
|
|
use quinn::{SendStream, RecvStream};
|
|
use tokio::io::AsyncWriteExt;
|
|
use zeroize::Zeroizing;
|
|
|
|
use crate::{write_msg, read_msg};
|
|
use super::{Question, Answer};
|
|
|
|
#[cfg(feature = "server")]
|
|
pub async fn server_authenticate(send: &mut SendStream, recv: &mut RecvStream, username: String) -> Result<()> {
|
|
use pam_client::ConversationHandler;
|
|
|
|
use crate::Message;
|
|
|
|
use super::*;
|
|
|
|
let (q_send, q_recv) = flume::bounded(1);
|
|
let (a_send, a_recv) = flume::bounded(1);
|
|
|
|
struct Conversation {
|
|
send: flume::Sender<Message>,
|
|
recv: flume::Receiver<CString>,
|
|
error: std::cell::Cell<Option<&'static str>>,
|
|
}
|
|
|
|
impl Conversation {
|
|
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
|
|
|
fn error(&self, msg: &'static str) -> pam_client::ErrorCode {
|
|
self.error.set(Some(msg));
|
|
pam_client::ErrorCode::CONV_ERR
|
|
}
|
|
|
|
fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> {
|
|
self.send
|
|
.send_timeout(Message::from_value(&question)
|
|
.map_err(|_| self.error("Serialization error"))?, Self::TIMEOUT)
|
|
.map_err(|_| self.error("Ask question timed out"))
|
|
}
|
|
|
|
fn prompt(&self, prompt: &CStr, echo: bool) -> Result<CString, pam_client::ErrorCode> {
|
|
self.ask(Question::Prompt {
|
|
prompt: prompt.into(),
|
|
echo,
|
|
})?;
|
|
self.recv
|
|
.recv_timeout(Self::TIMEOUT)
|
|
.map_err(|_| self.error("Wait for answer timed out"))
|
|
}
|
|
}
|
|
|
|
impl ConversationHandler for Conversation {
|
|
fn prompt_echo_on(
|
|
&mut self,
|
|
prompt: &CStr,
|
|
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
|
|
self.prompt(prompt, true)
|
|
}
|
|
|
|
fn prompt_echo_off(
|
|
&mut self,
|
|
prompt: &CStr,
|
|
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
|
|
self.prompt(prompt, false)
|
|
}
|
|
|
|
fn text_info(&mut self, msg: &CStr) {
|
|
_ = self.ask(Question::TextInfo(msg.into()));
|
|
}
|
|
|
|
fn error_msg(&mut self, msg: &CStr) {
|
|
_ = self.ask(Question::ErrorMsg(msg.into()));
|
|
}
|
|
}
|
|
|
|
let hdl = tokio::task::spawn_blocking(move || {
|
|
let mut ctx = pam_client::Context::new(
|
|
"sshd",
|
|
Some(&username),
|
|
Conversation {
|
|
send: q_send,
|
|
recv: a_recv,
|
|
error: Default::default(),
|
|
},
|
|
)?;
|
|
info!("created context");
|
|
|
|
ctx.authenticate(pam_client::Flag::NONE)
|
|
.with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"))
|
|
.context("authenticate")?;
|
|
info!("authenticated user");
|
|
|
|
// very odd behaviour:
|
|
// if user is currently logged in, we must call acct_mgmt first.
|
|
// if user is not currently logged in, we must call acc_mgmt first and ignore the error.
|
|
if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") {
|
|
if cfg!(any(target_os = "macos", target_os = "ios")) {
|
|
warn!("ignoring validation error due to macOS oddity: {}", e);
|
|
} else {
|
|
return Err(e).with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"));
|
|
}
|
|
}
|
|
info!("validated user");
|
|
|
|
/*let sess = ctx
|
|
.open_session(pam_client::Flag::NONE)
|
|
.context("open_session")?;
|
|
info!("opened session");
|
|
let sess = sess.leak();*/
|
|
|
|
let conv = ctx.conversation_mut();
|
|
conv.send = flume::bounded(0).0;
|
|
conv.recv = flume::bounded(0).1;
|
|
//Result::<_>::Ok((ctx, sess))
|
|
Result::<_>::Ok((ctx, ()))
|
|
});
|
|
|
|
let mut answer_buf = Vec::new();
|
|
while let Ok(question_msg) = q_recv.recv_async().await {
|
|
let question: Question = question_msg.to_value().unwrap();
|
|
debug!("received question: {:?}", question);
|
|
question_msg.write_ref(send).await?;
|
|
if matches!(question, Question::Prompt { .. }) {
|
|
let answer = read_msg(recv, &mut answer_buf).await??;
|
|
trace!("received answer: {:?}", answer);
|
|
let answer = match answer {
|
|
Answer::Prompt(s) => (*s).to_owned(),
|
|
//Answer::SshAuthResponse(_) => bail!("The received answer was unexpected: SshAuthResponse"),
|
|
};
|
|
a_send.send_async(answer).await?;
|
|
}
|
|
}
|
|
|
|
let (_ctx, _sess) = hdl.await??;
|
|
/*let sess = ctx.unleak_session(sess);
|
|
let env = sess.envlist();
|
|
let env = env
|
|
.into_iter()
|
|
.filter_map(|pair| {
|
|
let element = pair.as_cstr().to_bytes_with_nul();
|
|
let sep = element
|
|
.iter()
|
|
.position(|b| *b == b'=')
|
|
.unwrap_or(element.len());
|
|
let k = CString::new(&element[..sep]).ok()?;
|
|
let v = CString::new(&element[sep + 1..]).ok()?;
|
|
Some((k, v))
|
|
})
|
|
.collect::<Vec<_>>();*/
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn client_authenticate(
|
|
conn: &quinn::Connection,
|
|
send: &mut SendStream,
|
|
recv: &mut RecvStream,
|
|
username: &str,
|
|
) -> Result<()> {
|
|
write_msg(
|
|
send,
|
|
&super::Hello {
|
|
username,
|
|
auth_method: super::Method::Password,
|
|
},
|
|
)
|
|
.await?;
|
|
|
|
let mut stdout =
|
|
unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) };
|
|
|
|
let mut msg_buf = Vec::new();
|
|
loop {
|
|
tokio::select! {
|
|
r = read_msg::<Question>(recv, &mut msg_buf) => {
|
|
match r?? {
|
|
Question::LoggedIn => {
|
|
return Ok(());
|
|
}
|
|
Question::Prompt {
|
|
prompt,
|
|
// TODO: implement this
|
|
echo: _,
|
|
} => {
|
|
stdout.write_all(prompt.to_bytes()).await?;
|
|
stdout.write_all(b" ").await?;
|
|
let answer = rpassword::read_password()?;
|
|
let answer = Zeroizing::new(CString::new(answer)?);
|
|
write_msg(send, &Answer::Prompt((*answer).as_ref().into())).await?;
|
|
},
|
|
//Question::SshAuthRequest(_) => bail!("Received an unexpected question from the server: SshAuthRequest"),
|
|
Question::TextInfo(s) => {
|
|
stdout.write_all(b"INFO ").await?;
|
|
stdout.write_all(s.to_bytes()).await?;
|
|
stdout.write_all(b"\n").await?;
|
|
},
|
|
Question::ErrorMsg(s) => {
|
|
stdout.write_all(b"ERRO ").await?;
|
|
stdout.write_all(s.to_bytes()).await?;
|
|
stdout.write_all(b"\n").await?;
|
|
},
|
|
}
|
|
}
|
|
r = send.stopped() => {
|
|
info!("Remote disconnected");
|
|
let code = r?.into_inner();
|
|
if code == 0 {
|
|
return Ok(());
|
|
} else {
|
|
return Err(anyhow!("Error code {}", code));
|
|
}
|
|
}
|
|
e = conn.closed() => {
|
|
info!("Remote disconnected: {}", e);
|
|
return Err(anyhow!("Remote connection closed"));
|
|
}
|
|
}
|
|
}
|
|
}
|