From 2cc95b9f371915a6bf2f9503ebc3ae3ca2ff9a3c Mon Sep 17 00:00:00 2001 From: emeryc Date: Mon, 1 Nov 2021 06:25:59 -0700 Subject: [PATCH] Numeric type (#214) * exp & nbf as float In order to properly align with JWT NumericType wire protocol allow for type on wire to either be u64 or f64. In either case we convert in the most lossless way possible to a u64, so that nobody needs to know that the spec is overly permissive. * minimal cleanup --- src/validation.rs | 81 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/validation.rs b/src/validation.rs index 14a0726..b189dc4 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -1,8 +1,11 @@ use std::borrow::Cow; use std::collections::HashSet; +use std::fmt; +use std::marker::PhantomData; use std::time::{SystemTime, UNIX_EPOCH}; -use serde::Deserialize; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer}; use crate::algorithms::Algorithm; use crate::errors::{new_error, ErrorKind, Result}; @@ -117,7 +120,9 @@ pub fn get_current_timestamp() -> u64 { #[derive(Deserialize)] pub(crate) struct ClaimsForValidation<'a> { + #[serde(deserialize_with = "numeric_type", default)] exp: TryParse, + #[serde(deserialize_with = "numeric_type", default)] nbf: TryParse, #[serde(borrow)] sub: TryParse>, @@ -126,6 +131,7 @@ pub(crate) struct ClaimsForValidation<'a> { #[serde(borrow)] aud: TryParse>, } +#[derive(Debug)] enum TryParse { Parsed(T), FailedToParse, @@ -142,6 +148,12 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse { }) } } +impl Default for TryParse { + fn default() -> Self { + Self::NotPresent + } +} + #[derive(Deserialize)] #[serde(untagged)] enum Audience<'a> { @@ -205,6 +217,44 @@ pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Res Ok(()) } +fn numeric_type<'de, D>(deserializer: D) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + struct NumericType(PhantomData TryParse>); + + impl<'de> Visitor<'de> for NumericType { + type Value = TryParse; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("A NumericType that can be reasonably coerced into a u64") + } + + fn visit_f64(self, value: f64) -> std::result::Result + where + E: de::Error, + { + if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) { + Ok(TryParse::Parsed(value.round() as u64)) + } else { + Err(serde::de::Error::custom("NumericType must be representable as a u64")) + } + } + + fn visit_u64(self, value: u64) -> std::result::Result + where + E: de::Error, + { + Ok(TryParse::Parsed(value)) + } + } + + match deserializer.deserialize_any(NumericType(PhantomData)) { + Ok(ok) => Ok(ok), + Err(_) => Ok(TryParse::FailedToParse), + } +} + #[cfg(test)] mod tests { use serde_json::json; @@ -225,6 +275,13 @@ mod tests { assert!(res.is_ok()); } + #[test] + fn exp_float_in_future_ok() { + let claims = json!({ "exp": (get_current_timestamp() as f64) + 10000.123 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); + assert!(res.is_ok()); + } + #[test] fn exp_in_past_fails() { let claims = json!({ "exp": get_current_timestamp() - 100000 }); @@ -237,6 +294,18 @@ mod tests { }; } + #[test] + fn exp_float_in_past_fails() { + let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + ErrorKind::ExpiredSignature => (), + _ => unreachable!(), + }; + } + #[test] fn exp_in_past_but_in_leeway_ok() { let claims = json!({ "exp": get_current_timestamp() - 500 }); @@ -268,6 +337,16 @@ mod tests { assert!(res.is_ok()); } + #[test] + fn nbf_float_in_past_ok() { + let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 }); + let mut validation = Validation::new(Algorithm::HS256); + validation.validate_exp = false; + validation.validate_nbf = true; + let res = validate(deserialize_claims(&claims), &validation); + assert!(res.is_ok()); + } + #[test] fn nbf_in_future_fails() { let claims = json!({ "nbf": get_current_timestamp() + 100000 });