From e3d610e9ad3ac2dc12951bebe94c55e7dfa21dc9 Mon Sep 17 00:00:00 2001 From: Michael Pfaff Date: Tue, 6 Jun 2023 20:07:10 -0400 Subject: [PATCH] Authentication, better performance, better error handling --- Cargo.lock | 415 ++++++++++++++++++++++- Cargo.toml | 6 + README.md | 8 + src/main.rs | 929 ++++++++++++++++++++++++++++++++++++++++------------ src/pty.rs | 29 +- 5 files changed, 1150 insertions(+), 237 deletions(-) create mode 100644 README.md diff --git a/Cargo.lock b/Cargo.lock index 181d157..9cd3087 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,41 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +dependencies = [ + "memchr", +] + +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anyhow" version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -26,6 +55,29 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +[[package]] +name = "bindgen" +version = "0.59.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "clap", + "env_logger", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -56,12 +108,47 @@ version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clang-sys" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "ansi_term", + "atty", + "bitflags", + "strsim", + "textwrap", + "unicode-width", + "vec_map", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -78,6 +165,61 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + +[[package]] +name = "enum-repr" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bad30c9c0fa1aaf1ae5010dab11f1117b15d35faf62cda4bbbc53b9987950f18" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin 0.9.8", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + [[package]] name = "getrandom" version = "0.2.9" @@ -85,8 +227,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", ] [[package]] @@ -98,6 +257,12 @@ dependencies = [ "libc", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "js-sys" version = "0.3.63" @@ -113,12 +278,38 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc86cde3ff845662b8f4ef6cb50ea0e20c524eb3d29ae048287e06a1b3fa6a81" +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.18" @@ -134,6 +325,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + [[package]] name = "memoffset" version = "0.7.1" @@ -143,6 +340,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "0.8.8" @@ -154,6 +357,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "nix" version = "0.26.2" @@ -168,6 +380,16 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -193,7 +415,7 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" dependencies = [ - "hermit-abi", + "hermit-abi 0.2.6", "libc", ] @@ -215,12 +437,40 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "pam-client" +version = "0.5.0" +dependencies = [ + "bitflags", + "enum-repr", + "libc", + "pam-sys", + "rustversion", + "serde", +] + +[[package]] +name = "pam-sys" +version = "1.0.0-alpha4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e9dfd42858f6a6bb1081079fd9dc259ca3e2aaece6cb689fd36b1058046c969" +dependencies = [ + "bindgen", + "libc", +] + [[package]] name = "paste" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem" version = "1.1.1" @@ -230,6 +480,26 @@ dependencies = [ "base64 0.13.1", ] +[[package]] +name = "pin-project" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -262,16 +532,22 @@ name = "quic-shell" version = "0.1.0" dependencies = [ "anyhow", + "bytes", + "flume", "libc", "nix", + "pam-client", + "pin-project-lite", "quinn", "rcgen", "rmp-serde", + "rpassword", "rustls", "serde", "tokio", "tracing", "tracing-subscriber", + "triggered", ] [[package]] @@ -379,6 +655,8 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" dependencies = [ + "aho-corasick", + "memchr", "regex-syntax 0.7.1", ] @@ -412,7 +690,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", @@ -440,6 +718,27 @@ dependencies = [ "serde", ] +[[package]] +name = "rpassword" +version = "7.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6678cf63ab3491898c0d021b493c94c9b221d91295294a2a5746eacbe5928322" +dependencies = [ + "libc", + "rtoolbox", + "winapi", +] + +[[package]] +name = "rtoolbox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "034e22c514f5c0cb8a10ff341b9b048b5ceb21591f31c8f44c43b960f9b3524a" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -488,6 +787,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" + [[package]] name = "schannel" version = "0.1.21" @@ -497,6 +802,12 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "sct" version = "0.7.0" @@ -547,7 +858,7 @@ checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.18", ] [[package]] @@ -559,6 +870,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -609,12 +926,38 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "static_assertions" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.18" @@ -626,6 +969,24 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.40" @@ -643,7 +1004,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.18", ] [[package]] @@ -713,7 +1074,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.18", ] [[package]] @@ -737,7 +1098,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.18", ] [[package]] @@ -779,12 +1140,24 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triggered" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce148eae0d1a376c1b94ae651fc3261d9cb8294788b962b7382066376503a2d1" + [[package]] name = "unicode-ident" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + [[package]] name = "untrusted" version = "0.7.1" @@ -797,6 +1170,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -824,7 +1203,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.18", "wasm-bindgen-shared", ] @@ -846,7 +1225,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.18", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -867,6 +1246,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" +dependencies = [ + "either", + "libc", + "once_cell", +] + [[package]] name = "winapi" version = "0.3.9" @@ -883,6 +1273,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 9835aae..9ec45f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,14 +7,20 @@ edition = "2021" [dependencies] anyhow = "1.0.71" +bytes = "1.4.0" +flume = "0.10.14" libc = "0.2.145" nix = "0.26.2" +pam-client = { version = "0.5.0", path = "../../../../../Users/michael/b/rust-pam-client", default-features = false, features = ["serde"] } +pin-project-lite = "0.2.9" quinn = "0.10.1" rcgen = "0.10.0" rmp-serde = "1.1.1" +rpassword = "7.2.0" rustls = { version = "0.21.1", default-features = false } serde = { version = "1.0.163", features = ["derive"] } #termion = "2.0.1" tokio = { version = "1.28.2", default-features = false, features = ["rt-multi-thread", "macros", "process", "io-util", "io-std", "time", "fs", "signal"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +triggered = "0.1.2" diff --git a/README.md b/README.md new file mode 100644 index 0000000..7a88025 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +## Resources + +Some resources used in the development of the program. + +- https://www.linuxquestions.org/questions/programming-9/how-to-set-pseudo-terminal-non-blocking-690846/ +- https://meli.delivery/posts/2019-10-25-making-a-quick-and-dirty-terminal-emulator.html +- https://viewsourcecode.org/snaptoken/kilo/02.enteringRawMode.html +- https://www.digitalocean.com/community/tutorials/understanding-the-ssh-encryption-and-connection-process diff --git a/src/main.rs b/src/main.rs index 2d16f0a..07d194a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#[deny(unused_must_use)] + #[macro_use] extern crate anyhow; #[macro_use] @@ -7,22 +9,28 @@ extern crate tracing; mod pty; -use std::ffi::CStr; +use std::ffi::{CStr, CString}; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::net::SocketAddr; use std::os::fd::FromRawFd; use std::os::unix::process::CommandExt; use std::process::Stdio; use std::sync::Arc; +use std::task::Poll; use anyhow::{Context, Result}; -use quinn::{RecvStream, SendStream}; +use pam_client::ConversationHandler; +use quinn::{RecvStream, SendStream, ReadExactError}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::process::{Child, Command}; +use tokio::process::Command; use tracing::Instrument; #[derive(Debug, Clone, Serialize, Deserialize)] enum Stream { + Exec, Shell, - Heartbeat, + // TODO: port forwarding } #[tokio::main] @@ -42,7 +50,10 @@ async fn main() -> Result<()> { let fut = run_cmd(args); tokio::select! { - _ = ctrl_c => Ok(()), + _ = ctrl_c => { + info!("Aborting"); + Ok(()) + } r = fut => r, } } @@ -51,31 +62,47 @@ async fn run_cmd(mut args: std::env::Args) -> Result<()> { let cmd = args.next().expect("COMMAND"); match cmd.as_str() { "server" => run_server().await, - "client" => run_client().await, + "client" => run_client(args).await, _ => Err(anyhow!("Unrecognized command: {}", cmd)), } } const ALPN_QUIC_SHELL: &str = "quic-shell"; +struct ServerConfig { + shell: String, + listen: SocketAddr, +} + async fn run_server() -> Result<()> { - let opt_shell = &*Box::leak( - std::env::var("SHELL") - .context("SHELL not defined")? - .into_boxed_str(), - ); - let opt_listen = std::env::var("BIND_ADDR") - .unwrap_or_else(|_| "127.0.0.1:8022".to_owned()) - .parse()?; + let cfg = { + let opt_shell = std::env::var("SHELL") + .context("SHELL not defined")?; + let opt_listen = std::env::var("BIND_ADDR") + .unwrap_or_else(|_| "127.0.0.1:8022".to_owned()) + .parse()?; + + &*Box::leak(Box::new(ServerConfig { + shell: opt_shell, + listen: opt_listen, + })) + }; let subject_alt_names = vec!["localhost".to_string()]; - let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; - let key = rustls::PrivateKey(cert.serialize_private_key_der()); - let cert = rustls::Certificate(cert.serialize_der()?); - //std::fs::write("key.der", &key.0)?; - std::fs::write("cert.der", &cert.0)?; + let (cert, key) = if !std::path::Path::new("cert.der").exists() || !std::path::Path::new("key.der").exists() { + let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; + let key = rustls::PrivateKey(cert.serialize_private_key_der()); + let cert = rustls::Certificate(cert.serialize_der()?); + std::fs::write("key.der", &key.0)?; + std::fs::write("cert.der", &cert.0)?; + (cert, key) + } else { + let cert = rustls::Certificate(std::fs::read("cert.der")?); + let key = rustls::PrivateKey(std::fs::read("key.der")?); + (cert, key) + }; let mut server_crypto = rustls::ServerConfig::builder() .with_safe_defaults() @@ -84,19 +111,21 @@ async fn run_server() -> Result<()> { .unwrap(); server_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()]; + let mut transport = transport_config(); + transport.max_concurrent_uni_streams(0_u8.into()); + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); - let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); - transport_config.max_concurrent_uni_streams(0_u8.into()); + server_config.transport_config(transport.into()); server_config.use_retry(true); - let endpoint = quinn::Endpoint::server(server_config, opt_listen)?; - eprintln!("listening on {}", endpoint.local_addr()?); + let endpoint = quinn::Endpoint::server(server_config, cfg.listen)?; + info!("listening on {}", endpoint.local_addr()?); while let Some(conn) = endpoint.accept().await { info!("connection incoming"); tokio::spawn(async move { - if let Err(e) = handle_connection(opt_shell, conn).await { - error!("connection failed: {reason}", reason = e.to_string()) + if let Err(e) = greet_conn(cfg, conn).await { + error!("connection failed: {reason}", reason = e.to_string()); } }); } @@ -112,9 +141,53 @@ fn is_broken_pipe(r: &Result) -> bool { } } -async fn run_client() -> Result<()> { +fn transport_config() -> quinn::TransportConfig { + let mut transport = quinn::TransportConfig::default(); + transport.stream_receive_window((64u32 * 1024 * 1024).into()); + transport.send_window(64 * 1024 * 1024); + transport.receive_window((64u32 * 1024 * 1024).into()); + transport +} + +#[derive(Debug, Clone, Copy)] +struct FinishedEarly; + +impl std::fmt::Display for FinishedEarly { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ReadExactError::FinishedEarly.fmt(f) + } +} + +impl std::error::Error for FinishedEarly { +} + +async fn read_msg(recv: &mut RecvStream) -> Result> { + let mut size = [0u8; 2]; + match recv.read_exact(&mut size).await { + Ok(()) => {} + Err(ReadExactError::FinishedEarly) => return Ok(Err(FinishedEarly)), + Err(ReadExactError::ReadError(e)) => return Err(e.into()), + } + let size = u16::from_le_bytes(size); + let mut buf = Vec::with_capacity(size.into()); + recv.take(size.into()).read_to_end(&mut buf).await?; + Ok(Ok(rmp_serde::from_slice(&buf).with_context(|| format!("reading a {} byte message", size))?)) +} + +async fn write_msg(send: &mut SendStream, value: &T) -> Result<()> { + let buf = rmp_serde::to_vec(value)?; + send.write_all(&u16::try_from(buf.len())?.to_le_bytes()) + .await?; + send.write_all(&buf).await?; + Ok(()) +} + +async fn run_client(mut args: std::env::Args) -> Result<()> { info!("running client"); + let conn_str = args.next().expect("USERNAME@HOST"); + let (username, host) = conn_str.split_once('@').expect("USERNAME@HOST"); + let mut roots = rustls::RootCertStore::empty(); match std::fs::read("cert.der") { Ok(cert) => { @@ -137,7 +210,7 @@ async fn run_client() -> Result<()> { client_crypto.alpn_protocols = vec![ALPN_QUIC_SHELL.as_bytes().to_owned()]; - let mut transport = quinn::TransportConfig::default(); + let mut transport = transport_config(); transport.keep_alive_interval(Some(std::time::Duration::from_secs(5))); let mut client_config = quinn::ClientConfig::new(Arc::new(client_crypto)); @@ -149,23 +222,32 @@ async fn run_client() -> Result<()> { info!("connecting"); let conn = endpoint - .connect("127.0.0.1:8022".parse()?, "localhost")? + .connect(host.parse()?, "localhost")? .await?; - let (mut send, mut recv) = conn.open_bi().await?; - write_header(&mut send, Stream::Shell).await?; - - info!("connected"); - - let mut stdin = unsafe { tokio::fs::File::from_raw_fd(libc::STDIN_FILENO) }; - let mut stdout = unsafe { tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO) }; - let mut stdin_buf = Vec::with_capacity(4096); - let mut stdout_buf = Vec::with_capacity(4096); - - let mut stdin_eof = false; + // authenticating client { + let (mut send, mut recv) = conn.open_bi().await?; + write_msg(&mut send, &auth::Hello { + username: username.to_owned(), + }).await?; + do_auth_prompt(&conn, &mut send, &mut recv).await?; + } + + // authenticated client + + let _reset = { use nix::sys::termios::*; + + struct Reset(Termios); + impl Drop for Reset { + fn drop(&mut self) { + _ = tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &self.0); + info!("termios reset!"); + } + } + let mut termios = tcgetattr(libc::STDIN_FILENO)?; termios.local_flags.remove(LocalFlags::ECHO); termios.local_flags.remove(LocalFlags::ICANON); @@ -174,19 +256,35 @@ async fn run_client() -> Result<()> { termios.input_flags.remove(InputFlags::IXON); termios.input_flags.remove(InputFlags::ICRNL); termios.output_flags.remove(OutputFlags::OPOST); + let reset = Reset(termios.clone()); tcsetattr(libc::STDIN_FILENO, SetArg::TCSAFLUSH, &termios)?; - } + reset + }; - /*let mut heartbeat = { - let (mut send, recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await?; - write_header(&mut send, Stream::Heartbeat).await?; - Box::pin(handle_stream_heartbeat(send, recv)) - };*/ + write_msg(&mut send, &Stream::Shell).await?; + + info!("connected"); + + do_shell(&conn, &mut send, &mut recv).await +} + +async fn do_shell(conn: &quinn::Connection, send: &mut SendStream, recv: &mut RecvStream) -> Result<()> { + let mut stdin = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDIN_FILENO)) }; + let mut stdout = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) }; + let mut stdin_buf = Vec::with_capacity(4096); + //let mut stdout_buf = Vec::with_capacity(4096); + let mut stdout_buf = vec![bytes::Bytes::new(); 128]; + + let mut stdin_eof = false; loop { tokio::select! { - //_ = &mut heartbeat => {} + /*r = tokio::io::copy(&mut stdin, &mut send) => { + r?; + info!("EOF on stdin"); + }*/ r = stdin.read_buf(&mut stdin_buf), if !stdin_eof => { if r? == 0 { stdin_eof = true; @@ -194,12 +292,20 @@ async fn run_client() -> Result<()> { send.write_all(&stdin_buf).await?; stdin_buf.clear(); //info!("sent stdin"); - }, - r = recv.read_buf(&mut stdout_buf) => if r? > 0 { + } + r = recv.read_chunks(&mut stdout_buf) => { + if let Some(n) = r? { + for chunk in &stdout_buf[..n] { + stdout.write_all(&chunk).await?; + } + //info!("recv stdout"); + } + } + /*r = recv.read_buf(&mut stdout_buf) => if r? > 0 { stdout.write_all(&stdout_buf).await?; stdout_buf.clear(); //info!("recv stdout"); - }, + }*/, r = send.stopped() => { info!("Remote disconnected"); let code = r?.into_inner(); @@ -211,30 +317,67 @@ async fn run_client() -> Result<()> { } e = conn.closed() => { info!("Remote disconnected: {}", e); - return Ok(()); + return Err(anyhow!("Remote connection closed")); } } } } -async fn write_header(send: &mut SendStream, header: Stream) -> Result<()> { - let buf = rmp_serde::to_vec(&header)?; - send.write_all(&u16::try_from(buf.len())?.to_le_bytes()) - .await?; - send.write_all(&buf).await?; - Ok(()) +async fn do_auth_prompt(conn: &quinn::Connection, send: &mut SendStream, recv: &mut RecvStream) -> Result<()> { + use auth::*; + + let mut stdout = unsafe { ManuallyDrop::new(tokio::fs::File::from_raw_fd(libc::STDOUT_FILENO)) }; + + loop { + tokio::select! { + r = read_msg::(recv) => { + match r?? { + Question::LoggedIn => { + return Ok(()); + } + Question::Prompt { + prompt, + echo, + } => { + let mut prompt = prompt.into_bytes(); + prompt.push(b' '); + stdout.write_all(&prompt).await?; + let answer = rpassword::read_password()?; + let answer = CString::new(answer)?; + write_msg(send, &Answer::Prompt(Ok(answer))).await?; + }, + Question::TextInfo(s) => { + stdout.write_all(b"INFO ").await?; + stdout.write_all(s.as_bytes()).await?; + stdout.write_all(b"\n").await?; + }, + Question::ErrorMsg(s) => { + stdout.write_all(b"ERRO ").await?; + stdout.write_all(s.as_bytes()).await?; + stdout.write_all(b"\n").await?; + }, + } + } + r = send.stopped() => { + info!("Remote disconnected"); + let code = r?.into_inner(); + if code == 0 { + return Ok(()); + } else { + return Err(anyhow!("Error code {}", code)); + } + } + e = conn.closed() => { + info!("Remote disconnected: {}", e); + return Err(anyhow!("Remote connection closed")); + } + } + } } -async fn read_header(recv: &mut RecvStream) -> Result { - let mut size = [0u8; 2]; - recv.read_exact(&mut size).await?; - let size = u16::from_le_bytes(size); - let mut buf = Vec::with_capacity(size.into()); - recv.take(size.into()).read_to_end(&mut buf).await?; - Ok(rmp_serde::from_slice(&buf)?) -} +async fn greet_conn(cfg: &'static ServerConfig, conn: quinn::Connecting) -> Result<()> { + info!("greeting connection"); -async fn handle_connection(opt_shell: &'static str, conn: quinn::Connecting) -> Result<()> { let conn = conn.await?; let span = info_span!( "connection", @@ -246,70 +389,503 @@ async fn handle_connection(opt_shell: &'static str, conn: quinn::Connecting) -> .protocol .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) ); - async { - info!("established"); - loop { - let stream = conn.accept_bi().await; - let (send, mut recv) = match stream { - Err(quinn::ConnectionError::ApplicationClosed { .. }) => { - info!("connection closed"); - return Ok(()); - } - Err(e) => { - return Err(e.into()); - } - Ok(s) => s, - }; - - let stream = read_header(&mut recv).await?; - let span = info_span!( - "stream", - r#type = ?stream - ); - tokio::task::spawn( - async move { - let r = match stream { - Stream::Shell => handle_stream_shell(opt_shell, send, recv).await, - Stream::Heartbeat => handle_stream_heartbeat(send, recv).await, - }; - if let Err(e) = r { - error!("Error in stream handler: {e}"); - } - } - .instrument(span), - ); - } + if let Err(e) = authenticate_conn(cfg, &conn) + .instrument(span) + .await { + error!("handler failed: {reason}", reason = e.to_string()); + conn.close(1u8.into(), b"handler error"); } - .instrument(span) - .await + + Ok(()) } -async fn handle_stream_shell( - opt_shell: &str, +mod auth { + use std::ffi::CString; + + #[derive(Debug, Serialize, Deserialize)] + pub struct Hello { + pub username: String, + } + + #[derive(Debug, Serialize, Deserialize)] + pub enum Question { + Prompt { + prompt: CString, + echo: bool, + }, + TextInfo(CString), + ErrorMsg(CString), + LoggedIn, + } + + #[derive(Debug, Serialize, Deserialize)] + pub enum Answer { + Prompt(Result), + } +} + +async fn authenticate_conn(cfg: &'static ServerConfig, conn: &quinn::Connection) -> Result<()> { + use auth::*; + + info!("authenticating connection"); + + let (mut send, mut recv) = conn.accept_bi().await?; + + let hello = read_msg::(&mut recv).await??; + + let (q_send, q_recv) = flume::bounded(1); + let (a_send, a_recv) = flume::bounded(1); + + struct Conversation { + send: flume::Sender, + recv: flume::Receiver, + } + + impl Conversation { + const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); + + fn ask(&self, question: Question) -> Result<(), pam_client::ErrorCode> { + self.send.send_timeout(question, Self::TIMEOUT).map_err(|_| pam_client::ErrorCode::ABORT) + } + + fn answer(&self) -> Result { + self.recv.recv_timeout(Self::TIMEOUT).map_err(|_| pam_client::ErrorCode::ABORT) + } + } + + impl ConversationHandler for Conversation { + fn prompt_echo_on(&mut self, prompt: &CStr) -> std::result::Result { + self.ask(Question::Prompt { + prompt: prompt.to_owned(), + echo: true, + })?; + match self.answer()? { + Answer::Prompt(r) => r, + } + } + + fn prompt_echo_off(&mut self, prompt: &CStr) -> std::result::Result { + self.ask(Question::Prompt { + prompt: prompt.to_owned(), + echo: false, + })?; + match self.answer()? { + Answer::Prompt(r) => r, + } + } + + fn text_info(&mut self, msg: &CStr) { + _ = self.ask(Question::TextInfo(msg.to_owned())); + } + + fn error_msg(&mut self, msg: &CStr) { + _ = self.ask(Question::ErrorMsg(msg.to_owned())); + } + } + + let hdl = tokio::task::spawn_blocking(move || { + let mut ctx = pam_client::Context::new("sshd", Some(&hello.username), Conversation { + send: q_send, + recv: a_recv, + })?; + info!("created context"); + + ctx.authenticate(pam_client::Flag::NONE)?; + info!("authenticated user"); + + ctx.acct_mgmt(pam_client::Flag::NONE)?; + info!("validated user"); + + let sess = ctx.open_session(pam_client::Flag::NONE)?; + info!("opened session"); + let sess = sess.leak(); + + let conv = ctx.conversation_mut(); + conv.send = flume::bounded(0).0; + conv.recv = flume::bounded(0).1; + Result::<_>::Ok((ctx, sess)) + }); + + while let Ok(question) = q_recv.recv_async().await { + debug!("received question: {:?}", question); + write_msg(&mut send, &question).await?; + if matches!(question, Question::Prompt { .. }) { + let answer = read_msg(&mut recv).await??; + trace!("received answer: {:?}", answer); + a_send.send_async(answer).await?; + } + /*match question { + Question::Prompt { prompt, echo } => { + let r = async { + // FIXME: actually disable echo + send.write_all(prompt.as_bytes()).await?; + send.write_all(b" ").await?; + let erase = format!("\x1b[{}G\x1b[K", prompt.as_bytes().len() + 1 + 1); + let mut buf = Vec::new(); + 'prompt: loop { + let mut i = buf.len(); + recv.read_buf(&mut buf).await?; + let mut j = i; + while j < buf.len() { + match buf[j] { + 0x7f => { + buf.remove(j); + if j > 0 { + if j == i { + i -= 1; + } + buf.remove(j-1); + // erase in line, move cursor left 1 column + send.write_all(erase.as_bytes()).await?; + send.write_all(&buf).await?; + j -= 1; + } + } + 0x3 => { + send.write_all(b"\r\n").await?; + return Err(anyhow!("Aborted by the user")); + } + b'\r' => { + buf.remove(j); + } + b'\n' => { + info!("found \\n"); + // remove newline and trailing chars + buf.truncate(j); + break 'prompt; + } + _ => { + j += 1; + } + } + } + let seg = &buf[i..]; + if echo { + send.write_all(&seg).await?; + } else { + send.write_all(&vec![b'*'; seg.len()]).await?; + } + info!("{:?} ({:x?})", std::str::from_utf8(&buf), buf); + } + let buf = CString::new(buf)?; + send.write_all(b"\n").await?; + Result::<_>::Ok(buf) + }.await; + a_send.send_async(Answer::Prompt(r.map_err(|e| { + error!("PAM error: {}", e); + pam_client::ErrorCode::ABORT + }))).await?; + } + Question::TextInfo(s) => { + send.write_all(b"INFO ").await?; + send.write_all(s.as_bytes()).await?; + send.write_all(b"\n").await?; + } + Question::ErrorMsg(s) => { + send.write_all(b"ERRO ").await?; + send.write_all(s.as_bytes()).await?; + send.write_all(b"\n").await?; + } + }*/ + } + + + let (mut ctx, sess) = hdl.await??; + let sess = ctx.unleak_session(sess); + info!("logged in: {}", sess.envlist()); + + write_msg(&mut send, &Question::LoggedIn).await?; + send.finish().await?; + recv.stop(0u8.into())?; + + handle_conn(cfg, conn).await +} + +async fn handle_conn(cfg: &'static ServerConfig, conn: &quinn::Connection) -> Result<()> { + info!("established"); + + loop { + let stream = conn.accept_bi().await; + let (send, mut recv) = match stream { + Err(quinn::ConnectionError::ApplicationClosed { .. }) => { + info!("connection closed"); + return Ok(()); + } + Err(e) => { + return Err(e.into()); + } + Ok(s) => s, + }; + + let stream = read_msg::(&mut recv).await??; + let span = info_span!( + "stream", + r#type = ?stream + ); + tokio::task::spawn( + async move { + let r = match stream { + Stream::Exec => handle_stream_exec(cfg, send, recv).await, + Stream::Shell => handle_stream_shell(cfg, send, recv).await, + }; + if let Err(e) = r { + error!("Error in stream handler: {e}"); + } + } + .instrument(span), + ); + } +} + +async fn handle_stream_exec( + cfg: &ServerConfig, mut send: SendStream, mut recv: RecvStream, ) -> Result<()> { - let use_pty = true; - if use_pty { - let args = if opt_shell == "bash" || opt_shell.ends_with("/bash") { - vec![CStr::from_bytes_with_nul(b"-i\0")?] - } else { - vec![] - }; - let mut opt_shell_with_nul = Vec::with_capacity(opt_shell.len() + 1); - opt_shell_with_nul.extend(opt_shell.as_bytes()); - opt_shell_with_nul.push(0); - let opt_shell = CStr::from_bytes_with_nul(&opt_shell_with_nul)?; - let mut sh = pty::create_pty(opt_shell, &args)?; - info!("Created pty"); + let mut cmd = std::process::Command::new(&cfg.shell); + if cfg.shell == "bash" || cfg.shell.ends_with("/bash") { + cmd.arg("-i"); + } + cmd.stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .stdin(Stdio::piped()); + #[cfg(target_family = "unix")] + cmd.process_group(0); + info!("Running {:?}", cmd); + let mut sh = Command::from(cmd).kill_on_drop(true).spawn()?; - let mut stdin_buf = Vec::with_capacity(4096); - let mut pty_buf = Vec::with_capacity(4096); - //let mut pty_buf = [0u8]; + let mut stdout = sh.stdout.take().unwrap(); + let mut stderr = sh.stderr.take().unwrap(); + let mut stdin = sh.stdin.take().unwrap(); + let mut stdout_buf = Vec::with_capacity(4096); + let mut stderr_buf = Vec::with_capacity(4096); + let mut stdin_buf = Vec::with_capacity(4096); + let mut stdout_eof = false; + let mut stderr_eof = false; + + loop { + tokio::select! { + r = sh.wait() => { + let code = r?; + send.finish().await?; + if !code.success() { + info!("Child exit: {}", code); + + recv.stop(1u8.into())?; + return Ok(()); + } else { + info!("Child exit"); + recv.stop(0u8.into())?; + return Ok(()); + } + } + r = stdout.read_buf(&mut stdout_buf), if !stdout_eof => { + if is_broken_pipe(&r) || r? == 0 { + stdout_eof = true; + info!("stdout eof"); + } else { + send.write_all(&stdout_buf).await?; + info!("sent stdout: {:x?}", stdout_buf); + stdout_buf.clear(); + } + }, + r = stderr.read_buf(&mut stderr_buf), if !stderr_eof => { + if is_broken_pipe(&r) || r? == 0 { + stderr_eof = true; + info!("stderr eof"); + } else { + send.write_all(&stderr_buf).await?; + stderr_buf.clear(); + info!("sent stderr: {:x?}", stderr_buf); + } + }, + r = recv.read_buf(&mut stdin_buf) => if r? > 0 { + stdin.write_all(&stdin_buf).await?; + stdin_buf.clear(); + info!("recv stdin"); + }, + } + } +} + +async fn handle_stream_shell( + cfg: &ServerConfig, + mut send: SendStream, + mut recv: RecvStream, +) -> Result<()> { + let args = if cfg.shell == "bash" || cfg.shell.ends_with("/bash") { + vec![CStr::from_bytes_with_nul(b"-i\0")?] + } else { + vec![] + }; + let mut opt_shell_with_nul = Vec::with_capacity(cfg.shell.len() + 1); + opt_shell_with_nul.extend(cfg.shell.as_bytes()); + opt_shell_with_nul.push(0); + let opt_shell = CStr::from_bytes_with_nul(&opt_shell_with_nul)?; + let mut sh = pty::create_pty(opt_shell, &args)?; + sh.set_nodelay()?; + let mut pty = sh.pty; + //let mut pty = tokio::io::unix::AsyncFd::with_interest(pty, tokio::io::Interest::READABLE)?; + info!("created pty"); + + //let mut stdin_buf = Vec::with_capacity(4096); + let mut pty_buf = Vec::with_capacity(4096); + //let mut pty_buf = [0u8]; + + let (_waker_send, mut waker_recv) = flume::bounded::<()>(1); + + /*let fd = pty.as_raw_fd(); + std::thread::spawn(move || { + let mut set = [nix::poll::PollFd::new(fd.as_raw_fd(), nix::poll::PollFlags::POLLIN)]; loop { - if let Some(code) = sh.try_wait()? { + if let Ok(n) = nix::poll::poll(&mut set, -1) { + if n != 0 { + if waker_send.send(()).is_err() { + break; + } + } + } + } + });*/ + + loop { + /*if let Some(code) = sh.proc.try_wait()? { + send.finish().await?; + if code != 0 { + info!("Child exit: {}", code); + + recv.stop(1u8.into())?; + return Ok(()); + } else { + info!("Child exit"); + recv.stop(0u8.into())?; + return Ok(()); + } + }*/ + + //let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50)); + + struct Wait<'a> { + proc: &'a pty::Proc, + } + + impl<'a> Wait<'a> { + pub fn new(proc: &'a pty::Proc) -> Self { + Self { proc } + } + } + + impl<'a> Future for Wait<'a> { + type Output = std::io::Result; + + fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll { + match self.proc.try_wait() { + Ok(Some(code)) => Poll::Ready(Ok(code)), + Ok(None) => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } + } + } + + /*pin_project_lite::pin_project! { + struct SelectRead { + //#[allow(dead_code)] + fd: T, + #[pin] + list: triggered::Listener, + } + } + + impl SelectRead { + pub fn new(fd: T) -> Self { + let (trig, list) = triggered::trigger(); + Self { fd, list, set } + } + } + + impl Future for SelectRead { + type Output = nix::Result<()>; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + //let mut this = self; + let this = self.project(); + match this.list.poll(cx) { + Poll::Ready(()) => { + let r = nix::poll::poll(this.set, 0); + match r { + Ok(0) => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Ok(_) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(e)), + } + }, + Poll::Pending => { + //cx.waker().wake_by_ref(); + Poll::Pending + }, + } + //let mut set = *this.set; + /*let r = nix::sys::select::select( + set.highest().unwrap() + 1, + &mut set, + None, + None, + &mut nix::sys::time::TimeVal::new(0, 0), + );*/ + } + }*/ + + async fn read_pty( + //pty: &mut tokio::io::unix::AsyncFd, + pty: &mut tokio::fs::File, + _waker_recv: &mut flume::Receiver<()>, + buf: &mut Vec, + send: &mut SendStream, + ) -> Result<()> { + loop { + //let mut pty = pty.readable_mut().await?; + //let pty = pty.get_inner_mut(); + //SelectRead::new(pty.as_fd()).await?; + //_ = waker_recv.recv_async().await; + let r = pty.read_buf(buf).await; + //_ = waker_recv.try_recv(); + if let Err(e) = r { + if e.raw_os_error() == Some(35) { + //info!("not ready: {}", e); + //tokio::task::yield_now().await; + //tokio::time::sleep(std::time::Duration::from_millis(1)).await; + } else { + return Err(e.into()); + } + } else if buf.len() == 0 { + info!("not ready: empty"); + //tokio::task::yield_now().await; + //tokio::time::sleep(std::time::Duration::from_millis(1)).await; + } else { + //return Ok(()); + send.write_all(&buf).await?; + buf.clear(); + //info!("sent pty"); + } + } + } + + tokio::select! { + /*_ = redraw.tick() => { + sh.pty.read_buf(&mut pty_buf).await?; + send.write_all(&pty_buf).await?; + pty_buf.clear(); + info!("redraw complete"); + }*/ + r = Wait::new(&sh.proc) => { + let code = r?; send.finish().await?; if code != 0 { info!("Child exit: {}", code); @@ -322,108 +898,25 @@ async fn handle_stream_shell( return Ok(()); } } - - //let mut redraw = tokio::time::interval(std::time::Duration::from_millis(50)); - - tokio::select! { - /*_ = redraw.tick() => { - sh.pty.read_buf(&mut pty_buf).await?; - send.write_all(&pty_buf).await?; - pty_buf.clear(); - info!("redraw complete"); - }*/ - r = sh.pty.read_buf(&mut pty_buf) => { - //r = sh.pty.read_exact(&mut pty_buf) => { - if let Err(e) = r { - if e.raw_os_error() != Some(35) { - return Err(e.into()); - } - } - if pty_buf.len() > 0 { - send.write_all(&pty_buf).await?; - pty_buf.clear(); - info!("sent pty"); - } - } - r = recv.read_buf(&mut stdin_buf) => if r? > 0 { - sh.pty.write_all(&stdin_buf).await?; - stdin_buf.clear(); - info!("recv stdin"); - }, + /*r = tokio::io::copy(&mut pty, &mut send) => { + r?; + }*/ + r = read_pty(&mut pty, &mut waker_recv, &mut pty_buf, &mut send) => { + r?; } - } - } else { - let mut cmd = std::process::Command::new(opt_shell); - if opt_shell == "bash" || opt_shell.ends_with("/bash") { - cmd.arg("-i"); - } - cmd.stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .stdin(Stdio::piped()); - #[cfg(target_family = "unix")] - cmd.process_group(0); - info!("Running {:?}", cmd); - let mut sh = Command::from(cmd).kill_on_drop(true).spawn()?; - - let mut stdout = sh.stdout.take().unwrap(); - let mut stderr = sh.stderr.take().unwrap(); - let mut stdin = sh.stdin.take().unwrap(); - let mut stdout_buf = Vec::with_capacity(4096); - let mut stderr_buf = Vec::with_capacity(4096); - let mut stdin_buf = Vec::with_capacity(4096); - - let mut stdout_eof = false; - let mut stderr_eof = false; - - loop { - tokio::select! { - r = sh.wait() => { - let code = r?; - send.finish().await?; - if !code.success() { - info!("Child exit: {}", code); - - recv.stop(1u8.into())?; - return Ok(()); - } else { - info!("Child exit"); - recv.stop(0u8.into())?; - return Ok(()); - } - } - r = stdout.read_buf(&mut stdout_buf), if !stdout_eof => { - if is_broken_pipe(&r) || r? == 0 { - stdout_eof = true; - info!("stdout eof"); - } else { - send.write_all(&stdout_buf).await?; - info!("sent stdout: {:x?}", stdout_buf); - stdout_buf.clear(); - } - }, - r = stderr.read_buf(&mut stderr_buf), if !stderr_eof => { - if is_broken_pipe(&r) || r? == 0 { - stderr_eof = true; - info!("stderr eof"); - } else { - send.write_all(&stderr_buf).await?; - stderr_buf.clear(); - info!("sent stderr: {:x?}", stderr_buf); - } - }, - r = recv.read_buf(&mut stdin_buf) => if r? > 0 { - stdin.write_all(&stdin_buf).await?; - stdin_buf.clear(); + // FIXME: figure out a maximum chunk size + r = recv.read_chunk(usize::MAX, true) => { + if let Some(chunk) = r? { + pty.write_all(&chunk.bytes).await?; info!("recv stdin"); - }, + } } + /*r = recv.read_buf(&mut stdin_buf) => if r? > 0 { + //pty.get_mut().write_all(&stdin_buf).await?; + pty.write_all(&stdin_buf).await?; + stdin_buf.clear(); + info!("recv stdin"); + },*/ } } } - -async fn handle_stream_heartbeat(mut send: SendStream, _recv: RecvStream) -> Result<()> { - loop { - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - send.write_all(&[0u8]).await?; - } -} diff --git a/src/pty.rs b/src/pty.rs index 34b5722..39e0b82 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -9,9 +9,8 @@ use libc::{STDERR_FILENO, STDIN_FILENO, STDOUT_FILENO, TIOCSCTTY}; use libc::TIOCSWINSZ; use std::ffi::CStr; use std::fs::File; -use std::os::unix::process::CommandExt; +use std::os::fd::AsRawFd; use std::path::Path; -use std::process::{Command, Stdio}; use std::os::unix::io::{FromRawFd, IntoRawFd}; @@ -23,15 +22,23 @@ ioctl_none_bad!(set_controlling_terminal, TIOCSCTTY); pub struct Child { pub pty: tokio::fs::File, - pub pid: Pid, + pub proc: Proc, } impl Child { + pub fn set_nodelay(&mut self) -> nix::Result<()> { + fcntl(self.pty.as_raw_fd(), FcntlArg::F_SETFL(OFlag::O_NDELAY)).map(|_|()) + } +} + +pub struct Proc(pub Pid); + +impl Proc { // copied from // https://doc.rust-lang.org/nightly/src/std/sys/unix/process/process_unix.rs.html#744-757 pub fn try_wait(&self) -> std::io::Result> { let mut status = 0; - let pid = cvt(unsafe { libc::waitpid(self.pid.as_raw(), &mut status, libc::WNOHANG) })?; + let pid = cvt(unsafe { libc::waitpid(self.0.as_raw(), &mut status, libc::WNOHANG) })?; if pid == 0 { Ok(None) } else { @@ -40,9 +47,9 @@ impl Child { } } -impl Drop for Child { +impl Drop for Proc { fn drop(&mut self) { - _ = nix::sys::signal::kill(self.pid, nix::sys::signal::Signal::SIGTERM); + _ = nix::sys::signal::kill(self.0, nix::sys::signal::Signal::SIGTERM); } } @@ -60,7 +67,7 @@ pub fn create_pty + std::fmt::Debug>(path: &CStr, argv: &[S]) -> /* Try to open the slave */ let _slave_fd = open(Path::new(&slave_name), OFlag::O_RDWR, Mode::empty())?; - info!("master opened the slave_fd!"); + trace!("master opened the slave_fd!"); /* Launch our child process. The main application loop can inspect and then pass the stdin data to it. */ @@ -82,11 +89,10 @@ pub fn create_pty + std::fmt::Debug>(path: &CStr, argv: &[S]) -> let master_fd = master_fd.into_raw_fd(); /* Tell the master the size of the terminal */ unsafe { set_window_size(master_fd, &winsize)? }; - fcntl(master_fd, FcntlArg::F_SETFL(OFlag::O_NDELAY)).unwrap(); let master_file = unsafe { File::from_raw_fd(master_fd) }; Ok(Child { pty: master_file.into(), - pid: child_pid, + proc: Proc(child_pid), }) } @@ -97,7 +103,8 @@ fn init_child + std::fmt::Debug>( ) -> anyhow::Result { /* Open slave end for pseudoterminal */ let slave_fd = open(Path::new(&slave_name), OFlag::O_RDWR, Mode::empty())?; - info!("child opened the slave_fd!"); + trace!("child opened the slave_fd!"); + debug!("we are going to execute: {:?} {:?}", path, argv); // assign stdin, stdout, stderr to the tty nix::unistd::dup2(slave_fd, STDIN_FILENO)?; @@ -107,7 +114,7 @@ fn init_child + std::fmt::Debug>( nix::unistd::setsid().unwrap(); unsafe { set_controlling_terminal(slave_fd) }.unwrap(); - info!("running exec: {:?} {:?}", path, argv); + trace!("running exec"); Ok(nix::unistd::execv(path, argv)?) }