diff --git a/src/validation.rs b/src/validation.rs index 14a0726..b189dc4 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -1,8 +1,11 @@ use std::borrow::Cow; use std::collections::HashSet; +use std::fmt; +use std::marker::PhantomData; use std::time::{SystemTime, UNIX_EPOCH}; -use serde::Deserialize; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer}; use crate::algorithms::Algorithm; use crate::errors::{new_error, ErrorKind, Result}; @@ -117,7 +120,9 @@ pub fn get_current_timestamp() -> u64 { #[derive(Deserialize)] pub(crate) struct ClaimsForValidation<'a> { + #[serde(deserialize_with = "numeric_type", default)] exp: TryParse, + #[serde(deserialize_with = "numeric_type", default)] nbf: TryParse, #[serde(borrow)] sub: TryParse>, @@ -126,6 +131,7 @@ pub(crate) struct ClaimsForValidation<'a> { #[serde(borrow)] aud: TryParse>, } +#[derive(Debug)] enum TryParse { Parsed(T), FailedToParse, @@ -142,6 +148,12 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse { }) } } +impl Default for TryParse { + fn default() -> Self { + Self::NotPresent + } +} + #[derive(Deserialize)] #[serde(untagged)] enum Audience<'a> { @@ -205,6 +217,44 @@ pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Res Ok(()) } +fn numeric_type<'de, D>(deserializer: D) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + struct NumericType(PhantomData TryParse>); + + impl<'de> Visitor<'de> for NumericType { + type Value = TryParse; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("A NumericType that can be reasonably coerced into a u64") + } + + fn visit_f64(self, value: f64) -> std::result::Result + where + E: de::Error, + { + if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) { + Ok(TryParse::Parsed(value.round() as u64)) + } else { + Err(serde::de::Error::custom("NumericType must be representable as a u64")) + } + } + + fn visit_u64(self, value: u64) -> std::result::Result + where + E: de::Error, + { + Ok(TryParse::Parsed(value)) + } + } + + match deserializer.deserialize_any(NumericType(PhantomData)) { + Ok(ok) => Ok(ok), + Err(_) => Ok(TryParse::FailedToParse), + } +} + #[cfg(test)] mod tests { use serde_json::json; @@ -225,6 +275,13 @@ mod tests { assert!(res.is_ok()); } + #[test] + fn exp_float_in_future_ok() { + let claims = json!({ "exp": (get_current_timestamp() as f64) + 10000.123 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); + assert!(res.is_ok()); + } + #[test] fn exp_in_past_fails() { let claims = json!({ "exp": get_current_timestamp() - 100000 }); @@ -237,6 +294,18 @@ mod tests { }; } + #[test] + fn exp_float_in_past_fails() { + let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + ErrorKind::ExpiredSignature => (), + _ => unreachable!(), + }; + } + #[test] fn exp_in_past_but_in_leeway_ok() { let claims = json!({ "exp": get_current_timestamp() - 500 }); @@ -268,6 +337,16 @@ mod tests { assert!(res.is_ok()); } + #[test] + fn nbf_float_in_past_ok() { + let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 }); + let mut validation = Validation::new(Algorithm::HS256); + validation.validate_exp = false; + validation.validate_nbf = true; + let res = validate(deserialize_claims(&claims), &validation); + assert!(res.is_ok()); + } + #[test] fn nbf_in_future_fails() { let claims = json!({ "nbf": get_current_timestamp() + 100000 });