diff --git a/benches/jwt.rs b/benches/jwt.rs index 640b53c..d2fee79 100644 --- a/benches/jwt.rs +++ b/benches/jwt.rs @@ -1,5 +1,5 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] @@ -23,7 +23,11 @@ fn bench_decode(c: &mut Criterion) { c.bench_function("bench_decode", |b| { b.iter(|| { - decode::(black_box(token), black_box(&key), black_box(&Validation::default())) + decode::( + black_box(token), + black_box(&key), + black_box(&Validation::new(Algorithm::HS256)), + ) }) }); } diff --git a/src/decoding.rs b/src/decoding.rs index 90645b3..0f77931 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -5,7 +5,7 @@ use crate::crypto::verify; use crate::errors::{new_error, ErrorKind, Result}; use crate::header::Header; use crate::pem::decoder::PemEncodedKey; -use crate::serialization::{b64_decode, from_jwt_part_claims}; +use crate::serialization::{b64_decode, DecodedJwtPartClaims}; use crate::validation::{validate, Validation}; /// The return type of a successful call to [decode](fn.decode.html). @@ -198,10 +198,11 @@ pub fn decode( match verify_signature(token, key, validation) { Err(e) => Err(e), Ok((header, claims)) => { - let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?; - validate(&claims_map, validation)?; + let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?; + let claims = decoded_claims.deserialize()?; + validate(decoded_claims.deserialize()?, validation)?; - Ok(TokenData { header, claims: decoded_claims }) + Ok(TokenData { header, claims }) } } } diff --git a/src/serialization.rs b/src/serialization.rs index 5d09312..cf219dc 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -1,7 +1,4 @@ -use serde::de::DeserializeOwned; -use serde::ser::Serialize; -use serde_json::map::Map; -use serde_json::{from_slice, to_vec, Value}; +use serde::{Deserialize, Serialize}; use crate::errors::Result; @@ -15,18 +12,23 @@ pub(crate) fn b64_decode>(input: T) -> Result> { /// Serializes a struct to JSON and encodes it in base64 pub(crate) fn b64_encode_part(input: &T) -> Result { - let json = to_vec(input)?; + let json = serde_json::to_vec(input)?; Ok(b64_encode(json)) } -/// Decodes from base64 and deserializes from JSON to a struct AND a hashmap of Value so we can -/// run validation on it -pub(crate) fn from_jwt_part_claims, T: DeserializeOwned>( - encoded: B, -) -> Result<(T, Map)> { - let s = b64_decode(encoded)?; - - let claims: T = from_slice(&s)?; - let validation_map: Map<_, _> = from_slice(&s)?; - Ok((claims, validation_map)) +/// This is used to decode from base64 then deserialize from JSON to several structs: +/// - The user-provided struct +/// - The ClaimsForValidation struct from this crate to run validation on +pub(crate) struct DecodedJwtPartClaims { + b64_decoded: Vec, +} + +impl DecodedJwtPartClaims { + pub fn from_jwt_part_claims(encoded_jwt_part_claims: impl AsRef<[u8]>) -> Result { + Ok(Self { b64_decoded: b64_decode(encoded_jwt_part_claims)? }) + } + + pub fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> Result { + Ok(serde_json::from_slice(&self.b64_decoded)?) + } } diff --git a/src/validation.rs b/src/validation.rs index aa8d3c3..9714250 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -1,8 +1,8 @@ +use std::borrow::Cow; use std::collections::HashSet; use std::time::{SystemTime, UNIX_EPOCH}; -use serde_json::map::Map; -use serde_json::Value; +use serde::Deserialize; use crate::algorithms::Algorithm; use crate::errors::{new_error, ErrorKind, Result}; @@ -109,76 +109,90 @@ pub fn get_current_timestamp() -> u64 { start.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs() } -pub fn validate(claims: &Map, options: &Validation) -> Result<()> { +#[derive(Deserialize)] +pub(crate) struct ClaimsForValidation<'a> { + exp: TryParse, + nbf: TryParse, + #[serde(borrow)] + sub: TryParse>, + #[serde(borrow)] + iss: TryParse>, + #[serde(borrow)] + aud: TryParse>, +} +enum TryParse { + Parsed(T), + FailedToParse, + NotPresent, +} +impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + Ok(match Option::::deserialize(deserializer) { + Ok(Some(value)) => TryParse::Parsed(value), + Ok(None) => TryParse::NotPresent, + Err(_) => TryParse::FailedToParse, + }) + } +} +#[derive(Deserialize)] +#[serde(untagged)] +enum Audience<'a> { + Single(#[serde(borrow)] Cow<'a, str>), + Multiple(#[serde(borrow)] HashSet>), +} +/// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where +/// possible (no escapes in the original str) but it does not work on e.g. `HashSet>` +/// We use this struct in this case. +#[derive(Deserialize, PartialEq, Eq, Hash)] +struct BorrowedCowIfPossible<'a>(#[serde(borrow)] Cow<'a, str>); +impl std::borrow::Borrow for BorrowedCowIfPossible<'_> { + fn borrow(&self) -> &str { + &*self.0 + } +} + +pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> { let now = get_current_timestamp(); - if options.validate_exp { - if let Some(exp) = claims.get("exp") { - if let Some(exp) = exp.as_u64() { - if exp < now - options.leeway { - return Err(new_error(ErrorKind::ExpiredSignature)); - } - } else { - return Err(new_error(ErrorKind::ExpiredSignature)); - } - } else { - return Err(new_error(ErrorKind::ExpiredSignature)); - } + if options.validate_exp + && !matches!(claims.exp, TryParse::Parsed(exp) if exp >= now-options.leeway) + { + return Err(new_error(ErrorKind::ExpiredSignature)); } - if options.validate_nbf { - if let Some(nbf) = claims.get("nbf") { - if let Some(nbf) = nbf.as_u64() { - if nbf > now + options.leeway { - return Err(new_error(ErrorKind::ImmatureSignature)); - } - } else { - return Err(new_error(ErrorKind::ImmatureSignature)); - } - } else { - return Err(new_error(ErrorKind::ImmatureSignature)); - } + if options.validate_nbf + && !matches!(claims.nbf, TryParse::Parsed(nbf) if nbf <= now + options.leeway) + { + return Err(new_error(ErrorKind::ImmatureSignature)); } - if let Some(ref correct_sub) = options.sub { - if let Some(Value::String(sub)) = claims.get("sub") { - if sub != correct_sub { - return Err(new_error(ErrorKind::InvalidSubject)); - } - } else { + if let Some(correct_sub) = options.sub.as_deref() { + if !matches!(claims.sub, TryParse::Parsed(sub) if sub == correct_sub) { return Err(new_error(ErrorKind::InvalidSubject)); } } if let Some(ref correct_iss) = options.iss { - if let Some(Value::String(iss)) = claims.get("iss") { - if !correct_iss.contains(iss) { - return Err(new_error(ErrorKind::InvalidIssuer)); - } - } else { + if !matches!(claims.iss, TryParse::Parsed(iss) if correct_iss.contains(&*iss)) { return Err(new_error(ErrorKind::InvalidIssuer)); } } if let Some(ref correct_aud) = options.aud { - if let Some(aud) = claims.get("aud") { - match aud { - Value::String(aud) => { - if !correct_aud.contains(aud) { - return Err(new_error(ErrorKind::InvalidAudience)); + match claims.aud { + TryParse::Parsed(Audience::Single(aud)) if correct_aud.contains(&*aud) => {} + TryParse::Parsed(Audience::Multiple(aud)) + if { + // Check that intersection is non-empty, favoring iterating on smallest + if correct_aud.len() < aud.len() { + correct_aud.iter().any(|a| aud.contains(&**a)) + } else { + aud.iter().any(|a| correct_aud.contains(&*a.0)) } - } - Value::Array(_) => { - use serde::Deserialize; - let aud = HashSet::::deserialize(aud)?; - if aud.intersection(correct_aud).next().is_none() { - return Err(new_error(ErrorKind::InvalidAudience)); - } - } - _ => return Err(new_error(ErrorKind::InvalidAudience)), - }; - } else { - return Err(new_error(ErrorKind::InvalidAudience)); + } => {} + _ => return Err(new_error(ErrorKind::InvalidAudience)), } } @@ -187,27 +201,28 @@ pub fn validate(claims: &Map, options: &Validation) -> Result<()> #[cfg(test)] mod tests { - use serde_json::map::Map; - use serde_json::to_value; + use serde_json::json; - use super::{get_current_timestamp, validate, Validation}; + use super::{get_current_timestamp, validate, ClaimsForValidation, Validation}; use crate::errors::ErrorKind; use crate::Algorithm; + fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation { + serde::Deserialize::deserialize(claims).unwrap() + } + #[test] fn exp_in_future_ok() { - let mut claims = Map::new(); - claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap()); - let res = validate(&claims, &Validation::new(Algorithm::HS256)); + let claims = json!({ "exp": get_current_timestamp() + 10000 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); assert!(res.is_ok()); } #[test] fn exp_in_past_fails() { - let mut claims = Map::new(); - claims.insert("exp".to_string(), to_value(get_current_timestamp() - 100000).unwrap()); - let res = validate(&claims, &Validation::new(Algorithm::HS256)); + let claims = json!({ "exp": get_current_timestamp() - 100000 }); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -218,19 +233,18 @@ mod tests { #[test] fn exp_in_past_but_in_leeway_ok() { - let mut claims = Map::new(); - claims.insert("exp".to_string(), to_value(get_current_timestamp() - 500).unwrap()); + let claims = json!({ "exp": get_current_timestamp() - 500 }); let mut validation = Validation::new(Algorithm::HS256); validation.leeway = 1000 * 60; - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } // https://github.com/Keats/jsonwebtoken/issues/51 #[test] fn validation_called_even_if_field_is_empty() { - let claims = Map::new(); - let res = validate(&claims, &Validation::new(Algorithm::HS256)); + let claims = json!({}); + let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256)); assert!(res.is_err()); match res.unwrap_err().kind() { ErrorKind::ExpiredSignature => (), @@ -240,23 +254,21 @@ mod tests { #[test] fn nbf_in_past_ok() { - let mut claims = Map::new(); - claims.insert("nbf".to_string(), to_value(get_current_timestamp() - 10000).unwrap()); + let claims = json!({ "nbf": get_current_timestamp() - 10000 }); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.validate_nbf = true; - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn nbf_in_future_fails() { - let mut claims = Map::new(); - claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 100000).unwrap()); + let claims = json!({ "nbf": get_current_timestamp() + 100000 }); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.validate_nbf = true; - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -267,37 +279,35 @@ mod tests { #[test] fn nbf_in_future_but_in_leeway_ok() { - let mut claims = Map::new(); - claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 500).unwrap()); + let claims = json!({ "nbf": get_current_timestamp() + 500 }); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.validate_nbf = true; validation.leeway = 1000 * 60; - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn iss_ok() { - let mut claims = Map::new(); - claims.insert("iss".to_string(), to_value("Keats").unwrap()); + let claims = json!({"iss": "Keats"}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_iss(&["Keats"]); - let res = validate(&claims, &validation); + + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn iss_not_matching_fails() { - let mut claims = Map::new(); - claims.insert("iss".to_string(), to_value("Hacked").unwrap()); + let claims = json!({"iss": "Hacked"}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_iss(&["Keats"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -308,13 +318,12 @@ mod tests { #[test] fn iss_missing_fails() { - let claims = Map::new(); + let claims = json!({}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_iss(&["Keats"]); - let res = validate(&claims, &validation); - assert!(res.is_err()); + let res = validate(deserialize_claims(&claims), &validation); match res.unwrap_err().kind() { ErrorKind::InvalidIssuer => (), @@ -324,25 +333,21 @@ mod tests { #[test] fn sub_ok() { - let mut claims = Map::new(); - claims.insert("sub".to_string(), to_value("Keats").unwrap()); + let claims = json!({"sub": "Keats"}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.sub = Some("Keats".to_owned()); - - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn sub_not_matching_fails() { - let mut claims = Map::new(); - claims.insert("sub".to_string(), to_value("Hacked").unwrap()); + let claims = json!({"sub": "Hacked"}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.sub = Some("Keats".to_owned()); - - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -353,12 +358,11 @@ mod tests { #[test] fn sub_missing_fails() { - let claims = Map::new(); + let claims = json!({}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.sub = Some("Keats".to_owned()); - - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -369,34 +373,31 @@ mod tests { #[test] fn aud_string_ok() { - let mut claims = Map::new(); - claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap()); + let claims = json!({"aud": ["Everyone"]}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_audience(&["Everyone"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn aud_array_of_string_ok() { - let mut claims = Map::new(); - claims.insert("aud".to_string(), to_value(["UserA", "UserB"]).unwrap()); + let claims = json!({"aud": ["UserA", "UserB"]}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_audience(&["UserA", "UserB"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_ok()); } #[test] fn aud_type_mismatch_fails() { - let mut claims = Map::new(); - claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap()); + let claims = json!({"aud": ["Everyone"]}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_audience(&["UserA", "UserB"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -407,12 +408,11 @@ mod tests { #[test] fn aud_correct_type_not_matching_fails() { - let mut claims = Map::new(); - claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap()); + let claims = json!({"aud": ["Everyone"]}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_audience(&["None"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -423,11 +423,11 @@ mod tests { #[test] fn aud_missing_fails() { - let claims = Map::new(); + let claims = json!({}); let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = false; validation.set_audience(&["None"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); assert!(res.is_err()); match res.unwrap_err().kind() { @@ -439,16 +439,14 @@ mod tests { // https://github.com/Keats/jsonwebtoken/issues/51 #[test] fn does_validation_in_right_order() { - let mut claims = Map::new(); - claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap()); + let claims = json!({ "exp": get_current_timestamp() + 10000 }); let mut validation = Validation::new(Algorithm::HS256); validation.leeway = 5; validation.set_iss(&["iss no check"]); validation.set_audience(&["iss no check"]); - let res = validate(&claims, &validation); - + let res = validate(deserialize_claims(&claims), &validation); // It errors because it needs to validate iss/sub which are missing assert!(res.is_err()); match res.unwrap_err().kind() { @@ -460,11 +458,7 @@ mod tests { // https://github.com/Keats/jsonwebtoken/issues/110 #[test] fn aud_use_validation_struct() { - let mut claims = Map::new(); - claims.insert( - "aud".to_string(), - to_value("my-googleclientid1234.apps.googleusercontent.com").unwrap(), - ); + let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"}); let aud = "my-googleclientid1234.apps.googleusercontent.com".to_string(); let mut aud_hashset = std::collections::HashSet::new(); @@ -473,7 +467,7 @@ mod tests { validation.validate_exp = false; validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]); - let res = validate(&claims, &validation); + let res = validate(deserialize_claims(&claims), &validation); println!("{:?}", res); assert!(res.is_ok()); }