Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Pfaff 413e0eaea5
Revert back to rmp_serde::to_vec 2023-06-11 00:15:45 -04:00
Michael Pfaff 7b50417ba7
WIP public key authentication
- Implemented public key authentication
    - TODO: figure out key selection (I refuse to resort to sending all
      public keys to the server)
- Refactoring
2023-06-11 00:10:07 -04:00
6 changed files with 1241 additions and 350 deletions

718
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,14 @@ version = "0.1.0"
edition = "2021"
[features]
server = ["dep:pam-client", "dep:pam-client-macos"]
default = ["ed25519", "ecdsa", "rsa"]
#server = ["dep:pam-client", "dep:pam-client-macos"]
server = ["dep:pam-client"]
ed25519 = ["ssh-key/ed25519"]
ecdsa = ["ssh-key/ecdsa"]
rsa = ["ssh-key/rsa"]
[dependencies]
anyhow = "1.0.71"
@ -17,21 +24,27 @@ nix = "0.26.2"
parking_lot = "0.12.1"
pin-project-lite = "0.2.9"
quinn = "0.10.1"
rand = "0.8.5"
rcgen = "0.10.0"
rmp = "0.8.11"
rmp-serde = "1.1.1"
rmpv = "1.0.0"
rpassword = "7.2.0"
rustls = { version = "0.21.1", default-features = false, features = ["dangerous_configuration"] }
rustls-webpki = "0.100.1"
serde = { version = "1.0.163", features = ["derive"] }
serde_with = { version = "3.0.0", default-features = false, features = ["std"] }
sha2 = "0.10.6"
ssh-key = { version = "0.5.1", default-features = false, features = ["std", "encryption"] }
#termion = "2.0.1"
tokio = { version = "1.28.2", default-features = false, features = ["rt-multi-thread", "macros", "process", "io-util", "io-std", "time", "fs", "signal"] }
tracing = { version = "0.1.37", features = ["max_level_debug", "release_max_level_info"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
triggered = "0.1.2"
zeroize = { version = "1.6.0", features = ["std"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
pam-client = { version = "0.5.0", default-features = false, features = ["serde"], optional = true }
[target.'cfg(target_os = "macos")'.dependencies]
pam-client-macos = { package = "pam-client", version = "0.5.0", path = "../../../../../Users/michael/b/rust-pam-client", default-features = false, features = ["serde"], optional = true }
#pam-client-macos = { package = "pam-client", version = "0.5.0", path = "../../../../../Users/michael/b/rust-pam-client", default-features = false, features = ["serde"], optional = true }

86
src/auth/mod.rs Normal file
View File

@ -0,0 +1,86 @@
pub mod password;
pub mod ssh_key;
use std::ffi::CStr;
#[derive(Debug, Serialize, Deserialize)]
pub struct Hello<'a> {
#[serde(borrow)]
pub username: &'a str,
#[serde(borrow)]
pub auth_method: Method<'a>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Method<'a> {
Password,
SshKey {
/// [Public key](::ssh_key::public::PublicKey).
#[serde(borrow)]
public_key: &'a str,
},
}
mod cstr_as_bytes {
use std::ffi::CStr;
use serde::{Serializer, Deserializer, Deserialize};
pub fn serialize<S>(s: &CStr, ser: S) -> Result<S::Ok, S::Error> where S: Serializer {
ser.serialize_bytes(s.to_bytes_with_nul())
}
pub fn deserialize<'de, D>(de: D) -> Result<&'de CStr, D::Error> where D: Deserializer<'de> {
use serde::de::Error;
let b = <&'de [u8]>::deserialize(de)?;
CStr::from_bytes_with_nul(b)
.map_err(|_| D::Error::invalid_value(serde::de::Unexpected::Bytes(b), &"a sequence of bytes ending with NUL"))
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
struct ByteSlice<'a>(#[serde(serialize_with = "serialize_bytes")] &'a [u8]);
pub fn serialize_bytes<S>(b: &[u8], ser: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
ser.serialize_bytes(b)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CStrAsBytes<'a>(#[serde(serialize_with = "cstr_as_bytes::serialize", deserialize_with = "cstr_as_bytes::deserialize", borrow)] &'a CStr);
impl<'a> std::ops::Deref for CStrAsBytes<'a> {
type Target = CStr;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a> From<&'a CStr> for CStrAsBytes<'a> {
fn from(value: &'a CStr) -> Self {
Self(value)
}
}
impl<'a> From<CStrAsBytes<'a>> for &'a CStr {
fn from(value: CStrAsBytes<'a>) -> Self {
value.0
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Question<'a> {
Prompt { #[serde(borrow)] prompt: CStrAsBytes<'a>, echo: bool },
TextInfo(#[serde(borrow)] CStrAsBytes<'a>),
ErrorMsg(#[serde(borrow)] CStrAsBytes<'a>),
//SshAuthRequest(ssh_key::AuthRequest<'a>),
LoggedIn,
}
#[derive(Debug, Serialize, Deserialize)]
enum Answer<'a> {
Prompt(#[serde(borrow)] CStrAsBytes<'a>),
//SshAuthResponse(ssh_key::AuthResponse<'a>),
}

219
src/auth/password.rs Normal file
View File

@ -0,0 +1,219 @@
use std::{ffi::CString, mem::ManuallyDrop, os::fd::FromRawFd};
use anyhow::{Context, Result};
use quinn::{SendStream, RecvStream};
use tokio::io::AsyncWriteExt;
use zeroize::Zeroizing;
use crate::{write_msg, read_msg};
use super::{Question, Answer};
#[cfg(feature = "server")]
pub async fn server_authenticate(send: &mut SendStream, recv: &mut RecvStream, username: String) -> Result<()> {
use pam_client::ConversationHandler;
use crate::Message;
use super::*;
let (q_send, q_recv) = flume::bounded(1);
let (a_send, a_recv) = flume::bounded(1);
struct Conversation {
send: flume::Sender<Message>,
recv: flume::Receiver<CString>,
error: std::cell::Cell<Option<&'static str>>,
}
impl Conversation {
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
fn error(&self, msg: &'static str) -> pam_client::ErrorCode {
self.error.set(Some(msg));
pam_client::ErrorCode::CONV_ERR
}
fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> {
self.send
.send_timeout(Message::from_value(&question)
.map_err(|_| self.error("Serialization error"))?, Self::TIMEOUT)
.map_err(|_| self.error("Ask question timed out"))
}
fn prompt(&self, prompt: &CStr, echo: bool) -> Result<CString, pam_client::ErrorCode> {
self.ask(Question::Prompt {
prompt: prompt.into(),
echo,
})?;
self.recv
.recv_timeout(Self::TIMEOUT)
.map_err(|_| self.error("Wait for answer timed out"))
}
}
impl ConversationHandler for Conversation {
fn prompt_echo_on(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.prompt(prompt, true)
}
fn prompt_echo_off(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.prompt(prompt, false)
}
fn text_info(&mut self, msg: &CStr) {
_ = self.ask(Question::TextInfo(msg.into()));
}
fn error_msg(&mut self, msg: &CStr) {
_ = self.ask(Question::ErrorMsg(msg.into()));
}
}
let hdl = tokio::task::spawn_blocking(move || {
let mut ctx = pam_client::Context::new(
"sshd",
Some(&username),
Conversation {
send: q_send,
recv: a_recv,
error: Default::default(),
},
)?;
info!("created context");
ctx.authenticate(pam_client::Flag::NONE)
.with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"))
.context("authenticate")?;
info!("authenticated user");
// very odd behaviour:
// if user is currently logged in, we must call acct_mgmt first.
// if user is not currently logged in, we must call acc_mgmt first and ignore the error.
if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") {
if cfg!(any(target_os = "macos", target_os = "ios")) {
warn!("ignoring validation error due to macOS oddity: {}", e);
} else {
return Err(e).with_context(|| ctx.conversation_mut().error.take().unwrap_or("Unknown error"));
}
}
info!("validated user");
/*let sess = ctx
.open_session(pam_client::Flag::NONE)
.context("open_session")?;
info!("opened session");
let sess = sess.leak();*/
let conv = ctx.conversation_mut();
conv.send = flume::bounded(0).0;
conv.recv = flume::bounded(0).1;
//Result::<_>::Ok((ctx, sess))
Result::<_>::Ok((ctx, ()))
});
let mut answer_buf = Vec::new();
while let Ok(question_msg) = q_recv.recv_async().await {
let question: Question = question_msg.to_value().unwrap();
debug!("received question: {:?}", question);
question_msg.write_ref(send).await?;
if matches!(question, Question::Prompt { .. }) {
let answer = read_msg(recv, &mut answer_buf).await??;
trace!("received answer: {:?}", answer);
let answer = match answer {
Answer::Prompt(s) => (*s).to_owned(),
//Answer::SshAuthResponse(_) => bail!("The received answer was unexpected: SshAuthResponse"),
};
a_send.send_async(answer).await?;
}
}
let (_ctx, _sess) = hdl.await??;
/*let sess = ctx.unleak_session(sess);
let env = sess.envlist();
let env = env
.into_iter()
.filter_map(|pair| {
let element = pair.as_cstr().to_bytes_with_nul();
let sep = element
.iter()
.position(|b| *b == b'=')
.unwrap_or(element.len());
let k = CString::new(&element[..sep]).ok()?;
let v = CString::new(&element[sep + 1..]).ok()?;
Some((k, v))
})
.collect::<Vec<_>>();*/
Ok(())
}
pub async fn client_authenticate(
conn: &quinn::Connection,
send: &mut SendStream,
recv: &mut RecvStream,
username: &str,
) -> Result<()> {
write_msg(
send,
&super::Hello {
username,
auth_method: super::Method::Password,
},
)
.await?;
let mut stdout =
unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) };
let mut msg_buf = Vec::new();
loop {
tokio::select! {
r = read_msg::<Question>(recv, &mut msg_buf) => {
match r?? {
Question::LoggedIn => {
return Ok(());
}
Question::Prompt {
prompt,
echo,
} => {
stdout.write_all(prompt.to_bytes()).await?;
stdout.write_all(b" ").await?;
let answer = rpassword::read_password()?;
let answer = Zeroizing::new(CString::new(answer)?);
write_msg(send, &Answer::Prompt((*answer).as_ref().into())).await?;
},
//Question::SshAuthRequest(_) => bail!("Received an unexpected question from the server: SshAuthRequest"),
Question::TextInfo(s) => {
stdout.write_all(b"INFO ").await?;
stdout.write_all(s.to_bytes()).await?;
stdout.write_all(b"\n").await?;
},
Question::ErrorMsg(s) => {
stdout.write_all(b"ERRO ").await?;
stdout.write_all(s.to_bytes()).await?;
stdout.write_all(b"\n").await?;
},
}
}
r = send.stopped() => {
info!("Remote disconnected");
let code = r?.into_inner();
if code == 0 {
return Ok(());
} else {
return Err(anyhow!("Error code {}", code));
}
}
e = conn.closed() => {
info!("Remote disconnected: {}", e);
return Err(anyhow!("Remote connection closed"));
}
}
}
}

119
src/auth/ssh_key.rs Normal file
View File

@ -0,0 +1,119 @@
use anyhow::{Context, Result};
use nix::unistd::Uid;
use quinn::{SendStream, RecvStream};
use zeroize::Zeroizing;
use crate::{read_msg, write_msg, ClientConfig, Message};
use super::Question;
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthRequest<'a> {
#[serde(serialize_with = "super::serialize_bytes")]
pub nonce: &'a [u8],
}
impl<'a> AuthRequest<'a> {
pub const NAMESPACE: &str = "QUINOA AUTHENTICATION";
pub fn new(nonce: &'a [u8]) -> AuthRequest<'a> {
Self {
nonce,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthResponse<'a> {
#[serde(borrow)]
signature: &'a str,
}
#[cfg(feature = "server")]
async fn server_authorize_key(user_id: Uid, public_key: &ssh_key::public::PublicKey) -> Result<()> {
const PATH: &str = "/.ssh/authorized_keys";
let mut user_passwd_buf = Vec::new();
let user_passwd = crate::passwd::Passwd::from_uid(user_id, &mut user_passwd_buf)?
.context("No passwd entry for user")?;
let home = user_passwd.dir.to_str()?;
let mut authorized_keys = String::with_capacity(home.len() + PATH.len());
authorized_keys.push_str(home);
authorized_keys.push_str(PATH);
let authorized_keys = tokio::fs::read_to_string(authorized_keys).await?;
let mut authorized_keys = ssh_key::authorized_keys::AuthorizedKeys::new(&authorized_keys);
while let Some(r) = authorized_keys.next() {
let entry = r?;
if entry.public_key().key_data() == public_key.key_data() {
return Ok(());
}
}
Err(anyhow!("Key provided by the client is not authorized to connect"))
}
#[cfg(feature = "server")]
pub async fn server_authenticate(send: &mut SendStream, recv: &mut RecvStream, user_id: Uid, public_key: &str) -> Result<()> {
use rand::RngCore;
let public_key = ssh_key::public::PublicKey::from_openssh(public_key)?;
server_authorize_key(user_id, &public_key).await?;
let mut nonce = vec![0; 256];
rand::thread_rng().try_fill_bytes(&mut nonce)?;
let request = AuthRequest::new(&nonce);
let request = Message::from_value(&request)?;
request.write_ref(send).await?;
let mut buf = Vec::new();
let response = read_msg::<AuthResponse>(recv, &mut buf).await??;
let sig = ssh_key::SshSig::from_pem(&response.signature)?;
public_key.verify(AuthRequest::NAMESPACE, &request.0, &sig)?;
Ok(())
}
pub async fn client_authenticate(cfg: &ClientConfig, send: &mut SendStream, recv: &mut RecvStream, username: &str) -> Result<()> {
let private_key = Zeroizing::new(tokio::fs::read_to_string(&cfg.ssh_key_file).await?);
let private_key = ssh_key::private::PrivateKey::from_openssh(private_key)
.context("Loading private key")?;
let public_key = private_key.public_key();
let public_key = public_key.to_openssh()
.context("Encoding public key")?;
info!("loaded private key: {:?}", private_key.algorithm());
write_msg(send, &super::Hello {
username,
auth_method: super::Method::SshKey {
public_key: &public_key,
},
}).await?;
info!("sent hello");
let mut buf = Vec::new();
_ = read_msg::<AuthRequest>(recv, &mut buf).await??;
info!("read auth request");
let private_key = if private_key.is_encrypted() {
let password = Zeroizing::new(rpassword::prompt_password("Enter the key's passphrase: ")?);
private_key.decrypt(&password).context("Incorrect passphrase")?
} else {
private_key
};
let sig = private_key.sign(AuthRequest::NAMESPACE, ssh_key::HashAlg::Sha512, &buf)?;
let sig = sig.to_pem(ssh_key::LineEnding::LF)?;
info!("signed auth request");
write_msg(send, &AuthResponse { signature: &sig }).await?;
info!("wrote auth request");
let Question::LoggedIn = read_msg(recv, &mut buf).await?? else {
bail!("Received an unexpected question")
};
Ok(())
}

View File

@ -8,6 +8,7 @@ extern crate serde;
#[macro_use]
extern crate tracing;
mod auth;
mod io_util;
mod passwd;
mod pty;
@ -35,8 +36,6 @@ use std::task::Poll;
use anyhow::{Context, Result};
use base64::Engine as _;
use nix::unistd::Uid;
#[cfg(feature = "server")]
use pam_client::ConversationHandler;
use quinn::{ReadExactError, RecvStream, SendStream};
use rustls::client::ServerCertVerifier;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -220,13 +219,14 @@ async fn run_cmd(mut args: std::env::Args) -> Result<()> {
const ALPN_QUINOA: &str = "quinoa";
struct ServerConfig {
pub struct ServerConfig {
listen: SocketAddr,
}
struct ClientConfig {
pub struct ClientConfig {
known_hosts_file: String,
known_hosts: parking_lot::Mutex<Vec<KnownHost<'static>>>,
ssh_key_file: String,
}
#[cfg(feature = "server")]
@ -239,7 +239,8 @@ async fn run_server() -> Result<()> {
&*Box::leak(ServerConfig { listen: opt_listen }.into())
};
let subject_alt_names = vec!["localhost".to_string()];
//let subject_alt_names = vec!["localhost".to_string()];
let subject_alt_names = vec![];
let (cert, key) = if !std::path::Path::new("cert.der").exists()
|| !std::path::Path::new("key.der").exists()
@ -313,8 +314,10 @@ impl fmt::Display for FinishedEarly {
impl std::error::Error for FinishedEarly {}
async fn read_msg<T: serde::de::DeserializeOwned>(
/// Reads a message `T`. `buf` will be cleared before reading.
async fn read_msg<'a, T: serde::Deserialize<'a>>(
recv: &mut RecvStream,
buf: &'a mut Vec<u8>,
) -> Result<Result<T, FinishedEarly>> {
let mut size = [0u8; 2];
match recv.read_exact(&mut size).await {
@ -323,19 +326,59 @@ async fn read_msg<T: serde::de::DeserializeOwned>(
Err(ReadExactError::ReadError(e)) => return Err(e.into()),
}
let size = u16::from_le_bytes(size);
let mut buf = Vec::with_capacity(size.into());
recv.take(size.into()).read_to_end(&mut buf).await?;
Ok(Ok(rmp_serde::from_slice(&buf).with_context(|| {
format!("reading a {} byte message", size)
})?))
buf.clear();
buf.reserve(size.into());
recv.take(size.into()).read_to_end(buf).await?;
Message::raw_to_value(buf).map(Ok)
}
struct Message(Vec<u8>);
impl Message {
pub fn from_value<T: serde::Serialize>(value: &T) -> Result<Self> {
Ok(Self(rmp_serde::to_vec(value)?))
}
pub fn from_raw(data: Vec<u8>) -> Self {
Self(data)
}
async fn write_len(&self, send: &mut SendStream) -> Result<()> {
send.write_all(&u16::try_from(self.0.len())?.to_le_bytes())
.await
.map_err(|e| e.into())
}
pub async fn write_ref(&self, send: &mut SendStream) -> Result<()> {
self.write_len(send).await?;
send.write_all(&self.0).await?;
Ok(())
}
pub async fn write(self, send: &mut SendStream) -> Result<()> {
self.write_len(send).await?;
send.write_chunk(self.0.into()).await?;
Ok(())
}
fn raw_to_value<'de, T: serde::Deserialize<'de>>(data: &'de [u8]) -> Result<T> {
rmp_serde::from_slice(data)
.with_context(|| {
format!("reading a {} byte message", data.len())
})
.with_context(|| {
let mut data = data;
format!("{:?}", rmpv::decode::read_value_ref(&mut data))
})
}
pub fn to_value<'de, T: serde::Deserialize<'de>>(&'de self) -> Result<T> {
Self::raw_to_value(&self.0)
}
}
async fn write_msg<T: serde::Serialize>(send: &mut SendStream, value: &T) -> Result<()> {
let buf = rmp_serde::to_vec(value)?;
send.write_all(&u16::try_from(buf.len())?.to_le_bytes())
.await?;
send.write_all(&buf).await?;
Ok(())
Message::from_value(value)?.write(send).await
}
struct InformedServerCertVerifier {
@ -364,8 +407,14 @@ impl InformedServerCertVerifier {
let hash = Hash::new(&end_entity.0);
eprintln!(
"The authenticity of host {:?} can't be established.",
subject_name
"The authenticity of host {} can't be established.",
match subject_name {
SubjectNameRef::DnsName(name) => name.into(),
SubjectNameRef::IpAddress(ip) => std::str::from_utf8(match ip {
webpki::IpAddrRef::V4(b, _) => b,
webpki::IpAddrRef::V6(b, _) => b,
}).map_err(|e| CertificateError::Other(Arc::new(e)))?,
}
);
eprintln!("Certificate hash is {}", hash);
if let Some(known_as) = known_as.and_then(|i| known_hosts.get(i)) {
@ -479,8 +528,10 @@ impl ServerCertVerifier for InformedServerCertVerifier {
_ => return Err(CertificateError::NotValidForName.into()),
};
cert.verify_is_valid_for_subject_name(subject_name)
.map_err(pki_error)?;
// TODO: is expiry checked for us?
/*cert.verify_is_valid_for_subject_name(subject_name)
.map_err(pki_error)?;*/
let mut known_hosts = self.cfg.known_hosts.lock();
if let Some((h_index, h)) = Self::find_known_host(&known_hosts, subject_name) {
@ -648,8 +699,8 @@ fn write_known_hosts(file: &str, hosts: &[KnownHost<'_>]) -> Result<()> {
async fn run_client(mut args: std::env::Args) -> Result<()> {
info!("running client");
let mut cfg_dir = std::env::var("HOME")?;
cfg_dir.push_str("/.config/quinoa");
let home_dir = std::env::var("HOME")?;
let cfg_dir = format!("{}/.config/quinoa", home_dir);
tokio::fs::create_dir_all(&cfg_dir).await?;
let known_hosts_file = format!("{}/known_hosts", cfg_dir);
let known_hosts = if std::path::Path::new(&known_hosts_file).exists() {
@ -665,19 +716,26 @@ async fn run_client(mut args: std::env::Args) -> Result<()> {
} else {
Vec::new()
};
let ssh_key_file = format!("{}/.ssh/id_ed25519", home_dir);
//let ssh_key_file = format!("{}/.ssh/id_rsa", home_dir);
let cfg = &*Box::leak(
ClientConfig {
known_hosts_file,
known_hosts: known_hosts.into(),
ssh_key_file,
}
.into(),
);
let mut conn_str = None;
let mut forwards = Vec::new();
let mut use_key = true;
while let Some(arg) = args.next() {
if let Some(arg) = arg.strip_prefix('-') {
match arg {
"-no-key" => {
use_key = false;
}
"-forward" => {
let v = args.next().context("Expected a value of the form LOCAL_ADDR:LOCAL_PORT[->|<-]REMOTE_ADDR:REMOTE_PORT (and /[tcp|udp] on bind side)")?;
forwards.push(v.parse::<ForwardSpec>()?);
@ -695,6 +753,8 @@ async fn run_client(mut args: std::env::Args) -> Result<()> {
.as_ref()
.and_then(|s| s.split_once('@'))
.context("Expected an argument of the form USERNAME@HOST")?;
let (host_name, port) = host.split_once(':').unwrap_or((host, "8022"));
let port = port.parse::<u16>()?;
let mut client_crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
@ -714,20 +774,17 @@ async fn run_client(mut args: std::env::Args) -> Result<()> {
info!("connecting");
let conn = endpoint.connect(host.parse()?, "localhost")?.await?;
let conn = endpoint.connect(host.parse()?, host_name)?.await?;
// authenticating client
{
let (mut send, mut recv) = conn.open_bi().await?;
write_msg(
&mut send,
&auth::Hello {
username: username.to_owned(),
},
)
.await?;
do_auth_prompt(&conn, &mut send, &mut recv).await?;
if use_key {
auth::ssh_key::client_authenticate(cfg, &mut send, &mut recv, username).await?;
} else {
auth::password::client_authenticate(&conn, &mut send, &mut recv, username).await?;
}
}
// authenticated client
@ -954,63 +1011,6 @@ async fn do_shell(conn: &quinn::Connection) -> Result<()> {
}
}
async fn do_auth_prompt(
conn: &quinn::Connection,
send: &mut SendStream,
recv: &mut RecvStream,
) -> Result<()> {
use auth::*;
let mut stdout =
unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) };
loop {
tokio::select! {
r = read_msg::<Question>(recv) => {
match r?? {
Question::LoggedIn => {
return Ok(());
}
Question::Prompt {
prompt,
echo,
} => {
let mut prompt = prompt.into_bytes();
prompt.push(b' ');
stdout.write_all(&prompt).await?;
let answer = rpassword::read_password()?;
let answer = CString::new(answer)?;
write_msg(send, &Answer::Prompt(answer)).await?;
},
Question::TextInfo(s) => {
stdout.write_all(b"INFO ").await?;
stdout.write_all(s.as_bytes()).await?;
stdout.write_all(b"\n").await?;
},
Question::ErrorMsg(s) => {
stdout.write_all(b"ERRO ").await?;
stdout.write_all(s.as_bytes()).await?;
stdout.write_all(b"\n").await?;
},
}
}
r = send.stopped() => {
info!("Remote disconnected");
let code = r?.into_inner();
if code == 0 {
return Ok(());
} else {
return Err(anyhow!("Error code {}", code));
}
}
e = conn.closed() => {
info!("Remote disconnected: {}", e);
return Err(anyhow!("Remote connection closed"));
}
}
}
}
#[cfg(feature = "server")]
async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Result<()> {
info!("greeting connection");
@ -1046,250 +1046,31 @@ async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Resu
Ok(())
}
mod auth {
use std::ffi::CString;
#[derive(Debug, Serialize, Deserialize)]
pub struct Hello {
pub username: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Question {
Prompt { prompt: CString, echo: bool },
TextInfo(CString),
ErrorMsg(CString),
LoggedIn,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Answer {
Prompt(CString),
}
}
#[cfg(feature = "server")]
async fn authenticate_conn(
cfg: &'static ServerConfig,
conn: &quinn::Connection,
) -> Result<(UserInfo, Vec<(CString, CString)>)> {
use auth::*;
info!("authenticating connection");
let (mut send, mut recv) = conn.accept_bi().await?;
let hello = read_msg::<Hello>(&mut recv).await??;
let mut hello_buf = Vec::new();
let hello = read_msg::<auth::Hello>(&mut recv, &mut hello_buf).await??;
let (q_send, q_recv) = flume::bounded(1);
let (a_send, a_recv) = flume::bounded(1);
let user_info = user_info::get_user_info(&hello.username).await?;
struct Conversation {
send: flume::Sender<Question>,
recv: flume::Receiver<Answer>,
match hello.auth_method {
auth::Method::Password => auth::password::server_authenticate(&mut send, &mut recv, hello.username.to_owned()).await?,
auth::Method::SshKey { public_key } => auth::ssh_key::server_authenticate(&mut send, &mut recv, user_info.user.id, public_key).await?,
}
impl Conversation {
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> {
self.send
.send_timeout(question, Self::TIMEOUT)
.map_err(|_| pam_client::ErrorCode::ABORT)
}
fn answer(&self) -> Result<Answer, pam_client::ErrorCode> {
self.recv
.recv_timeout(Self::TIMEOUT)
.map_err(|_| pam_client::ErrorCode::ABORT)
}
}
impl ConversationHandler for Conversation {
fn prompt_echo_on(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.ask(Question::Prompt {
prompt: prompt.to_owned(),
echo: true,
})?;
match self.answer()? {
Answer::Prompt(s) => Ok(s),
}
}
fn prompt_echo_off(
&mut self,
prompt: &CStr,
) -> std::result::Result<std::ffi::CString, pam_client::ErrorCode> {
self.ask(Question::Prompt {
prompt: prompt.to_owned(),
echo: false,
})?;
match self.answer()? {
Answer::Prompt(s) => Ok(s),
}
}
fn text_info(&mut self, msg: &CStr) {
_ = self.ask(Question::TextInfo(msg.to_owned()));
}
fn error_msg(&mut self, msg: &CStr) {
_ = self.ask(Question::ErrorMsg(msg.to_owned()));
}
}
let username = hello.username.clone();
let hdl = tokio::task::spawn_blocking(move || {
let mut ctx = pam_client::Context::new(
"sshd",
Some(&username),
Conversation {
send: q_send,
recv: a_recv,
},
)?;
info!("created context");
ctx.authenticate(pam_client::Flag::NONE)
.context("authenticate")?;
info!("authenticated user");
// very odd behaviour:
// if user is currently logged in, we must call acct_mgmt first.
// if user is not currently logged in, we must call acc_mgmt first and ignore the error.
if let Err(e) = ctx.acct_mgmt(pam_client::Flag::NONE).context("acct_mgmt") {
if cfg!(any(target_os = "macos", target_os = "ios")) {
warn!("ignoring validation error due to macOS oddity: {}", e);
} else {
return Err(e);
}
}
info!("validated user");
/*let sess = ctx
.open_session(pam_client::Flag::NONE)
.context("open_session")?;
info!("opened session");
let sess = sess.leak();*/
let conv = ctx.conversation_mut();
conv.send = flume::bounded(0).0;
conv.recv = flume::bounded(0).1;
//Result::<_>::Ok((ctx, sess))
Result::<_>::Ok((ctx, ()))
});
while let Ok(question) = q_recv.recv_async().await {
debug!("received question: {:?}", question);
write_msg(&mut send, &question).await?;
if matches!(question, Question::Prompt { .. }) {
let answer = read_msg(&mut recv).await??;
trace!("received answer: {:?}", answer);
a_send.send_async(answer).await?;
}
/*match question {
Question::Prompt { prompt, echo } => {
let r = async {
// FIXME: actually disable echo
send.write_all(prompt.as_bytes()).await?;
send.write_all(b" ").await?;
let erase = format!("\x1b[{}G\x1b[K", prompt.as_bytes().len() + 1 + 1);
let mut buf = Vec::new();
'prompt: loop {
let mut i = buf.len();
recv.read_buf(&mut buf).await?;
let mut j = i;
while j < buf.len() {
match buf[j] {
0x7f => {
buf.remove(j);
if j > 0 {
if j == i {
i -= 1;
}
buf.remove(j-1);
// erase in line, move cursor left 1 column
send.write_all(erase.as_bytes()).await?;
send.write_all(&buf).await?;
j -= 1;
}
}
0x3 => {
send.write_all(b"\r\n").await?;
return Err(anyhow!("Aborted by the user"));
}
b'\r' => {
buf.remove(j);
}
b'\n' => {
info!("found \\n");
// remove newline and trailing chars
buf.truncate(j);
break 'prompt;
}
_ => {
j += 1;
}
}
}
let seg = &buf[i..];
if echo {
send.write_all(&seg).await?;
} else {
send.write_all(&vec![b'*'; seg.len()]).await?;
}
info!("{:?} ({:x?})", std::str::from_utf8(&buf), buf);
}
let buf = CString::new(buf)?;
send.write_all(b"\n").await?;
Result::<_>::Ok(buf)
}.await;
a_send.send_async(Answer::Prompt(r.map_err(|e| {
error!("PAM error: {}", e);
pam_client::ErrorCode::ABORT
}))).await?;
}
Question::TextInfo(s) => {
send.write_all(b"INFO ").await?;
send.write_all(s.as_bytes()).await?;
send.write_all(b"\n").await?;
}
Question::ErrorMsg(s) => {
send.write_all(b"ERRO ").await?;
send.write_all(s.as_bytes()).await?;
send.write_all(b"\n").await?;
}
}*/
}
let (mut ctx, sess) = hdl.await??;
/*let sess = ctx.unleak_session(sess);
let env = sess.envlist();
let env = env
.into_iter()
.filter_map(|pair| {
let element = pair.as_cstr().to_bytes_with_nul();
let sep = element
.iter()
.position(|b| *b == b'=')
.unwrap_or(element.len());
let k = CString::new(&element[..sep]).ok()?;
let v = CString::new(&element[sep + 1..]).ok()?;
Some((k, v))
})
.collect::<Vec<_>>();*/
let env = Vec::new();
info!("logged in");
write_msg(&mut send, &Question::LoggedIn).await?;
write_msg(&mut send, &auth::Question::LoggedIn).await?;
send.finish().await?;
recv.stop(0u8.into())?;
let user_info = user_info::get_user_info(&hello.username).await?;
Ok((user_info, env))
}
@ -1302,6 +1083,7 @@ async fn handle_conn(
) -> Result<()> {
info!("established");
let mut stream_buf = Vec::new();
loop {
let stream = conn.accept_bi().await;
let (send, mut recv) = match stream {
@ -1315,7 +1097,7 @@ async fn handle_conn(
Ok(s) => s,
};
let stream = read_msg::<Stream>(&mut recv).await??;
let stream = read_msg::<Stream>(&mut recv, &mut stream_buf).await??;
let span = info_span!(
"stream",
r#type = ?stream
@ -1522,22 +1304,22 @@ async fn handle_stream_shell(
extern "C" {
#[cfg(any(target_os = "macos", target_os = "ios"))]
pub fn _NSGetEnviron() -> *mut *const *const std::os::raw::c_char;
pub fn _NSGetEnviron() -> *mut *const *const std::ffi::c_char;
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
static mut environ: *const *const c_char;
static mut environ: *const *const std::ffi::c_char;
}
let mut keep = Vec::new();
const PRESERVE_ENV: &[&[u8]] = &[b"PATH\0", b"LANG\0"];
unsafe {
#[cfg(any(target_os = "macos", target_os = "ios"))]
let mut environ = *_NSGetEnviron();
let mut env = *_NSGetEnviron();
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
let mut environ = environ;
if !environ.is_null() {
while !(*environ).is_null() {
let key_value = CStr::from_ptr(*environ).to_bytes_with_nul();
environ = environ.add(1);
let mut env = environ;
if !env.is_null() {
while !(*env).is_null() {
let key_value = CStr::from_ptr(*env).to_bytes_with_nul();
env = env.add(1);
let Some(i) = key_value.iter().position(|b| *b == b'=') else {
continue
};
@ -1583,7 +1365,7 @@ async fn handle_stream_shell(
target_os = "redox",
target_os = "haiku"
)))]
nix::unistd::setgroups(&user_info.groups.into_iter().map(|g| g.id).collect())
nix::unistd::setgroups(&user_info.groups.iter().map(|g| g.id).collect::<Vec<_>>())
.context("setting supplementary groups")?;
nix::unistd::setgid(user_info.group.id).context("setting primary group")?;
@ -1737,7 +1519,7 @@ async fn handle_stream_shell(
let r = pty.read_buf(buf).await;
//_ = waker_recv.try_recv();
if let Err(e) = r {
if e.raw_os_error() == Some(35) {
if e.raw_os_error() == Some(35) || e.raw_os_error() == Some(11) {
//debug!("not ready: {}", e);
//tokio::task::yield_now().await;
//tokio::time::sleep(std::time::Duration::from_millis(1)).await;