From 410499e6b6c6559039d62229d1827798e53e2bf6 Mon Sep 17 00:00:00 2001 From: Vincent Prouillet Date: Tue, 11 Apr 2017 14:41:44 +0900 Subject: [PATCH] Add validation --- Cargo.toml | 3 +- benches/jwt.rs | 4 +- examples/claims.rs | 4 +- examples/custom_header.rs | 4 +- src/crypto.rs | 35 ++--- src/errors.rs | 7 + src/lib.rs | 5 +- src/serialization.rs | 20 ++- src/validation.rs | 313 ++++++++++++++++++++++++++++++++++++++ tests/lib.rs | 18 +-- tests/rsa.rs | 4 +- 11 files changed, 368 insertions(+), 49 deletions(-) create mode 100644 src/validation.rs diff --git a/Cargo.toml b/Cargo.toml index 77b6376..28f1517 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,11 @@ keywords = ["jwt", "web", "api", "token", "json"] [dependencies] rustc-serialize = "^0.3" -error-chain = "0.9" +error-chain = "0.10" serde_json = "0.9" serde_derive = "0.9" serde = "0.9" ring = { version = "0.7", features = ["rsa_signing", "dev_urandom_fallback"] } base64 = "0.4" untrusted = "0.3" +chrono = "0.3" diff --git a/benches/jwt.rs b/benches/jwt.rs index 0744efe..bafe8de 100644 --- a/benches/jwt.rs +++ b/benches/jwt.rs @@ -4,7 +4,7 @@ extern crate jsonwebtoken as jwt; #[macro_use] extern crate serde_derive; -use jwt::{encode, decode, Algorithm, Header}; +use jwt::{encode, decode, Algorithm, Header, Validation}; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] struct Claims { @@ -25,5 +25,5 @@ fn bench_encode(b: &mut test::Bencher) { #[bench] fn bench_decode(b: &mut test::Bencher) { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"; - b.iter(|| decode::(token, "secret".as_ref(), Algorithm::HS256)); + b.iter(|| decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default())); } diff --git a/examples/claims.rs b/examples/claims.rs index d24c056..4f75626 100644 --- a/examples/claims.rs +++ b/examples/claims.rs @@ -1,7 +1,7 @@ extern crate jsonwebtoken as jwt; #[macro_use] extern crate serde_derive; -use jwt::{encode, decode, Header, Algorithm}; +use jwt::{encode, decode, Header, Algorithm, Validation}; use jwt::errors::{ErrorKind}; @@ -36,7 +36,7 @@ fn main() { println!("{:?}", token); - let token_data = match decode::(&token, key.as_ref(), Algorithm::HS256) { + let token_data = match decode::(&token, key.as_ref(), Algorithm::HS256, Validation::default()) { Ok(c) => c, Err(err) => match *err.kind() { ErrorKind::InvalidToken => panic!(), // Example on how to handle a specific error diff --git a/examples/custom_header.rs b/examples/custom_header.rs index c93bccc..62aad32 100644 --- a/examples/custom_header.rs +++ b/examples/custom_header.rs @@ -2,7 +2,7 @@ extern crate jsonwebtoken as jwt; #[macro_use] extern crate serde_derive; -use jwt::{encode, decode, Header, Algorithm}; +use jwt::{encode, decode, Header, Algorithm, Validation}; use jwt::errors::{ErrorKind}; @@ -28,7 +28,7 @@ fn main() { Err(_) => panic!() // in practice you would return the error }; - let token_data = match decode::(&token, key.as_ref(), Algorithm::HS512) { + let token_data = match decode::(&token, key.as_ref(), Algorithm::HS512, Validation::default()) { Ok(c) => c, Err(err) => match *err.kind() { ErrorKind::InvalidToken => panic!(), // Example on how to handle a specific error diff --git a/src/crypto.rs b/src/crypto.rs index f93e653..cd141ed 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -10,7 +10,8 @@ use untrusted; use errors::{Result, ErrorKind}; use header::Header; -use serialization::{from_jwt_part, to_jwt_part, TokenData}; +use serialization::{from_jwt_part, to_jwt_part, from_jwt_part_claims, TokenData}; +use validation::{Validation, validate}; /// The algorithms supported for signing/verifying @@ -112,7 +113,6 @@ pub fn verify(signature: &str, signing_input: &str, key: &[u8], algorithm: Algor message, expected_signature, ); - println!("{:?}", res); Ok(res.is_ok()) }, @@ -131,14 +131,14 @@ macro_rules! expect_two { }} } -/// Decode fn used internally by `decode` and `decode_without_verifying` -fn internal_decode(token: &str, key: &[u8], algorithm: Algorithm, do_verification: bool) -> Result> { +/// Decode a token into a struct containing Claims and Header +/// +/// If the token or its signature is invalid, it will return an error +pub fn decode(token: &str, key: &[u8], algorithm: Algorithm, validation: Validation) -> Result> { let (signature, signing_input) = expect_two!(token.rsplitn(2, '.')); - if do_verification { - if !verify(signature, signing_input, key, algorithm)? { - return Err(ErrorKind::InvalidSignature.into()); - } + if validation.validate_signature && !verify(signature, signing_input, key, algorithm)? { + return Err(ErrorKind::InvalidSignature.into()); } let (claims, header) = expect_two!(signing_input.rsplitn(2, '.')); @@ -147,22 +147,9 @@ fn internal_decode(token: &str, key: &[u8], algorithm: Algorithm if header.alg != algorithm { return Err(ErrorKind::WrongAlgorithmHeader.into()); } - let decoded_claims: T = from_jwt_part(claims)?; + let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?; + + validate(&claims_map, &validation)?; Ok(TokenData { header: header, claims: decoded_claims }) } - -/// Decode a token into a struct containing Claims and Header -/// -/// If the token or its signature is invalid, it will return an error -pub fn decode(token: &str, key: &[u8], algorithm: Algorithm) -> Result> { - internal_decode(token, key, algorithm, true) -} - -/// Decode a token into a struct containing Claims and Header -/// WARNING: this will not do any verification so only use that at your own risk -/// -/// If the token is invalid, it will return an error -pub fn decode_without_verification(token: &str, key: &[u8], algorithm: Algorithm) -> Result> { - internal_decode(token, key, algorithm, false) -} diff --git a/src/errors.rs b/src/errors.rs index f7e27e4..c243760 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -25,6 +25,8 @@ error_chain! { display("Invalid Key") } + // Validation error + /// When a token’s `exp` claim indicates that it has expired ExpiredSignature { description("expired signature") @@ -40,6 +42,11 @@ error_chain! { description("invalid audience") display("Invalid Audience") } + /// When a token’s `aud` claim does not match one of the expected audience values + InvalidSubject { + description("invalid subject") + display("Invalid Subject") + } /// When a token’s `iat` claim is in the future InvalidIssuedAt { description("invalid issued at") diff --git a/src/lib.rs b/src/lib.rs index bbbc1cb..3544903 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,11 +12,13 @@ extern crate serde; extern crate base64; extern crate ring; extern crate untrusted; +extern crate chrono; pub mod errors; mod header; mod crypto; mod serialization; +mod validation; pub use header::{Header}; pub use crypto::{ @@ -25,6 +27,5 @@ pub use crypto::{ verify, encode, decode, - decode_without_verification, }; - +pub use validation::Validation; diff --git a/src/serialization.rs b/src/serialization.rs index be4d736..41df835 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -1,8 +1,8 @@ use base64; use serde::de::Deserialize; use serde::ser::Serialize; -use serde_json; - +use serde_json::{from_str, to_string, Value}; +use serde_json::map::Map; use errors::{Result}; use header::Header; @@ -17,14 +17,24 @@ pub struct TokenData { /// Serializes to JSON and encodes to base64 pub fn to_jwt_part(input: &T) -> Result { - let encoded = serde_json::to_string(input)?; + let encoded = to_string(input)?; Ok(base64::encode_config(encoded.as_bytes(), base64::URL_SAFE_NO_PAD)) } -/// Decodes from base64 and deserializes from JSON +/// Decodes from base64 and deserializes from JSON to a struct pub fn from_jwt_part, T: Deserialize>(encoded: B) -> Result { let decoded = base64::decode_config(encoded.as_ref(), base64::URL_SAFE_NO_PAD)?; let s = String::from_utf8(decoded)?; - Ok(serde_json::from_str(&s)?) + Ok(from_str(&s)?) +} + +/// Decodes from base64 and deserializes from JSON to a struct AND a hashmap +pub fn from_jwt_part_claims, T: Deserialize>(encoded: B) -> Result<(T, Map)> { + let decoded = base64::decode_config(encoded.as_ref(), base64::URL_SAFE_NO_PAD)?; + let s = String::from_utf8(decoded)?; + + let claims: T = from_str(&s)?; + let map: Map<_,_> = from_str(&s)?; + Ok((claims, map)) } diff --git a/src/validation.rs b/src/validation.rs new file mode 100644 index 0000000..5a8a67c --- /dev/null +++ b/src/validation.rs @@ -0,0 +1,313 @@ +use chrono::UTC; +use serde::ser::Serialize; +use serde_json::{Value, from_value, to_value}; +use serde_json::map::Map; + +use errors::{Result, ErrorKind}; + + +#[derive(Debug, Clone, PartialEq)] +pub struct Validation { + pub leeway: i64, + pub validate_signature: bool, + pub validate_exp: bool, + pub validate_iat: bool, + pub validate_nbf: bool, + + pub aud: Option, + pub iss: Option, + pub sub: Option, +} + +impl Validation { + pub fn set_audience(&mut self, audience: &T) { + self.aud = Some(to_value(audience).unwrap()); + } +} + +impl Default for Validation { + fn default() -> Validation { + Validation { + leeway: 0, + + validate_signature: true, + + validate_exp: true, + validate_iat: true, + validate_nbf: true, + + iss: None, + sub: None, + aud: None, + } + } +} + + + +pub fn validate(claims: &Map, options: &Validation) -> Result<()> { + let now = UTC::now().timestamp(); + + if let Some(iat) = claims.get("iat") { + if options.validate_iat && from_value::(iat.clone())? > now + options.leeway { + return Err(ErrorKind::InvalidIssuedAt.into()); + } + } + + if let Some(exp) = claims.get("exp") { + if options.validate_exp && from_value::(exp.clone())? < now - options.leeway { + return Err(ErrorKind::ExpiredSignature.into()); + } + } + + if let Some(nbf) = claims.get("nbf") { + if options.validate_nbf && from_value::(nbf.clone())? > now + options.leeway { + return Err(ErrorKind::ImmatureSignature.into()); + } + } + + if let Some(iss) = claims.get("iss") { + if let Some(ref correct_iss) = options.iss { + if from_value::(iss.clone())? != *correct_iss { + return Err(ErrorKind::InvalidIssuer.into()); + } + } + } + + if let Some(sub) = claims.get("sub") { + if let Some(ref correct_sub) = options.sub { + if from_value::(sub.clone())? != *correct_sub { + return Err(ErrorKind::InvalidSubject.into()); + } + } + } + + if let Some(aud) = claims.get("aud") { + if let Some(ref correct_aud) = options.aud { + if aud != correct_aud { + return Err(ErrorKind::InvalidAudience.into()); + } + } + } + + Ok(()) +} + + +#[cfg(test)] +mod tests { + use serde_json::{to_value}; + use serde_json::map::Map; + use chrono::UTC; + + use super::{validate, Validation}; + + use errors::ErrorKind; + + #[test] + fn iat_in_past_ok() { + let mut claims = Map::new(); + claims.insert("iat".to_string(), to_value(UTC::now().timestamp() - 10000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_ok()); + } + + #[test] + fn iat_in_future_fails() { + let mut claims = Map::new(); + claims.insert("iat".to_string(), to_value(UTC::now().timestamp() + 100000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::InvalidIssuedAt => (), + _ => assert!(false), + }; + } + + #[test] + fn iat_in_future_but_in_leeway_ok() { + let mut claims = Map::new(); + claims.insert("iat".to_string(), to_value(UTC::now().timestamp() + 50).unwrap()); + let validation = Validation { + leeway: 1000 * 60, + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn exp_in_future_ok() { + let mut claims = Map::new(); + claims.insert("exp".to_string(), to_value(UTC::now().timestamp() + 10000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_ok()); + } + + #[test] + fn exp_in_past_fails() { + let mut claims = Map::new(); + claims.insert("exp".to_string(), to_value(UTC::now().timestamp() - 100000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::ExpiredSignature => (), + _ => assert!(false), + }; + } + + #[test] + fn exp_in_past_but_in_leeway_ok() { + let mut claims = Map::new(); + claims.insert("exp".to_string(), to_value(UTC::now().timestamp() - 500).unwrap()); + let validation = Validation { + leeway: 1000 * 60, + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn nbf_in_past_ok() { + let mut claims = Map::new(); + claims.insert("nbf".to_string(), to_value(UTC::now().timestamp() - 10000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_ok()); + } + + #[test] + fn nbf_in_future_fails() { + let mut claims = Map::new(); + claims.insert("nbf".to_string(), to_value(UTC::now().timestamp() + 100000).unwrap()); + let res = validate(&claims, &Validation::default()); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::ImmatureSignature => (), + _ => assert!(false), + }; + } + + #[test] + fn nbf_in_future_but_in_leeway_ok() { + let mut claims = Map::new(); + claims.insert("nbf".to_string(), to_value(UTC::now().timestamp() + 500).unwrap()); + let validation = Validation { + leeway: 1000 * 60, + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn iss_ok() { + let mut claims = Map::new(); + claims.insert("iss".to_string(), to_value("Keats").unwrap()); + let validation = Validation { + iss: Some("Keats".to_string()), + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn iss_not_matching_fails() { + let mut claims = Map::new(); + claims.insert("iss".to_string(), to_value("Hacked").unwrap()); + let validation = Validation { + iss: Some("Keats".to_string()), + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::InvalidIssuer => (), + _ => assert!(false), + }; + } + + #[test] + fn sub_ok() { + let mut claims = Map::new(); + claims.insert("sub".to_string(), to_value("Keats").unwrap()); + let validation = Validation { + sub: Some("Keats".to_string()), + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn sub_not_matching_fails() { + let mut claims = Map::new(); + claims.insert("sub".to_string(), to_value("Hacked").unwrap()); + let validation = Validation { + sub: Some("Keats".to_string()), + ..Default::default() + }; + let res = validate(&claims, &validation); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::InvalidSubject => (), + _ => assert!(false), + }; + } + + #[test] + fn aud_string_ok() { + let mut claims = Map::new(); + claims.insert("aud".to_string(), to_value("Everyone").unwrap()); + let mut validation = Validation::default(); + validation.set_audience(&"Everyone"); + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn aud_array_of_string_ok() { + let mut claims = Map::new(); + claims.insert("aud".to_string(), to_value(["UserA", "UserB"]).unwrap()); + let mut validation = Validation::default(); + validation.set_audience(&["UserA", "UserB"]); + let res = validate(&claims, &validation); + assert!(res.is_ok()); + } + + #[test] + fn aud_type_mismatch_fails() { + let mut claims = Map::new(); + claims.insert("aud".to_string(), to_value("Everyone").unwrap()); + let mut validation = Validation::default(); + validation.set_audience(&["UserA", "UserB"]); + let res = validate(&claims, &validation); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::InvalidAudience => (), + _ => assert!(false), + }; + } + + #[test] + fn aud_correct_type_not_matching_fails() { + let mut claims = Map::new(); + claims.insert("aud".to_string(), to_value("Everyone").unwrap()); + let mut validation = Validation::default(); + validation.set_audience(&"None"); + let res = validate(&claims, &validation); + assert!(res.is_err()); + + match res.unwrap_err().kind() { + &ErrorKind::InvalidAudience => (), + _ => assert!(false), + }; + } +} diff --git a/tests/lib.rs b/tests/lib.rs index abca7f1..bb84cf1 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -2,7 +2,7 @@ extern crate jsonwebtoken; #[macro_use] extern crate serde_derive; -use jsonwebtoken::{encode, decode, Algorithm, Header, sign, verify}; +use jsonwebtoken::{encode, decode, Algorithm, Header, sign, verify, Validation}; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] @@ -34,7 +34,7 @@ fn encode_with_custom_header() { let mut header = Header::default(); header.kid = Some("kid".to_string()); let token = encode(&header, &my_claims, "secret".as_ref()).unwrap(); - let token_data = decode::(&token, "secret".as_ref(), Algorithm::HS256).unwrap(); + let token_data = decode::(&token, "secret".as_ref(), Algorithm::HS256, Validation::default()).unwrap(); assert_eq!(my_claims, token_data.claims); assert_eq!("kid", token_data.header.kid.unwrap()); } @@ -46,7 +46,7 @@ fn round_trip_claim() { company: "ACME".to_string() }; let token = encode(&Header::default(), &my_claims, "secret".as_ref()).unwrap(); - let token_data = decode::(&token, "secret".as_ref(), Algorithm::HS256).unwrap(); + let token_data = decode::(&token, "secret".as_ref(), Algorithm::HS256, Validation::default()).unwrap(); assert_eq!(my_claims, token_data.claims); assert!(token_data.header.kid.is_none()); } @@ -54,7 +54,7 @@ fn round_trip_claim() { #[test] fn decode_token() { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.I1BvFoHe94AFf09O6tDbcSB8-jp8w6xZqmyHIwPeSdY"; - let claims = decode::(token, "secret".as_ref(), Algorithm::HS256); + let claims = decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default()); claims.unwrap(); } @@ -62,7 +62,7 @@ fn decode_token() { #[should_panic(expected = "InvalidToken")] fn decode_token_missing_parts() { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"; - let claims = decode::(token, "secret".as_ref(), Algorithm::HS256); + let claims = decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default()); claims.unwrap(); } @@ -70,7 +70,7 @@ fn decode_token_missing_parts() { #[should_panic(expected = "InvalidSignature")] fn decode_token_invalid_signature() { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.wrong"; - let claims = decode::(token, "secret".as_ref(), Algorithm::HS256); + let claims = decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default()); claims.unwrap(); } @@ -78,20 +78,20 @@ fn decode_token_invalid_signature() { #[should_panic(expected = "WrongAlgorithmHeader")] fn decode_token_wrong_algorithm() { let token = "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.pKscJVk7-aHxfmQKlaZxh5uhuKhGMAa-1F5IX5mfUwI"; - let claims = decode::(token, "secret".as_ref(), Algorithm::HS256); + let claims = decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default()); claims.unwrap(); } #[test] fn decode_token_with_bytes_secret() { let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiY29tcGFueSI6Ikdvb2dvbCJ9.27QxgG96vpX4akKNpD1YdRGHE3_u2X35wR3EHA2eCrs"; - let claims = decode::(token, b"\x01\x02\x03", Algorithm::HS256); + let claims = decode::(token, b"\x01\x02\x03", Algorithm::HS256, Validation::default()); assert!(claims.is_ok()); } #[test] fn decode_token_with_shuffled_header_fields() { let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJjb21wYW55IjoiMTIzNDU2Nzg5MCIsInN1YiI6IkpvaG4gRG9lIn0.SEIZ4Jg46VGhquuwPYDLY5qHF8AkQczF14aXM3a2c28"; - let claims = decode::(token, "secret".as_ref(), Algorithm::HS256); + let claims = decode::(token, "secret".as_ref(), Algorithm::HS256, Validation::default()); assert!(claims.is_ok()); } diff --git a/tests/rsa.rs b/tests/rsa.rs index 6d678ce..8ed2904 100644 --- a/tests/rsa.rs +++ b/tests/rsa.rs @@ -2,7 +2,7 @@ extern crate jsonwebtoken; #[macro_use] extern crate serde_derive; -use jsonwebtoken::{encode, decode, Algorithm, Header, sign, verify}; +use jsonwebtoken::{encode, decode, Algorithm, Header, sign, verify, Validation}; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] @@ -26,7 +26,7 @@ fn round_trip_claim() { company: "ACME".to_string() }; let token = encode(&Header::new(Algorithm::RS256), &my_claims, include_bytes!("private_rsa_key.der")).unwrap(); - let token_data = decode::(&token, include_bytes!("public_rsa_key.der"), Algorithm::RS256).unwrap(); + let token_data = decode::(&token, include_bytes!("public_rsa_key.der"), Algorithm::RS256, Validation::default()).unwrap(); assert_eq!(my_claims, token_data.claims); assert!(token_data.header.kid.is_none()); }