Add required_spec_claims (#225)

This commit is contained in:
Vincent Prouillet 2021-12-15 19:54:54 +01:00
parent 356fac075d
commit 255c740e47
3 changed files with 60 additions and 25 deletions

View File

@ -13,6 +13,7 @@
- Allow float values for `exp` and `nbf`, yes it's in the spec... floats will be rounded and converted to u64 - Allow float values for `exp` and `nbf`, yes it's in the spec... floats will be rounded and converted to u64
- Error now implements Clone/Eq - Error now implements Clone/Eq
- Change default leeway from 0s to 60s - Change default leeway from 0s to 60s
- Add `Validation::require_spec_claims` to validate presence of the spec claims
## 7.2.0 (2020-06-30) ## 7.2.0 (2020-06-30)

View File

@ -51,6 +51,8 @@ pub enum ErrorKind {
InvalidKeyFormat, InvalidKeyFormat,
// Validation errors // Validation errors
/// When a claim required by the validation is not present
MissingRequiredClaim(String),
/// When a tokens `exp` claim indicates that it has expired /// When a tokens `exp` claim indicates that it has expired
ExpiredSignature, ExpiredSignature,
/// When a tokens `iss` claim does not match the expected issuer /// When a tokens `iss` claim does not match the expected issuer
@ -88,6 +90,7 @@ impl StdError for Error {
ErrorKind::InvalidRsaKey(_) => None, ErrorKind::InvalidRsaKey(_) => None,
ErrorKind::ExpiredSignature => None, ErrorKind::ExpiredSignature => None,
ErrorKind::MissingAlgorithm => None, ErrorKind::MissingAlgorithm => None,
ErrorKind::MissingRequiredClaim(_) => None,
ErrorKind::InvalidIssuer => None, ErrorKind::InvalidIssuer => None,
ErrorKind::InvalidAudience => None, ErrorKind::InvalidAudience => None,
ErrorKind::InvalidSubject => None, ErrorKind::InvalidSubject => None,
@ -119,6 +122,7 @@ impl fmt::Display for Error {
| ErrorKind::InvalidAlgorithm | ErrorKind::InvalidAlgorithm
| ErrorKind::InvalidKeyFormat | ErrorKind::InvalidKeyFormat
| ErrorKind::InvalidAlgorithmName => write!(f, "{:?}", self.0), | ErrorKind::InvalidAlgorithmName => write!(f, "{:?}", self.0),
ErrorKind::MissingRequiredClaim(ref c) => write!(f, "Missing required claim: {}", c),
ErrorKind::InvalidRsaKey(ref msg) => write!(f, "RSA key invalid: {}", msg), ErrorKind::InvalidRsaKey(ref msg) => write!(f, "RSA key invalid: {}", msg),
ErrorKind::Json(ref err) => write!(f, "JSON error: {}", err), ErrorKind::Json(ref err) => write!(f, "JSON error: {}", err),
ErrorKind::Utf8(ref err) => write!(f, "UTF-8 error: {}", err), ErrorKind::Utf8(ref err) => write!(f, "UTF-8 error: {}", err),

View File

@ -28,10 +28,13 @@ use crate::errors::{new_error, ErrorKind, Result};
/// ``` /// ```
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Validation { pub struct Validation {
/// Which claims are required to be present before starting the validation /// Which claims are required to be present before starting the validation.
/// This does not interact with the various `validate_*`. If you remove `exp` from that list, you still need
/// to set `validate_exp` to `false`.
/// The only value that will be used are "exp", "nbf", "aud", "iss", "sub". Anything else will be ignored.
/// ///
/// Defaults to `{"exp"}` /// Defaults to `{"exp"}`
pub required_claims: HashSet<String>, pub required_spec_claims: HashSet<String>,
/// Add some leeway (in seconds) to the `exp`, `iat` and `nbf` validation to /// Add some leeway (in seconds) to the `exp`, `iat` and `nbf` validation to
/// account for clock skew. /// account for clock skew.
/// ///
@ -94,11 +97,13 @@ impl Validation {
self.iss = Some(items.iter().map(|x| x.to_string()).collect()) self.iss = Some(items.iter().map(|x| x.to_string()).collect())
} }
/// Which claims are required to be present for this JWT to be considered valid /// Which claims are required to be present for this JWT to be considered valid.
/// This is not restricted to the claims from the JWT spec, you can add your own custom ones. /// The only values that will be considered are "exp", "nbf", "aud", "iss", "sub".
/// The simple usage is `set_required_claims(&["exp", "my_claim"])` /// The simple usage is `set_required_claims(&["exp", "nbf"])`.
pub fn set_required_claims<T: ToString>(&mut self, items: &[T]) { /// If you want to have an empty set, do not use this function - set an empty set on the struct
self.required_claims = items.iter().map(|x| x.to_string()).collect(); /// param directly.
pub fn set_required_spec_claims<T: ToString>(&mut self, items: &[T]) {
self.required_spec_claims = items.iter().map(|x| x.to_string()).collect();
} }
/// Whether to validate the JWT cryptographic signature /// Whether to validate the JWT cryptographic signature
@ -115,7 +120,7 @@ impl Default for Validation {
required_claims.insert("exp".to_owned()); required_claims.insert("exp".to_owned());
Validation { Validation {
required_claims, required_spec_claims: required_claims,
algorithms: vec![Algorithm::HS256], algorithms: vec![Algorithm::HS256],
leeway: 60, leeway: 60,
@ -186,6 +191,7 @@ enum Issuer<'a> {
Single(#[serde(borrow)] Cow<'a, str>), Single(#[serde(borrow)] Cow<'a, str>),
Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>), Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
} }
/// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where /// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where
/// possible (no escapes in the original str) but it does not work on e.g. `HashSet<Cow<str>>` /// possible (no escapes in the original str) but it does not work on e.g. `HashSet<Cow<str>>`
/// We use this struct in this case. /// We use this struct in this case.
@ -207,12 +213,23 @@ fn is_subset(reference: &HashSet<String>, given: &HashSet<BorrowedCowIfPossible<
} }
pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> { pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> {
// for required_claim in &options.required_claims {
// if matches!(claims)
// }
let now = get_current_timestamp(); let now = get_current_timestamp();
for required_claim in &options.required_spec_claims {
let present = match required_claim.as_str() {
"exp" => matches!(claims.exp, TryParse::Parsed(_)),
"sub" => matches!(claims.sub, TryParse::Parsed(_)),
"iss" => matches!(claims.iss, TryParse::Parsed(_)),
"aud" => matches!(claims.aud, TryParse::Parsed(_)),
"nbf" => matches!(claims.nbf, TryParse::Parsed(_)),
_ => continue,
};
if !present {
return Err(new_error(ErrorKind::MissingRequiredClaim(required_claim.clone())));
}
}
if options.validate_exp if options.validate_exp
&& !matches!(claims.exp, TryParse::Parsed(exp) if exp >= now-options.leeway) && !matches!(claims.exp, TryParse::Parsed(exp) if exp >= now-options.leeway)
{ {
@ -304,6 +321,7 @@ mod tests {
use crate::errors::ErrorKind; use crate::errors::ErrorKind;
use crate::Algorithm; use crate::Algorithm;
use std::collections::HashSet;
fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation { fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation {
serde::Deserialize::deserialize(claims).unwrap() serde::Deserialize::deserialize(claims).unwrap()
@ -360,18 +378,17 @@ mod tests {
#[test] #[test]
fn validation_called_even_if_field_is_empty() { fn validation_called_even_if_field_is_empty() {
let claims = json!({}); let claims = json!({});
let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); let mut validation = Validation::new(Algorithm::HS256);
assert!(res.is_err()); validation.required_spec_claims = HashSet::new();
match res.unwrap_err().kind() { let res = validate(deserialize_claims(&claims), &validation).unwrap_err();
ErrorKind::ExpiredSignature => (), assert_eq!(res.kind(), &ErrorKind::ExpiredSignature);
_ => unreachable!(),
};
} }
#[test] #[test]
fn nbf_in_past_ok() { fn nbf_in_past_ok() {
let claims = json!({ "nbf": get_current_timestamp() - 10000 }); let claims = json!({ "nbf": get_current_timestamp() - 10000 });
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.validate_nbf = true; validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -382,6 +399,7 @@ mod tests {
fn nbf_float_in_past_ok() { fn nbf_float_in_past_ok() {
let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 }); let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 });
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.validate_nbf = true; validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -392,6 +410,7 @@ mod tests {
fn nbf_in_future_fails() { fn nbf_in_future_fails() {
let claims = json!({ "nbf": get_current_timestamp() + 100000 }); let claims = json!({ "nbf": get_current_timestamp() + 100000 });
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.validate_nbf = true; validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -407,6 +426,7 @@ mod tests {
fn nbf_in_future_but_in_leeway_ok() { fn nbf_in_future_but_in_leeway_ok() {
let claims = json!({ "nbf": get_current_timestamp() + 500 }); let claims = json!({ "nbf": get_current_timestamp() + 500 });
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.validate_nbf = true; validation.validate_nbf = true;
validation.leeway = 1000 * 60; validation.leeway = 1000 * 60;
@ -418,6 +438,7 @@ mod tests {
fn iss_string_ok() { fn iss_string_ok() {
let claims = json!({"iss": ["Keats"]}); let claims = json!({"iss": ["Keats"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.set_issuer(&["Keats"]); validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -428,6 +449,7 @@ mod tests {
fn iss_array_of_string_ok() { fn iss_array_of_string_ok() {
let claims = json!({"iss": ["UserA", "UserB"]}); let claims = json!({"iss": ["UserA", "UserB"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.set_issuer(&["UserA", "UserB"]); validation.set_issuer(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -439,6 +461,7 @@ mod tests {
let claims = json!({"iss": "Hacked"}); let claims = json!({"iss": "Hacked"});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.set_issuer(&["Keats"]); validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -455,6 +478,7 @@ mod tests {
let claims = json!({}); let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.set_issuer(&["Keats"]); validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -469,6 +493,7 @@ mod tests {
fn sub_ok() { fn sub_ok() {
let claims = json!({"sub": "Keats"}); let claims = json!({"sub": "Keats"});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.sub = Some("Keats".to_owned()); validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -479,6 +504,7 @@ mod tests {
fn sub_not_matching_fails() { fn sub_not_matching_fails() {
let claims = json!({"sub": "Hacked"}); let claims = json!({"sub": "Hacked"});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false; validation.validate_exp = false;
validation.sub = Some("Keats".to_owned()); validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -495,6 +521,7 @@ mod tests {
let claims = json!({}); let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.sub = Some("Keats".to_owned()); validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err()); assert!(res.is_err());
@ -510,6 +537,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]}); let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["Everyone"]); validation.set_audience(&["Everyone"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok()); assert!(res.is_ok());
@ -520,6 +548,7 @@ mod tests {
let claims = json!({"aud": ["UserA", "UserB"]}); let claims = json!({"aud": ["UserA", "UserB"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["UserA", "UserB"]); validation.set_audience(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok()); assert!(res.is_ok());
@ -530,6 +559,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]}); let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["UserA", "UserB"]); validation.set_audience(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err()); assert!(res.is_err());
@ -545,6 +575,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]}); let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["None"]); validation.set_audience(&["None"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err()); assert!(res.is_err());
@ -560,6 +591,7 @@ mod tests {
let claims = json!({}); let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["None"]); validation.set_audience(&["None"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err()); assert!(res.is_err());
@ -599,6 +631,7 @@ mod tests {
aud_hashset.insert(aud); aud_hashset.insert(aud);
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false; validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]); validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]);
let res = validate(deserialize_claims(&claims), &validation); let res = validate(deserialize_claims(&claims), &validation);
@ -606,12 +639,9 @@ mod tests {
} }
#[test] #[test]
fn required_claims_complains_if_field_not_found() { fn errors_when_required_claim_is_missing() {
let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"}); let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256); let res = validate(deserialize_claims(&claims), &Validation::default()).unwrap_err();
validation.validate_exp = false; assert_eq!(res.kind(), &ErrorKind::MissingRequiredClaim("exp".to_owned()));
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
} }
} }