diff --git a/README.md b/README.md index 269c1a4..597d7f3 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,9 @@ This library currently supports the following: - RS256 - RS384 - RS512 +- PS256 +- PS384 +- PS512 - ES256 - ES384 diff --git a/src/crypto.rs b/src/crypto.rs index 76dccd2..2ada40d 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -30,6 +30,13 @@ pub enum Algorithm { RS384, /// RSASSA-PKCS1-v1_5 using SHA-512 RS512, + + /// RSASSA-PSS using SHA-256 + PS256, + /// RSASSA-PSS using SHA-384 + PS384, + /// RSASSA-PSS using SHA-512 + PS512, } impl Default for Algorithm { @@ -50,6 +57,9 @@ impl FromStr for Algorithm { "RS256" => Ok(Algorithm::RS256), "RS384" => Ok(Algorithm::RS384), "RS512" => Ok(Algorithm::RS512), + "PS256" => Ok(Algorithm::PS256), + "PS384" => Ok(Algorithm::PS384), + "PS512" => Ok(Algorithm::PS512), _ => Err(new_error(ErrorKind::InvalidAlgorithmName)), } } @@ -64,7 +74,11 @@ fn sign_hmac(alg: &'static digest::Algorithm, key: &[u8], signing_input: &str) - } /// The actual ECDSA signing + encoding -fn sign_ecdsa(alg: &'static signature::EcdsaSigningAlgorithm, key: &[u8], signing_input: &str) -> Result { +fn sign_ecdsa( + alg: &'static signature::EcdsaSigningAlgorithm, + key: &[u8], + signing_input: &str, +) -> Result { let signing_key = signature::EcdsaKeyPair::from_pkcs8(alg, untrusted::Input::from(key))?; let rng = rand::SystemRandom::new(); let sig = signing_key.sign(&rng, untrusted::Input::from(signing_input.as_bytes()))?; @@ -73,7 +87,11 @@ fn sign_ecdsa(alg: &'static signature::EcdsaSigningAlgorithm, key: &[u8], signin /// The actual RSA signing + encoding /// Taken from Ring doc https://briansmith.org/rustdoc/ring/signature/index.html -fn sign_rsa(alg: &'static signature::RsaEncoding, key: &[u8], signing_input: &str) -> Result { +fn sign_rsa( + alg: &'static dyn signature::RsaEncoding, + key: &[u8], + signing_input: &str, +) -> Result { let key_pair = Arc::new( signature::RsaKeyPair::from_der(untrusted::Input::from(key)) .map_err(|_| ErrorKind::InvalidRsaKey)?, @@ -97,12 +115,20 @@ pub fn sign(signing_input: &str, key: &[u8], algorithm: Algorithm) -> Result sign_hmac(&digest::SHA384, key, signing_input), Algorithm::HS512 => sign_hmac(&digest::SHA512, key, signing_input), - Algorithm::ES256 => sign_ecdsa(&signature::ECDSA_P256_SHA256_FIXED_SIGNING, key, signing_input), - Algorithm::ES384 => sign_ecdsa(&signature::ECDSA_P384_SHA384_FIXED_SIGNING, key, signing_input), + Algorithm::ES256 => { + sign_ecdsa(&signature::ECDSA_P256_SHA256_FIXED_SIGNING, key, signing_input) + } + Algorithm::ES384 => { + sign_ecdsa(&signature::ECDSA_P384_SHA384_FIXED_SIGNING, key, signing_input) + } Algorithm::RS256 => sign_rsa(&signature::RSA_PKCS1_SHA256, key, signing_input), Algorithm::RS384 => sign_rsa(&signature::RSA_PKCS1_SHA384, key, signing_input), Algorithm::RS512 => sign_rsa(&signature::RSA_PKCS1_SHA512, key, signing_input), + + Algorithm::PS256 => sign_rsa(&signature::RSA_PSS_SHA256, key, signing_input), + Algorithm::PS384 => sign_rsa(&signature::RSA_PSS_SHA384, key, signing_input), + Algorithm::PS512 => sign_rsa(&signature::RSA_PSS_SHA512, key, signing_input), } } @@ -158,5 +184,14 @@ pub fn verify( Algorithm::RS512 => { verify_ring(&signature::RSA_PKCS1_2048_8192_SHA512, signature, signing_input, key) } + Algorithm::PS256 => { + verify_ring(&signature::RSA_PSS_2048_8192_SHA256, signature, signing_input, key) + } + Algorithm::PS384 => { + verify_ring(&signature::RSA_PSS_2048_8192_SHA384, signature, signing_input, key) + } + Algorithm::PS512 => { + verify_ring(&signature::RSA_PSS_2048_8192_SHA512, signature, signing_input, key) + } } } diff --git a/src/errors.rs b/src/errors.rs index 8111596..052e1ee 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -97,7 +97,7 @@ impl StdError for Error { } } - fn cause(&self) -> Option<&StdError> { + fn cause(&self) -> Option<&dyn StdError> { match *self.0 { ErrorKind::InvalidToken => None, ErrorKind::InvalidSignature => None, diff --git a/tests/lib.rs b/tests/lib.rs index 129ee3d..33f3cd1 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -166,5 +166,8 @@ fn generate_algorithm_enum_from_str() { assert!(Algorithm::from_str("RS256").is_ok()); assert!(Algorithm::from_str("RS384").is_ok()); assert!(Algorithm::from_str("RS512").is_ok()); + assert!(Algorithm::from_str("PS256").is_ok()); + assert!(Algorithm::from_str("PS384").is_ok()); + assert!(Algorithm::from_str("PS512").is_ok()); assert!(Algorithm::from_str("").is_err()); } diff --git a/tests/rsa.rs b/tests/rsa.rs index def1f9f..0ee80d2 100644 --- a/tests/rsa.rs +++ b/tests/rsa.rs @@ -6,6 +6,15 @@ extern crate chrono; use chrono::Utc; use jsonwebtoken::{decode, encode, sign, verify, Algorithm, Header, Validation}; +const RSA_ALGORITHMS: &[Algorithm] = &[ + Algorithm::RS256, + Algorithm::RS384, + Algorithm::RS512, + Algorithm::PS256, + Algorithm::PS384, + Algorithm::PS512, +]; + #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] struct Claims { sub: String, @@ -15,12 +24,12 @@ struct Claims { #[test] fn round_trip_sign_verification() { - let encrypted = - sign("hello world", include_bytes!("private_rsa_key.der"), Algorithm::RS256).unwrap(); - let is_valid = - verify(&encrypted, "hello world", include_bytes!("public_rsa_key.der"), Algorithm::RS256) - .unwrap(); - assert!(is_valid); + for &alg in RSA_ALGORITHMS { + let encrypted = sign("hello world", include_bytes!("private_rsa_key.der"), alg).unwrap(); + let is_valid = + verify(&encrypted, "hello world", include_bytes!("public_rsa_key.der"), alg).unwrap(); + assert!(is_valid); + } } #[test] @@ -30,15 +39,13 @@ fn round_trip_claim() { company: "ACME".to_string(), exp: Utc::now().timestamp() + 10000, }; - let token = - encode(&Header::new(Algorithm::RS256), &my_claims, include_bytes!("private_rsa_key.der")) - .unwrap(); - let token_data = decode::( - &token, - include_bytes!("public_rsa_key.der"), - &Validation::new(Algorithm::RS256), - ) - .unwrap(); - assert_eq!(my_claims, token_data.claims); - assert!(token_data.header.kid.is_none()); + for &alg in RSA_ALGORITHMS { + let token = + encode(&Header::new(alg), &my_claims, include_bytes!("private_rsa_key.der")).unwrap(); + let token_data = + decode::(&token, include_bytes!("public_rsa_key.der"), &Validation::new(alg)) + .unwrap(); + assert_eq!(my_claims, token_data.claims); + assert!(token_data.header.kid.is_none()); + } }