From 356fac075d265392d14a423abab52d7c71154134 Mon Sep 17 00:00:00 2001 From: Vincent Prouillet Date: Fri, 3 Dec 2021 20:42:05 +0100 Subject: [PATCH] Fix validation for issuers --- src/validation.rs | 90 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 18 deletions(-) diff --git a/src/validation.rs b/src/validation.rs index 20f213b..9ab40d0 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -28,6 +28,10 @@ use crate::errors::{new_error, ErrorKind, Result}; /// ``` #[derive(Debug, Clone, PartialEq)] pub struct Validation { + /// Which claims are required to be present before starting the validation + /// + /// Defaults to `{"exp"}` + pub required_claims: HashSet, /// Add some leeway (in seconds) to the `exp`, `iat` and `nbf` validation to /// account for clock skew. /// @@ -90,6 +94,13 @@ impl Validation { self.iss = Some(items.iter().map(|x| x.to_string()).collect()) } + /// Which claims are required to be present for this JWT to be considered valid + /// This is not restricted to the claims from the JWT spec, you can add your own custom ones. + /// The simple usage is `set_required_claims(&["exp", "my_claim"])` + pub fn set_required_claims(&mut self, items: &[T]) { + self.required_claims = items.iter().map(|x| x.to_string()).collect(); + } + /// Whether to validate the JWT cryptographic signature /// Very insecure to turn that off, only do it if you know what you're doing. /// With this flag turned off, you should not trust any of the values of the claims. @@ -100,7 +111,11 @@ impl Validation { impl Default for Validation { fn default() -> Self { + let mut required_claims = HashSet::with_capacity(1); + required_claims.insert("exp".to_owned()); + Validation { + required_claims, algorithms: vec![Algorithm::HS256], leeway: 60, @@ -131,7 +146,7 @@ pub(crate) struct ClaimsForValidation<'a> { #[serde(borrow)] sub: TryParse>, #[serde(borrow)] - iss: TryParse>, + iss: TryParse>, #[serde(borrow)] aud: TryParse>, } @@ -164,6 +179,13 @@ enum Audience<'a> { Single(#[serde(borrow)] Cow<'a, str>), Multiple(#[serde(borrow)] HashSet>), } + +#[derive(Deserialize)] +#[serde(untagged)] +enum Issuer<'a> { + Single(#[serde(borrow)] Cow<'a, str>), + Multiple(#[serde(borrow)] HashSet>), +} /// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where /// possible (no escapes in the original str) but it does not work on e.g. `HashSet>` /// We use this struct in this case. @@ -175,7 +197,20 @@ impl std::borrow::Borrow for BorrowedCowIfPossible<'_> { } } +fn is_subset(reference: &HashSet, given: &HashSet>) -> bool { + // Check that intersection is non-empty, favoring iterating on smallest + if reference.len() < given.len() { + reference.iter().any(|a| given.contains(&**a)) + } else { + given.iter().any(|a| reference.contains(&*a.0)) + } +} + pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> { + // for required_claim in &options.required_claims { + // if matches!(claims) + // } + let now = get_current_timestamp(); if options.validate_exp @@ -197,24 +232,26 @@ pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Res } if let Some(ref correct_iss) = options.iss { - if !matches!(claims.iss, TryParse::Parsed(iss) if correct_iss.contains(&*iss)) { + let is_valid = match claims.iss { + TryParse::Parsed(Issuer::Single(iss)) if correct_iss.contains(&*iss) => true, + TryParse::Parsed(Issuer::Multiple(iss)) => is_subset(correct_iss, &iss), + _ => false, + }; + + if !is_valid { return Err(new_error(ErrorKind::InvalidIssuer)); } } if let Some(ref correct_aud) = options.aud { - match claims.aud { - TryParse::Parsed(Audience::Single(aud)) if correct_aud.contains(&*aud) => {} - TryParse::Parsed(Audience::Multiple(aud)) - if { - // Check that intersection is non-empty, favoring iterating on smallest - if correct_aud.len() < aud.len() { - correct_aud.iter().any(|a| aud.contains(&**a)) - } else { - aud.iter().any(|a| correct_aud.contains(&*a.0)) - } - } => {} - _ => return Err(new_error(ErrorKind::InvalidAudience)), + let is_valid = match claims.aud { + TryParse::Parsed(Audience::Single(aud)) if correct_aud.contains(&*aud) => true, + TryParse::Parsed(Audience::Multiple(aud)) => is_subset(correct_aud, &aud), + _ => false, + }; + + if !is_valid { + return Err(new_error(ErrorKind::InvalidAudience)); } } @@ -378,13 +415,21 @@ mod tests { } #[test] - fn iss_ok() { - let claims = json!({"iss": "Keats"}); - + fn iss_string_ok() { + let claims = json!({"iss": ["Keats"]}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_issuer(&["Keats"]); + let res = validate(deserialize_claims(&claims), &validation); + assert!(res.is_ok()); + } + #[test] + fn iss_array_of_string_ok() { + let claims = json!({"iss": ["UserA", "UserB"]}); + let mut validation = Validation::new(Algorithm::HS256); + validation.validate_exp = false; + validation.set_issuer(&["UserA", "UserB"]); let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } @@ -557,7 +602,16 @@ mod tests { validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]); let res = validate(deserialize_claims(&claims), &validation); - println!("{:?}", res); assert!(res.is_ok()); } + + #[test] + fn required_claims_complains_if_field_not_found() { + let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"}); + let mut validation = Validation::new(Algorithm::HS256); + validation.validate_exp = false; + + let res = validate(deserialize_claims(&claims), &validation); + assert!(res.is_err()); + } }