use std::borrow::Cow; use std::path::{Path, PathBuf}; use miette::{IntoDiagnostic, Result, WrapErr}; use tokio::io::AsyncWriteExt; use crate::Context; pub fn get_username() -> Result { const CAPACITY: u32 = 32; let mut len: u32 = CAPACITY; const LAYOUT: std::alloc::Layout = unsafe { std::alloc::Layout::from_size_align_unchecked(CAPACITY as usize, 1) }; let ptr = unsafe { std::alloc::alloc(LAYOUT) }; ensure!(!ptr.is_null(), "Buffer allocation failed"); let success = unsafe { windows_sys::Win32::System::WindowsProgramming::GetUserNameA(ptr, &mut len as *mut u32) == 1 }; ensure!(success, "GetUserNameA failed"); assert!(len <= CAPACITY, "Buffer overflow caught"); String::from_utf8(unsafe { Vec::from_raw_parts(ptr, len as usize, CAPACITY as usize) }) .into_diagnostic() } /// A sequential pipeline of [`Step`]s. #[derive(Debug, Clone, Copy)] pub struct Pipeline<'a> { pub name: Option<&'a str>, pub steps: &'a [Step<'a>], } impl<'a> Pipeline<'a> { #[inline(always)] pub fn new(name: &'a str, steps: &'a [Step<'a>]) -> Self { Self::of(steps).named(name) } #[inline(always)] pub fn of(steps: &'a [Step<'a>]) -> Self { Self { name: None, steps } } #[inline(always)] pub fn named(mut self, name: &'a str) -> Self { self.name = Some(name); self } pub async fn invoke(&self, ctx: &Context) -> Result<()> { if let Some(name) = self.name { println!("Invoking {name}..."); } for step in self.steps.into_iter() { step.invoke(ctx).await?; } Ok(()) } } #[derive(Debug)] pub enum Step<'a> { DownloadFile { /// Remote resourcee to download from. res: RemoteResource<'a>, /// Path to save the downloaded file to. file: Cow<'a, Path>, }, /// Extracts the file using the `tar` binary that ships with Windows. ExtractFile { file: Cow<'a, Path>, dest: Cow<'a, Path>, }, ExecuteCommand { /// Path to the executable file. file: Cow<'a, Path>, args: &'a [&'a str], }, InstallMsi { file: Cow<'a, Path>, props: Cow<'a, [Cow<'a, str>]>, }, CreateDirectory { target: Cow<'a, Path>, parents: bool, }, CreateShortcut { /// Target of the shortcut (i.e. what is points to). target: ShortcutTarget<'a>, /// Path of the created shortcut file. file: Cow<'a, Path>, }, /// Nukes a publicly accessible directory and locks it from further modification. Nuke { target: Cow<'a, Path> }, /// Executes the steps concurrently. Concurrent(&'a [Pipeline<'a>]), /// Appends the path to the user-wide PATH environment variable. AppendPath(Cow<'a, Path>), } impl<'a> Clone for Step<'a> { #[inline] fn clone(&self) -> Self { match self { Self::Concurrent(pipelines) => Self::Concurrent(pipelines.clone()), Self::DownloadFile { res, file } => Self::DownloadFile { res: res.clone(), file: file.clone(), }, Self::ExtractFile { file, dest } => Self::ExtractFile { file: file.clone(), dest: dest.clone(), }, Self::ExecuteCommand { file, args } => Self::ExecuteCommand { file: file.clone(), args: args.clone(), }, Self::InstallMsi { file, props } => Self::InstallMsi { file: file.clone(), props: props.clone(), }, Self::CreateDirectory { target, parents } => Self::CreateDirectory { target: target.clone(), parents: *parents, }, Self::CreateShortcut { target, file } => Self::CreateShortcut { target: target.clone(), file: file.clone(), }, Self::Nuke { target } => Self::Nuke { target: target.clone(), }, Self::AppendPath(path) => Self::AppendPath(path.clone()), } } } impl<'a> Step<'a> { #[inline] pub async fn invoke(&self, ctx: &Context) -> Result<()> { match self { Self::Concurrent(sequences) => { println!("Executing concurrent steps..."); if let Err(e) = invoke_parallel(ctx, sequences).await { return Err(e); } } Self::DownloadFile { res, file } => { if file.exists() { println!( "File {file} already downloaded.", file = file.to_str().unwrap_or("") ); } else { println!( "Downloading file {file} from {res:?}...", file = file.to_str().unwrap_or("") ); const FETCH_FILE_ERROR_MSG: &'static str = "Fetching the remote resource failed."; const WRITE_FILE_ERROR_MSG: &'static str = "Writing the remote resource to disk failed."; let url = match res { RemoteResource::Url(url) => url::Url::parse(url) .into_diagnostic() .wrap_err("Invalid url for download step.")?, RemoteResource::GitHubArtifact { repo, pattern } => { let mut release = fetch_latest_release(&ctx.reqwest, repo).await?; let pattern = ramhorns::Template::new(*pattern) .into_diagnostic() .wrap_err( "Invalid pattern for artifact matching in download step.", )?; release.meta.tag_name_strip_prefix = release .meta .tag_name .strip_prefix('v') .unwrap_or(&release.meta.tag_name); let asset_name = pattern.render(&release.meta); let artifact = release.assets.into_iter().filter(move |asset| asset.name == asset_name).next().ok_or_else(|| miette!("No artifact of the latest release matched the pattern in download step."))?; url::Url::parse(&artifact.browser_download_url) .into_diagnostic() .wrap_err( "Invalid url returned by GitHub for latest release artifact.", )? } }; let mut resp = ctx .reqwest .get(url) .send() .await .into_diagnostic() .wrap_err(FETCH_FILE_ERROR_MSG)?; let _content_length = resp.content_length(); mkdir_all(file.parent().ok_or_else(|| { miette!("Destination file for download step has no parent.") })?) .await?; let mut writer = tokio::io::BufWriter::new( tokio::fs::File::create(file.as_os_str()) .await .into_diagnostic() .wrap_err(WRITE_FILE_ERROR_MSG)?, ); while let Some(mut chunk) = resp .chunk() .await .into_diagnostic() .wrap_err(FETCH_FILE_ERROR_MSG)? { writer .write_all_buf(&mut chunk) .await .into_diagnostic() .wrap_err(WRITE_FILE_ERROR_MSG)?; } writer .flush() .await .into_diagnostic() .wrap_err(WRITE_FILE_ERROR_MSG)?; } } Self::ExtractFile { file, dest } => { println!( "Extracting {file} to {dest}...", file = file.to_str().unwrap_or(""), dest = dest.to_str().unwrap_or("") ); const EXTRACT_FILE_ERROR_MSG: &'static str = "Extracting file failed."; mkdir_all(&dest).await.wrap_err(EXTRACT_FILE_ERROR_MSG)?; let dest = tokio::fs::canonicalize(&dest) .await .into_diagnostic() .wrap_err(EXTRACT_FILE_ERROR_MSG)?; let status = tokio::process::Command::new("tar") .arg("-xf") .arg(file.as_os_str()) .current_dir(dest) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err(EXTRACT_FILE_ERROR_MSG)?; ensure!(status.success(), EXTRACT_FILE_ERROR_MSG); } Self::ExecuteCommand { file, args } => { println!( "Executing command `{file} {args}`...", file = file.to_str().unwrap_or(""), args = args.into_iter().map(|s| *s).collect::(), ); const EXECUTE_COMMAND_ERROR_MSG: &'static str = "Executing command failed."; let status = tokio::process::Command::new(file.as_os_str()) .args(*args) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err(EXECUTE_COMMAND_ERROR_MSG)?; ensure!(status.success(), EXECUTE_COMMAND_ERROR_MSG); } Self::InstallMsi { file, props } => { println!( "Installing MSI {file} with props {props:?}`...", file = file.to_str().unwrap_or(""), props = props.iter().map(|s| s.as_ref()).collect::>(), ); const ERROR_MSG: &'static str = "Installing MSI failed."; let status = tokio::process::Command::new("msiexec.exe") .args(["/qn", "/i"]) .arg(file.as_os_str()) .args(props.into_iter().map(|s| s.as_ref())) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err(ERROR_MSG)?; ensure!(status.success(), ERROR_MSG); } Self::CreateDirectory { target, parents } => { if target.is_dir() { println!("Directory {target:?} already created."); } else { if *parents { std::fs::create_dir_all(target) } else { std::fs::create_dir(target) } .into_diagnostic() .wrap_err("Create directory failed.")?; } } Self::CreateShortcut { target, file } => { println!( "Creating shortcut to {target:?} at {file}...", file = file.to_str().unwrap_or("") ); const CREATE_SHORTCUT_ERROR_MSG: &'static str = "Creating shortcut failed."; mkdir_all( file.parent().ok_or_else(|| { miette!("Destination file for shortcut step has no parent.") })?, ) .await?; let status = match target { ShortcutTarget::Path { path } => { tokio::process::Command::new("powershell") .arg("-Command") .arg(format!(r#"$shell = New-Object -ComObject WScript.Shell; $shortcut = $shell.CreateShortcut({file:?}); $shortcut.TargetPath = {path:?}; $shortcut.Save()"#)) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err(CREATE_SHORTCUT_ERROR_MSG)? } ShortcutTarget::Executable { file: exec_file, args } => { tokio::process::Command::new("powershell") .arg("-Command") .arg(format!(r#"$shell = New-Object -ComObject WScript.Shell; $shortcut = $shell.CreateShortcut({file:?}); $shortcut.TargetPath = {exec_file:?}; $shortcut.Arguments = {args:?}; $shortcut.Save()"#)) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err(CREATE_SHORTCUT_ERROR_MSG)? } }; ensure!(status.success(), CREATE_SHORTCUT_ERROR_MSG); } Self::Nuke { target } => { println!( "Nuking {target}...", target = target.to_str().unwrap_or(""), ); // first delete if target.is_dir() { std::fs::remove_dir_all(target) } else { std::fs::remove_file(target) } .into_diagnostic() .wrap_err("Nuke failed: Could not remove target")?; // then make new std::fs::create_dir_all(target) .into_diagnostic() .wrap_err("Nuke failed: Could not create directory")?; let mut grant = get_username()?; grant.push_str(":F"); let status = tokio::process::Command::new("cacls") .arg(target.as_ref()) .arg("/T") .arg("/G") .arg(grant) .stdout(std::process::Stdio::inherit()) .stderr(std::process::Stdio::inherit()) .status() .await .into_diagnostic() .wrap_err("Nuke failed: could not set permissions")?; ensure!(status.success(), "Nuke failed: could not set permissions"); } Self::AppendPath(path) => { static LOCK: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(())); const HKEY: windows_sys::Win32::System::Registry::HKEY = windows_sys::Win32::System::Registry::HKEY_CURRENT_USER; const SUBKEY: &str = "Environment"; const VALUE: &str = "PATH"; const TYPE: windows_sys::Win32::System::Registry::RRF_RT = windows_sys::Win32::System::Registry::RRF_RT_REG_SZ; const CAPACITY: usize = 1024; println!( "Appending {path} to the PATH environment variable...", path = path.to_str().unwrap_or(""), ); let mut buffer: [std::mem::MaybeUninit; CAPACITY] = std::mem::MaybeUninit::uninit_array(); let mut len: u32 = CAPACITY as u32; // this lock will be held until the end of the scope. We do this to prevent // concurrent access to the registry from interfering with (i.e. overwriting) // eachother. let _lock = LOCK.lock().await; let err = unsafe { windows_sys::Win32::System::Registry::RegGetValueA( HKEY, SUBKEY.as_ptr(), VALUE.as_ptr(), TYPE, std::ptr::null_mut(), &mut buffer as *mut _ as *mut _, &mut len as *mut _, ) }; ensure!( err == windows_sys::Win32::Foundation::ERROR_SUCCESS, "RegGetValueA failed" ); assert!(len <= CAPACITY as u32, "Buffer overflow caught"); let buffer: &mut [u8] = unsafe { std::mem::MaybeUninit::slice_assume_init_mut(&mut buffer[..len as usize]) }; let path = path .as_os_str() .to_str() .ok_or_else(|| miette!("Path is not ASCII"))?; ensure!(path.is_ascii(), "Path is not ASCII"); let path_b = path.as_bytes(); let contains = buffer.split(|b| *b == ';' as u8).any(|item| item == path_b); if contains { println!("Not adding {path} because it is already in the PATH"); } else { let mut buffer = Vec::from(buffer); if buffer.is_empty() { buffer.push(';' as u8); } buffer.extend(path_b); let mut hkey = std::mem::MaybeUninit::uninit(); let err = unsafe { windows_sys::Win32::System::Registry::RegOpenKeyExA( HKEY, SUBKEY.as_ptr(), 0, windows_sys::Win32::System::Registry::KEY_SET_VALUE, hkey.as_mut_ptr(), ) }; ensure!( err == windows_sys::Win32::Foundation::ERROR_SUCCESS, "RegOpenKeyExA failed" ); // SAFETY: we just opened the key (which sets the handle) and checked for errors. let hkey = unsafe { hkey.assume_init() }; ensure!(buffer.len() < u32::MAX as usize, "Buffer is too large"); let err = unsafe { windows_sys::Win32::System::Registry::RegSetValueExA( hkey, VALUE.as_ptr(), 0, TYPE, buffer.as_ptr(), buffer.len() as u32, ) }; ensure!( err == windows_sys::Win32::Foundation::ERROR_SUCCESS, "RegSetValueExA failed" ); let err = unsafe { windows_sys::Win32::System::Registry::RegCloseKey(hkey) }; ensure!( err == windows_sys::Win32::Foundation::ERROR_SUCCESS, "RegCloseKey failed" ); } } } println!("-> Done."); Ok(()) } } // this part needed to be isolated to shut work around an async lowering recursion bug. #[inline] unsafe fn spawn_parallel<'a, 'b, 'c>( ctx: &'b Context, sequences: &'b [Pipeline<'c>], ) -> async_scoped::Scope<'a, Result<()>, async_scoped::Tokio> where 'b: 'a, 'c: 'a, { async_scoped::Scope::scope(|scope| { for seq in sequences { scope.spawn(seq.invoke(ctx)); } }) .0 } #[inline] async fn invoke_parallel<'a>(ctx: &Context, sequences: &[Pipeline<'a>]) -> Result<()> { // SAFETY: we immediately collect and block on it with no possibility of failure. let mut results = unsafe { spawn_parallel(ctx, sequences) }; let results = tokio::task::block_in_place(|| { tokio::runtime::Builder::new_current_thread() .build() .unwrap() .block_on(results.collect()) }); let results = results.into_iter().map(|result| { match result.into_diagnostic().wrap_err("Pipeline fork error.") { Ok(Ok(t)) => Ok(t), Ok(Err(e)) => Err(e), Err(e) => Err(e), } }); results.collect::>() } #[derive(Debug, Clone, Copy)] pub enum RemoteResource<'a> { Url(&'a str), GitHubArtifact { /// In the form `org/repo` repo: &'a str, /// Artifact name pattern. /// /// Any of the fields in [`GitHubReleaseMeta`] may be injected via handlebar syntax: /// `{{field_name}}`. /// with `v` prefix stripped if present. pattern: &'a str, }, } #[derive(Debug, Clone)] pub enum ShortcutTarget<'a> { /// An executable shortcut. Executable { /// The executable the shortcut should open. file: Cow<'a, Path>, /// Arguments to the executable. args: &'a str, }, /// A file or folder shortcut. Please use [`Self::Executable`] for shortcuts to binaries. Path { path: Cow<'a, Path> }, } #[derive(Debug, Clone, Serialize, Deserialize)] struct GitHubRelease<'a> { #[serde(flatten)] meta: GitHubReleaseMeta<'a>, //reactions: todo!(), assets: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, ramhorns::Content)] struct GitHubReleaseMeta<'a> { url: String, assets_url: String, upload_url: String, html_url: String, id: u64, //author: todo!(), node_id: String, tag_name: String, #[serde(skip_deserializing, default = "empty_str")] tag_name_strip_prefix: &'a str, target_commitish: String, name: String, draft: bool, prerelease: bool, created_at: String, published_at: String, tarball_url: String, zipball_url: String, body: String, } #[derive(Debug, Clone, Serialize, Deserialize)] struct GitHubReleaseAsset { url: String, id: u64, node_id: String, name: String, label: Option, //uploader: todo!(), content_type: String, state: String, size: u64, download_count: u64, created_at: String, updated_at: String, browser_download_url: String, } #[inline(always)] const fn empty_str<'a>() -> &'a str { "" } async fn mkdir_all(path: impl AsRef) -> Result<()> { tokio::fs::DirBuilder::new() .recursive(true) .create(path) .await .into_diagnostic() .wrap_err("Creating directory and any missing parents failed.") } async fn fetch_latest_release<'a, 'b>( reqwest: &'b reqwest::Client, repo: &'b str, ) -> Result> { const FETCH_META_ERROR_MSG: &'static str = "Fetching the latest release metadata from GitHub failed."; let url = url::Url::parse(&format!( "https://api.github.com/repos/{repo}/releases/latest" )) .into_diagnostic() .wrap_err("Invalid GitHub repo for download step.")?; let resp = reqwest .get(url) .send() .await .into_diagnostic() .wrap_err(FETCH_META_ERROR_MSG)?; let body = resp .text() .await .into_diagnostic() .wrap_err(FETCH_META_ERROR_MSG)?; let release: GitHubRelease = serde_json::from_str(&body) .into_diagnostic() .wrap_err_with(|| format!("{}: {}", FETCH_META_ERROR_MSG, body))?; Ok(release) } #[cfg(test)] mod tests { #[test] fn decode_latest_release() { tokio::runtime::Builder::new_current_thread() .enable_io() .build() .unwrap() .block_on(async move { let body = super::fetch_latest_release( &reqwest::Client::new(), "notepad-plus-plus/notepad-plus-plus", ) .await .unwrap(); assert_eq!(body.meta.draft, false); }); } }