Numeric type (#214)

* exp & nbf as float

In order to properly align with JWT NumericType wire protocol
allow for type on wire to either be u64 or f64. In either case we
convert in the most lossless way possible to a u64, so that nobody
needs to know that the spec is overly permissive.

* minimal cleanup
This commit is contained in:
emeryc 2021-11-01 06:25:59 -07:00 committed by Vincent Prouillet
parent 733d29aa87
commit 2cc95b9f37
1 changed files with 80 additions and 1 deletions

View File

@ -1,8 +1,11 @@
use std::borrow::Cow;
use std::collections::HashSet;
use std::fmt;
use std::marker::PhantomData;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::Deserialize;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer};
use crate::algorithms::Algorithm;
use crate::errors::{new_error, ErrorKind, Result};
@ -117,7 +120,9 @@ pub fn get_current_timestamp() -> u64 {
#[derive(Deserialize)]
pub(crate) struct ClaimsForValidation<'a> {
#[serde(deserialize_with = "numeric_type", default)]
exp: TryParse<u64>,
#[serde(deserialize_with = "numeric_type", default)]
nbf: TryParse<u64>,
#[serde(borrow)]
sub: TryParse<Cow<'a, str>>,
@ -126,6 +131,7 @@ pub(crate) struct ClaimsForValidation<'a> {
#[serde(borrow)]
aud: TryParse<Audience<'a>>,
}
#[derive(Debug)]
enum TryParse<T> {
Parsed(T),
FailedToParse,
@ -142,6 +148,12 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse<T> {
})
}
}
impl<T> Default for TryParse<T> {
fn default() -> Self {
Self::NotPresent
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum Audience<'a> {
@ -205,6 +217,44 @@ pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Res
Ok(())
}
fn numeric_type<'de, D>(deserializer: D) -> std::result::Result<TryParse<u64>, D::Error>
where
D: Deserializer<'de>,
{
struct NumericType(PhantomData<fn() -> TryParse<u64>>);
impl<'de> Visitor<'de> for NumericType {
type Value = TryParse<u64>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A NumericType that can be reasonably coerced into a u64")
}
fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) {
Ok(TryParse::Parsed(value.round() as u64))
} else {
Err(serde::de::Error::custom("NumericType must be representable as a u64"))
}
}
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(TryParse::Parsed(value))
}
}
match deserializer.deserialize_any(NumericType(PhantomData)) {
Ok(ok) => Ok(ok),
Err(_) => Ok(TryParse::FailedToParse),
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
@ -225,6 +275,13 @@ mod tests {
assert!(res.is_ok());
}
#[test]
fn exp_float_in_future_ok() {
let claims = json!({ "exp": (get_current_timestamp() as f64) + 10000.123 });
let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
assert!(res.is_ok());
}
#[test]
fn exp_in_past_fails() {
let claims = json!({ "exp": get_current_timestamp() - 100000 });
@ -237,6 +294,18 @@ mod tests {
};
}
#[test]
fn exp_float_in_past_fails() {
let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
assert!(res.is_err());
match res.unwrap_err().kind() {
ErrorKind::ExpiredSignature => (),
_ => unreachable!(),
};
}
#[test]
fn exp_in_past_but_in_leeway_ok() {
let claims = json!({ "exp": get_current_timestamp() - 500 });
@ -268,6 +337,16 @@ mod tests {
assert!(res.is_ok());
}
#[test]
fn nbf_float_in_past_ok() {
let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 });
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
}
#[test]
fn nbf_in_future_fails() {
let claims = json!({ "nbf": get_current_timestamp() + 100000 });