revised set_audience, cleaned up validation, and cleared compiler warnings

This commit is contained in:
dowwie 2019-10-28 11:49:02 -04:00
parent fe10accb6e
commit 68d6c84c8c
1 changed files with 18 additions and 43 deletions

View File

@ -1,11 +1,10 @@
use std::collections::HashSet; use std::collections::HashSet;
use chrono::Utc; use chrono::Utc;
use serde::ser::Serialize;
use serde_json::map::Map; use serde_json::map::Map;
use serde_json::{from_value, to_value, Value}; use serde_json::{from_value, Value};
use crate::algorithms::Algorithm; use crypto::Algorithm;
use crate::errors::{new_error, ErrorKind, Result}; use errors::{new_error, ErrorKind, Result};
/// Contains the various validations that are applied after decoding a token. /// Contains the various validations that are applied after decoding a token.
/// ///
@ -22,7 +21,7 @@ use crate::errors::{new_error, ErrorKind, Result};
/// ///
/// // Setting audience /// // Setting audience
/// let mut validation = Validation::default(); /// let mut validation = Validation::default();
/// validation.set_audience(&"Me"); // string /// validation.set_audience(&["Me"]); // a single string
/// validation.set_audience(&["Me", "You"]); // array of strings /// validation.set_audience(&["Me", "You"]); // array of strings
/// ``` /// ```
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -44,10 +43,8 @@ pub struct Validation {
/// ///
/// Defaults to `false`. /// Defaults to `false`.
pub validate_nbf: bool, pub validate_nbf: bool,
/// If it contains a value, the validation will check that the `aud` field is the same as the /// If it contains a value, the validation will check that the `aud` field is a member of the
/// one provided and will error otherwise. /// audience provided and will error otherwise.
/// Since `aud` can be either a String or a Vec<String> in the JWT spec, you will need to use
/// the [set_audience](struct.Validation.html#method.set_audience) method to set it.
/// ///
/// Defaults to `None`. /// Defaults to `None`.
pub aud: Option<HashSet<String>>, pub aud: Option<HashSet<String>>,
@ -76,31 +73,9 @@ impl Validation {
validation validation
} }
/// Since `aud` can be either a String or an array of String in the JWT spec, this method will take /// `aud` is a collection of one or more acceptable audience members
/// care of serializing the value. pub fn set_audience<T: ToString>(&mut self, items: &[T]) {
pub fn set_audience<T: Serialize>(&mut self, audience: &T) { self.aud = Some(items.iter().map(|x| x.to_string()).collect())
let aud = to_value(audience)
.unwrap_or_else(|_| panic!("Failed to_value within set_audience)"));
let aud = Validation::convert_aud(&aud)
.unwrap_or_else(|_| panic!("Failed convert_aud within set_audience"));
self.aud = Some(aud);
}
/// Converts a Value, representing a String or collection of Strings, to a
/// HashSet<String>, required for audience membership testing
fn convert_aud(aud: &Value) -> Result<HashSet<String>> {
let aud_from_claim: Vec<String> = match aud.is_array() {
true => from_value(aud.clone()).unwrap(),
false => {
let aud_str: String = match from_value(aud.clone()) {
Ok(val) => val,
Err(_) => return Err(new_error(ErrorKind::InvalidAudience)),
};
vec![aud_str]
}
};
Ok(aud_from_claim.into_iter().collect())
} }
} }
@ -166,8 +141,8 @@ pub fn validate(claims: &Map<String, Value>, options: &Validation) -> Result<()>
if let Some(ref correct_aud) = options.aud { if let Some(ref correct_aud) = options.aud {
if let Some(aud) = claims.get("aud") { if let Some(aud) = claims.get("aud") {
let converted_aud = Validation::convert_aud(aud)?; let provided_aud: HashSet<String> = from_value(aud.clone())?;
if converted_aud.intersection(correct_aud).count() == 0 { if provided_aud.intersection(correct_aud).count() == 0 {
return Err(new_error(ErrorKind::InvalidAudience)); return Err(new_error(ErrorKind::InvalidAudience));
} }
} else { } else {
@ -186,7 +161,7 @@ mod tests {
use super::{validate, Validation}; use super::{validate, Validation};
use crate::errors::ErrorKind; use errors::ErrorKind;
#[test] #[test]
fn exp_in_future_ok() { fn exp_in_future_ok() {
@ -368,9 +343,9 @@ mod tests {
#[test] #[test]
fn aud_string_ok() { fn aud_string_ok() {
let mut claims = Map::new(); let mut claims = Map::new();
claims.insert("aud".to_string(), to_value("Everyone").unwrap()); claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
let mut validation = Validation { validate_exp: false, ..Validation::default() }; let mut validation = Validation { validate_exp: false, ..Validation::default() };
validation.set_audience(&"Everyone"); validation.set_audience(&["Everyone"]);
let res = validate(&claims, &validation); let res = validate(&claims, &validation);
assert!(res.is_ok()); assert!(res.is_ok());
} }
@ -388,7 +363,7 @@ mod tests {
#[test] #[test]
fn aud_type_mismatch_fails() { fn aud_type_mismatch_fails() {
let mut claims = Map::new(); let mut claims = Map::new();
claims.insert("aud".to_string(), to_value("Everyone").unwrap()); claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
let mut validation = Validation { validate_exp: false, ..Validation::default() }; let mut validation = Validation { validate_exp: false, ..Validation::default() };
validation.set_audience(&["UserA", "UserB"]); validation.set_audience(&["UserA", "UserB"]);
let res = validate(&claims, &validation); let res = validate(&claims, &validation);
@ -403,9 +378,9 @@ mod tests {
#[test] #[test]
fn aud_correct_type_not_matching_fails() { fn aud_correct_type_not_matching_fails() {
let mut claims = Map::new(); let mut claims = Map::new();
claims.insert("aud".to_string(), to_value("Everyone").unwrap()); claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
let mut validation = Validation { validate_exp: false, ..Validation::default() }; let mut validation = Validation { validate_exp: false, ..Validation::default() };
validation.set_audience(&"None"); validation.set_audience(&["None"]);
let res = validate(&claims, &validation); let res = validate(&claims, &validation);
assert!(res.is_err()); assert!(res.is_err());
@ -419,7 +394,7 @@ mod tests {
fn aud_missing_fails() { fn aud_missing_fails() {
let claims = Map::new(); let claims = Map::new();
let mut validation = Validation { validate_exp: false, ..Validation::default() }; let mut validation = Validation { validate_exp: false, ..Validation::default() };
validation.set_audience(&"None"); validation.set_audience(&["None"]);
let res = validate(&claims, &validation); let res = validate(&claims, &validation);
assert!(res.is_err()); assert!(res.is_err());