diff --git a/src/algorithms.rs b/src/algorithms.rs index c9ff848..20edbf1 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -2,6 +2,13 @@ use crate::errors::{Error, ErrorKind, Result}; use serde::{Deserialize, Serialize}; use std::str::FromStr; +#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] +pub(crate) enum AlgorithmFamily { + Hmac, + Rsa, + Ec, +} + /// The algorithms supported for signing/verifying JWTs #[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)] pub enum Algorithm { @@ -58,6 +65,21 @@ impl FromStr for Algorithm { } } +impl Algorithm { + pub(crate) fn family(self) -> AlgorithmFamily { + match self { + Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac, + Algorithm::RS256 + | Algorithm::RS384 + | Algorithm::RS512 + | Algorithm::PS256 + | Algorithm::PS384 + | Algorithm::PS512 => AlgorithmFamily::Rsa, + Algorithm::ES256 | Algorithm::ES384 => AlgorithmFamily::Ec, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/decoding.rs b/src/decoding.rs index bbd5705..36837f0 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use serde::de::DeserializeOwned; +use crate::algorithms::AlgorithmFamily; use crate::crypto::verify; use crate::errors::{new_error, ErrorKind, Result}; use crate::header::Header; @@ -40,37 +41,70 @@ pub(crate) enum DecodingKeyKind<'a> { /// This key can be re-used so make sure you only initialize it once if you can for better performance #[derive(Debug, Clone, PartialEq)] pub struct DecodingKey<'a> { + pub(crate) family: AlgorithmFamily, pub(crate) kind: DecodingKeyKind<'a>, } impl<'a> DecodingKey<'a> { /// If you're using HMAC, use this. pub fn from_secret(secret: &'a [u8]) -> Self { - DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)) } + DecodingKey { + family: AlgorithmFamily::Hmac, + kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)), + } + } + + /// If you're using HMAC with a base64 encoded, use this. + pub fn from_base64_secret(secret: &str) -> Result { + let out = base64::decode(&secret)?; + Ok(DecodingKey { + family: AlgorithmFamily::Hmac, + kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)), + }) } /// If you are loading a public RSA key in a PEM format, use this. pub fn from_rsa_pem(key: &'a [u8]) -> Result { let pem_key = PemEncodedKey::new(key)?; let content = pem_key.as_rsa_key()?; - Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) }) + Ok(DecodingKey { + family: AlgorithmFamily::Rsa, + kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())), + }) } /// If you have (n, e) RSA public key components, use this. pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self { - DecodingKey { kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent } } + DecodingKey { + family: AlgorithmFamily::Rsa, + kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent }, + } } /// If you have a ECDSA public key in PEM format, use this. pub fn from_ec_pem(key: &'a [u8]) -> Result { let pem_key = PemEncodedKey::new(key)?; let content = pem_key.as_ec_public_key()?; - Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) }) + Ok(DecodingKey { + family: AlgorithmFamily::Ec, + kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())), + }) } - /// If you know what you're doing and have the DER encoded public key, use this. - pub fn from_der(der: &'a [u8]) -> Self { - DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)) } + /// If you know what you're doing and have a RSA DER encoded public key, use this. + pub fn from_rsa_der(der: &'a [u8]) -> Self { + DecodingKey { + family: AlgorithmFamily::Rsa, + kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)), + } + } + + /// If you know what you're doing and have a RSA EC encoded public key, use this. + pub fn from_ec_der(der: &'a [u8]) -> Self { + DecodingKey { + family: AlgorithmFamily::Ec, + kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)), + } } pub(crate) fn as_bytes(&self) -> &[u8] { @@ -104,6 +138,12 @@ pub fn decode( key: &DecodingKey, validation: &Validation, ) -> Result> { + for alg in &validation.algorithms { + if key.family != alg.family() { + return Err(new_error(ErrorKind::InvalidAlgorithm)); + } + } + let (signature, message) = expect_two!(token.rsplitn(2, '.')); let (claims, header) = expect_two!(message.rsplitn(2, '.')); let header = Header::from_encoded(header)?; diff --git a/src/encoding.rs b/src/encoding.rs index 175e081..3f7a690 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -1,7 +1,8 @@ use serde::ser::Serialize; +use crate::algorithms::AlgorithmFamily; use crate::crypto; -use crate::errors::Result; +use crate::errors::{new_error, ErrorKind, Result}; use crate::header::Header; use crate::pem::decoder::PemEncodedKey; use crate::serialization::b64_encode_part; @@ -10,13 +11,20 @@ use crate::serialization::b64_encode_part; /// This key can be re-used so make sure you only initialize it once if you can for better performance #[derive(Debug, Clone, PartialEq)] pub struct EncodingKey { + pub(crate) family: AlgorithmFamily, content: Vec, } impl EncodingKey { - /// If you're using HMAC, use that. + /// If you're using a HMAC secret that is not base64, use that. pub fn from_secret(secret: &[u8]) -> Self { - EncodingKey { content: secret.to_vec() } + EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() } + } + + /// If you have a base64 HMAC secret, use that. + pub fn from_base64_secret(secret: &str) -> Result { + let out = base64::decode(&secret)?; + Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out }) } /// If you are loading a RSA key from a .pem file. @@ -24,7 +32,7 @@ impl EncodingKey { pub fn from_rsa_pem(key: &[u8]) -> Result { let pem_key = PemEncodedKey::new(key)?; let content = pem_key.as_rsa_key()?; - Ok(EncodingKey { content: content.to_vec() }) + Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() }) } /// If you are loading a ECDSA key from a .pem file @@ -32,12 +40,17 @@ impl EncodingKey { pub fn from_ec_pem(key: &[u8]) -> Result { let pem_key = PemEncodedKey::new(key)?; let content = pem_key.as_ec_private_key()?; - Ok(EncodingKey { content: content.to_vec() }) + Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() }) } - /// If you know what you're doing and have the DER-encoded key, for RSA or ECDSA - pub fn from_der(der: &[u8]) -> Self { - EncodingKey { content: der.to_vec() } + /// If you know what you're doing and have the DER-encoded key, for RSA only + pub fn from_rsa_der(der: &[u8]) -> Self { + EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() } + } + + /// If you know what you're doing and have the DER-encoded key, for ECDSA + pub fn from_ec_der(der: &[u8]) -> Self { + EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() } } pub(crate) fn inner(&self) -> &[u8] { @@ -68,6 +81,9 @@ impl EncodingKey { /// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap(); /// ``` pub fn encode(header: &Header, claims: &T, key: &EncodingKey) -> Result { + if key.family != header.alg.family() { + return Err(new_error(ErrorKind::InvalidAlgorithm)); + } let encoded_header = b64_encode_part(&header)?; let encoded_claims = b64_encode_part(&claims)?; let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join("."); diff --git a/src/errors.rs b/src/errors.rs index 87231ed..580e8b3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -56,7 +56,8 @@ pub enum ErrorKind { InvalidSubject, /// When a token’s nbf claim represents a time in the future ImmatureSignature, - /// When the algorithm in the header doesn't match the one passed to `decode` + /// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key + /// used doesn't match the alg requested InvalidAlgorithm, // 3rd party errors diff --git a/tests/ecdsa/mod.rs b/tests/ecdsa/mod.rs index 8e39d78..2362409 100644 --- a/tests/ecdsa/mod.rs +++ b/tests/ecdsa/mod.rs @@ -17,8 +17,11 @@ fn round_trip_sign_verification_pk8() { let privkey = include_bytes!("private_ecdsa_key.pk8"); let pubkey = include_bytes!("public_ecdsa_key.pk8"); - let encrypted = sign("hello world", &EncodingKey::from_der(privkey), Algorithm::ES256).unwrap(); - let is_valid = verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey), Algorithm::ES256).unwrap(); + let encrypted = + sign("hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap(); + let is_valid = + verify(&encrypted, "hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256) + .unwrap(); assert!(is_valid); } diff --git a/tests/hmac.rs b/tests/hmac.rs index 94ee664..dcd9c59 100644 --- a/tests/hmac.rs +++ b/tests/hmac.rs @@ -102,6 +102,18 @@ fn decode_token_wrong_algorithm() { claims.unwrap(); } +#[test] +#[should_panic(expected = "InvalidAlgorithm")] +fn encode_wrong_alg_family() { + let my_claims = Claims { + sub: "b@b.com".to_string(), + company: "ACME".to_string(), + exp: Utc::now().timestamp() + 10000, + }; + let claims = encode(&Header::default(), &my_claims, &EncodingKey::from_rsa_der(b"secret")); + claims.unwrap(); +} + #[test] fn decode_token_with_bytes_secret() { let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks"; diff --git a/tests/rsa/mod.rs b/tests/rsa/mod.rs index 08c3f8a..60a2b71 100644 --- a/tests/rsa/mod.rs +++ b/tests/rsa/mod.rs @@ -57,11 +57,9 @@ fn round_trip_sign_verification_der() { let pubkey_der = include_bytes!("public_rsa_key.der"); for &alg in RSA_ALGORITHMS { - let encrypted = - sign("hello world", &EncodingKey::from_der(privkey_der), alg).unwrap(); + let encrypted = sign("hello world", &EncodingKey::from_rsa_der(privkey_der), alg).unwrap(); let is_valid = - verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey_der), alg) - .unwrap(); + verify(&encrypted, "hello world", &DecodingKey::from_rsa_der(pubkey_der), alg).unwrap(); assert!(is_valid); } }