use core::fmt::Display; use core::future::Future; use core::pin::Pin; use tide::Next; use tide::Request; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Directive<'a> { ChildSrc(Vec<&'a str>), ConnectSrc(Vec<&'a str>), DefaultSrc(Vec<&'a str>), FontSrc(Vec<&'a str>), FrameSrc(Vec<&'a str>), ImgSrc(Vec<&'a str>), ManifestSrc(Vec<&'a str>), MediaSrc(Vec<&'a str>), ObjectSrc(Vec<&'a str>), PrefetchSrc(Vec<&'a str>), ScriptSrc(Vec<&'a str>), ScriptSrcElem(Vec<&'a str>), ScriptSrcAttr(Vec<&'a str>), StyleSrc(Vec<&'a str>), StyleSrcElem(Vec<&'a str>), StyleSrcAttr(Vec<&'a str>), WorkerSrc(Vec<&'a str>), BaseUri(Vec<&'a str>), PluginTypes(Vec<&'a str>), Sandbox(Sandbox), FormAction(Vec<&'a str>), FrameAncestors(Vec<&'a str>), NavigateTo(Vec<&'a str>), ReportTo(String), RequireTrustedTypesFor(TrustedTypesTarget), TrustedTypes(TrustedTypes<'a>), UpgradeInsecureRequests, } impl<'a> Display for Directive<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { match self { Self::ChildSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ConnectSrc(sources) => write!(f, "{}", sources.join(" ")), Self::DefaultSrc(sources) => write!(f, "{}", sources.join(" ")), Self::FontSrc(sources) => write!(f, "{}", sources.join(" ")), Self::FrameSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ImgSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ManifestSrc(sources) => write!(f, "{}", sources.join(" ")), Self::MediaSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ObjectSrc(sources) => write!(f, "{}", sources.join(" ")), Self::PrefetchSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ScriptSrc(sources) => write!(f, "{}", sources.join(" ")), Self::ScriptSrcElem(sources) => write!(f, "{}", sources.join(" ")), Self::ScriptSrcAttr(sources) => write!(f, "{}", sources.join(" ")), Self::StyleSrc(sources) => write!(f, "{}", sources.join(" ")), Self::StyleSrcElem(sources) => write!(f, "{}", sources.join(" ")), Self::StyleSrcAttr(sources) => write!(f, "{}", sources.join(" ")), Self::WorkerSrc(sources) => write!(f, "{}", sources.join(" ")), Self::BaseUri(base_uris) => write!(f, "{}", base_uris.join(" ")), Self::PluginTypes(plugin_types) => write!(f, "{}", plugin_types.join(" ")), Self::Sandbox(sandbox) => write!(f, "{}", sandbox), Self::FormAction(form_actions) => write!(f, "{}", form_actions.join(" ")), Self::FrameAncestors(frame_ancestors) => write!(f, "{}", frame_ancestors.join(" ")), Self::NavigateTo(targets) => write!(f, "{}", targets.join(" ")), Self::ReportTo(report_url) => write!(f, "report-to {}", report_url), Self::RequireTrustedTypesFor(target) => { write!(f, "require-trusted-types-for {}", target) } Self::TrustedTypes(trusted_types) => write!(f, "trusted-types {}", trusted_types), Self::UpgradeInsecureRequests => write!(f, "upgrade-insecure-requests"), } } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Sandbox { AllowDownloadsWithoutUserActivation, AllowForms, AllowModals, AllowOrienationLock, AllowPointerLock, AllowPopups, AllowPopupsToEscapeSandbox, AllowPresentation, AllowSameOrigin, AllowScripts, AllowStorageAccessByUserActivation, AllowTopNavigation, AllowTopNavigationByUserActivation, } impl Display for Sandbox { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { write!( f, "{}", match self { Self::AllowDownloadsWithoutUserActivation => "allow-downloads-without-user-activation", Self::AllowForms => "allow-forms", Self::AllowModals => "allow-modals", Self::AllowOrienationLock => "allow-orienation-lock", Self::AllowPointerLock => "allow-pointer-lock", Self::AllowPopups => "allow-popups", Self::AllowPopupsToEscapeSandbox => "allow-popups-to-escape-sandbox", Self::AllowPresentation => "allow-presentation", Self::AllowSameOrigin => "allow-same-origin", Self::AllowScripts => "allow-scripts", Self::AllowStorageAccessByUserActivation => "allow-storage-access-by-user-activation", Self::AllowTopNavigation => "allow-top-navigation", Self::AllowTopNavigationByUserActivation => "allow-top-navigation-by-user-activation", } ) } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum TrustedTypesTarget { Script, } impl Display for TrustedTypesTarget { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { write!(f, "'script'") } } #[derive(Clone, Debug, PartialEq, Eq)] pub enum TrustedTypes<'a> { None, Values { policy_names: Vec<&'a str>, allow_duplicates: bool, }, } impl<'a> Display for TrustedTypes<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { match self { Self::None => write!(f, "'none'"), Self::Values { policy_names, allow_duplicates: true, } => write!(f, "{} 'allow-duplicates'", policy_names.join(" ")), Self::Values { policy_names, allow_duplicates: false, } => write!(f, "{}", policy_names.join(" ")), } } } pub fn csp_middleware<'a, State: 'static + Clone + Send + Sync>( directives: &'a [Directive<'a>], ) -> impl Fn(Request, Next<'a, State>) -> Pin + Send + 'a>> { move |req: Request, next: Next<'a, State>| { Box::pin(async move { let mut res = next.run(req).await; res.insert_header( "Content-Security-Policy", format!( "{}", directives .iter() .fold(String::new(), |a, b| a + &b.to_string() + ";") ), ); Ok(res) }) } }