Fix validation for issuers

This commit is contained in:
Vincent Prouillet 2021-12-03 20:42:05 +01:00
parent 5ed8af440c
commit 356fac075d
1 changed files with 72 additions and 18 deletions

View File

@ -28,6 +28,10 @@ use crate::errors::{new_error, ErrorKind, Result};
/// ```
#[derive(Debug, Clone, PartialEq)]
pub struct Validation {
/// Which claims are required to be present before starting the validation
///
/// Defaults to `{"exp"}`
pub required_claims: HashSet<String>,
/// Add some leeway (in seconds) to the `exp`, `iat` and `nbf` validation to
/// account for clock skew.
///
@ -90,6 +94,13 @@ impl Validation {
self.iss = Some(items.iter().map(|x| x.to_string()).collect())
}
/// Which claims are required to be present for this JWT to be considered valid
/// This is not restricted to the claims from the JWT spec, you can add your own custom ones.
/// The simple usage is `set_required_claims(&["exp", "my_claim"])`
pub fn set_required_claims<T: ToString>(&mut self, items: &[T]) {
self.required_claims = items.iter().map(|x| x.to_string()).collect();
}
/// Whether to validate the JWT cryptographic signature
/// Very insecure to turn that off, only do it if you know what you're doing.
/// With this flag turned off, you should not trust any of the values of the claims.
@ -100,7 +111,11 @@ impl Validation {
impl Default for Validation {
fn default() -> Self {
let mut required_claims = HashSet::with_capacity(1);
required_claims.insert("exp".to_owned());
Validation {
required_claims,
algorithms: vec![Algorithm::HS256],
leeway: 60,
@ -131,7 +146,7 @@ pub(crate) struct ClaimsForValidation<'a> {
#[serde(borrow)]
sub: TryParse<Cow<'a, str>>,
#[serde(borrow)]
iss: TryParse<Cow<'a, str>>,
iss: TryParse<Issuer<'a>>,
#[serde(borrow)]
aud: TryParse<Audience<'a>>,
}
@ -164,6 +179,13 @@ enum Audience<'a> {
Single(#[serde(borrow)] Cow<'a, str>),
Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
}
#[derive(Deserialize)]
#[serde(untagged)]
enum Issuer<'a> {
Single(#[serde(borrow)] Cow<'a, str>),
Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
}
/// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where
/// possible (no escapes in the original str) but it does not work on e.g. `HashSet<Cow<str>>`
/// We use this struct in this case.
@ -175,7 +197,20 @@ impl std::borrow::Borrow<str> for BorrowedCowIfPossible<'_> {
}
}
fn is_subset(reference: &HashSet<String>, given: &HashSet<BorrowedCowIfPossible<'_>>) -> bool {
// Check that intersection is non-empty, favoring iterating on smallest
if reference.len() < given.len() {
reference.iter().any(|a| given.contains(&**a))
} else {
given.iter().any(|a| reference.contains(&*a.0))
}
}
pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> {
// for required_claim in &options.required_claims {
// if matches!(claims)
// }
let now = get_current_timestamp();
if options.validate_exp
@ -197,24 +232,26 @@ pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Res
}
if let Some(ref correct_iss) = options.iss {
if !matches!(claims.iss, TryParse::Parsed(iss) if correct_iss.contains(&*iss)) {
let is_valid = match claims.iss {
TryParse::Parsed(Issuer::Single(iss)) if correct_iss.contains(&*iss) => true,
TryParse::Parsed(Issuer::Multiple(iss)) => is_subset(correct_iss, &iss),
_ => false,
};
if !is_valid {
return Err(new_error(ErrorKind::InvalidIssuer));
}
}
if let Some(ref correct_aud) = options.aud {
match claims.aud {
TryParse::Parsed(Audience::Single(aud)) if correct_aud.contains(&*aud) => {}
TryParse::Parsed(Audience::Multiple(aud))
if {
// Check that intersection is non-empty, favoring iterating on smallest
if correct_aud.len() < aud.len() {
correct_aud.iter().any(|a| aud.contains(&**a))
} else {
aud.iter().any(|a| correct_aud.contains(&*a.0))
}
} => {}
_ => return Err(new_error(ErrorKind::InvalidAudience)),
let is_valid = match claims.aud {
TryParse::Parsed(Audience::Single(aud)) if correct_aud.contains(&*aud) => true,
TryParse::Parsed(Audience::Multiple(aud)) => is_subset(correct_aud, &aud),
_ => false,
};
if !is_valid {
return Err(new_error(ErrorKind::InvalidAudience));
}
}
@ -378,13 +415,21 @@ mod tests {
}
#[test]
fn iss_ok() {
let claims = json!({"iss": "Keats"});
fn iss_string_ok() {
let claims = json!({"iss": ["Keats"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
}
#[test]
fn iss_array_of_string_ok() {
let claims = json!({"iss": ["UserA", "UserB"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.set_issuer(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
}
@ -557,7 +602,16 @@ mod tests {
validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]);
let res = validate(deserialize_claims(&claims), &validation);
println!("{:?}", res);
assert!(res.is_ok());
}
#[test]
fn required_claims_complains_if_field_not_found() {
let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
}
}