Validate key type with algo in encode/decode

This commit is contained in:
Vincent Prouillet 2020-01-13 19:38:33 +01:00
parent 4dd2f12c6d
commit 689cc6d32e
7 changed files with 114 additions and 22 deletions

View File

@ -2,6 +2,13 @@ use crate::errors::{Error, ErrorKind, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr; use std::str::FromStr;
#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
pub(crate) enum AlgorithmFamily {
Hmac,
Rsa,
Ec,
}
/// The algorithms supported for signing/verifying JWTs /// The algorithms supported for signing/verifying JWTs
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
pub enum Algorithm { pub enum Algorithm {
@ -58,6 +65,21 @@ impl FromStr for Algorithm {
} }
} }
impl Algorithm {
pub(crate) fn family(self) -> AlgorithmFamily {
match self {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac,
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => AlgorithmFamily::Rsa,
Algorithm::ES256 | Algorithm::ES384 => AlgorithmFamily::Ec,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -2,6 +2,7 @@ use std::borrow::Cow;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::algorithms::AlgorithmFamily;
use crate::crypto::verify; use crate::crypto::verify;
use crate::errors::{new_error, ErrorKind, Result}; use crate::errors::{new_error, ErrorKind, Result};
use crate::header::Header; use crate::header::Header;
@ -40,37 +41,70 @@ pub(crate) enum DecodingKeyKind<'a> {
/// This key can be re-used so make sure you only initialize it once if you can for better performance /// This key can be re-used so make sure you only initialize it once if you can for better performance
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct DecodingKey<'a> { pub struct DecodingKey<'a> {
pub(crate) family: AlgorithmFamily,
pub(crate) kind: DecodingKeyKind<'a>, pub(crate) kind: DecodingKeyKind<'a>,
} }
impl<'a> DecodingKey<'a> { impl<'a> DecodingKey<'a> {
/// If you're using HMAC, use this. /// If you're using HMAC, use this.
pub fn from_secret(secret: &'a [u8]) -> Self { pub fn from_secret(secret: &'a [u8]) -> Self {
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)) } DecodingKey {
family: AlgorithmFamily::Hmac,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
}
}
/// If you're using HMAC with a base64 encoded, use this.
pub fn from_base64_secret(secret: &str) -> Result<Self> {
let out = base64::decode(&secret)?;
Ok(DecodingKey {
family: AlgorithmFamily::Hmac,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
})
} }
/// If you are loading a public RSA key in a PEM format, use this. /// If you are loading a public RSA key in a PEM format, use this.
pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> { pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?; let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_rsa_key()?; let content = pem_key.as_rsa_key()?;
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) }) Ok(DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
})
} }
/// If you have (n, e) RSA public key components, use this. /// If you have (n, e) RSA public key components, use this.
pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self { pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
DecodingKey { kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent } } DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent },
}
} }
/// If you have a ECDSA public key in PEM format, use this. /// If you have a ECDSA public key in PEM format, use this.
pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> { pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?; let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ec_public_key()?; let content = pem_key.as_ec_public_key()?;
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) }) Ok(DecodingKey {
family: AlgorithmFamily::Ec,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
})
} }
/// If you know what you're doing and have the DER encoded public key, use this. /// If you know what you're doing and have a RSA DER encoded public key, use this.
pub fn from_der(der: &'a [u8]) -> Self { pub fn from_rsa_der(der: &'a [u8]) -> Self {
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)) } DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
}
}
/// If you know what you're doing and have a RSA EC encoded public key, use this.
pub fn from_ec_der(der: &'a [u8]) -> Self {
DecodingKey {
family: AlgorithmFamily::Ec,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
}
} }
pub(crate) fn as_bytes(&self) -> &[u8] { pub(crate) fn as_bytes(&self) -> &[u8] {
@ -104,6 +138,12 @@ pub fn decode<T: DeserializeOwned>(
key: &DecodingKey, key: &DecodingKey,
validation: &Validation, validation: &Validation,
) -> Result<TokenData<T>> { ) -> Result<TokenData<T>> {
for alg in &validation.algorithms {
if key.family != alg.family() {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
}
let (signature, message) = expect_two!(token.rsplitn(2, '.')); let (signature, message) = expect_two!(token.rsplitn(2, '.'));
let (claims, header) = expect_two!(message.rsplitn(2, '.')); let (claims, header) = expect_two!(message.rsplitn(2, '.'));
let header = Header::from_encoded(header)?; let header = Header::from_encoded(header)?;

View File

@ -1,7 +1,8 @@
use serde::ser::Serialize; use serde::ser::Serialize;
use crate::algorithms::AlgorithmFamily;
use crate::crypto; use crate::crypto;
use crate::errors::Result; use crate::errors::{new_error, ErrorKind, Result};
use crate::header::Header; use crate::header::Header;
use crate::pem::decoder::PemEncodedKey; use crate::pem::decoder::PemEncodedKey;
use crate::serialization::b64_encode_part; use crate::serialization::b64_encode_part;
@ -10,13 +11,20 @@ use crate::serialization::b64_encode_part;
/// This key can be re-used so make sure you only initialize it once if you can for better performance /// This key can be re-used so make sure you only initialize it once if you can for better performance
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct EncodingKey { pub struct EncodingKey {
pub(crate) family: AlgorithmFamily,
content: Vec<u8>, content: Vec<u8>,
} }
impl EncodingKey { impl EncodingKey {
/// If you're using HMAC, use that. /// If you're using a HMAC secret that is not base64, use that.
pub fn from_secret(secret: &[u8]) -> Self { pub fn from_secret(secret: &[u8]) -> Self {
EncodingKey { content: secret.to_vec() } EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() }
}
/// If you have a base64 HMAC secret, use that.
pub fn from_base64_secret(secret: &str) -> Result<Self> {
let out = base64::decode(&secret)?;
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
} }
/// If you are loading a RSA key from a .pem file. /// If you are loading a RSA key from a .pem file.
@ -24,7 +32,7 @@ impl EncodingKey {
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> { pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?; let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_rsa_key()?; let content = pem_key.as_rsa_key()?;
Ok(EncodingKey { content: content.to_vec() }) Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() })
} }
/// If you are loading a ECDSA key from a .pem file /// If you are loading a ECDSA key from a .pem file
@ -32,12 +40,17 @@ impl EncodingKey {
pub fn from_ec_pem(key: &[u8]) -> Result<Self> { pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?; let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ec_private_key()?; let content = pem_key.as_ec_private_key()?;
Ok(EncodingKey { content: content.to_vec() }) Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() })
} }
/// If you know what you're doing and have the DER-encoded key, for RSA or ECDSA /// If you know what you're doing and have the DER-encoded key, for RSA only
pub fn from_der(der: &[u8]) -> Self { pub fn from_rsa_der(der: &[u8]) -> Self {
EncodingKey { content: der.to_vec() } EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() }
}
/// If you know what you're doing and have the DER-encoded key, for ECDSA
pub fn from_ec_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() }
} }
pub(crate) fn inner(&self) -> &[u8] { pub(crate) fn inner(&self) -> &[u8] {
@ -68,6 +81,9 @@ impl EncodingKey {
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap(); /// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
/// ``` /// ```
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> { pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
if key.family != header.alg.family() {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
let encoded_header = b64_encode_part(&header)?; let encoded_header = b64_encode_part(&header)?;
let encoded_claims = b64_encode_part(&claims)?; let encoded_claims = b64_encode_part(&claims)?;
let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join("."); let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join(".");

View File

@ -56,7 +56,8 @@ pub enum ErrorKind {
InvalidSubject, InvalidSubject,
/// When a tokens nbf claim represents a time in the future /// When a tokens nbf claim represents a time in the future
ImmatureSignature, ImmatureSignature,
/// When the algorithm in the header doesn't match the one passed to `decode` /// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
/// used doesn't match the alg requested
InvalidAlgorithm, InvalidAlgorithm,
// 3rd party errors // 3rd party errors

View File

@ -17,8 +17,11 @@ fn round_trip_sign_verification_pk8() {
let privkey = include_bytes!("private_ecdsa_key.pk8"); let privkey = include_bytes!("private_ecdsa_key.pk8");
let pubkey = include_bytes!("public_ecdsa_key.pk8"); let pubkey = include_bytes!("public_ecdsa_key.pk8");
let encrypted = sign("hello world", &EncodingKey::from_der(privkey), Algorithm::ES256).unwrap(); let encrypted =
let is_valid = verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey), Algorithm::ES256).unwrap(); sign("hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap();
let is_valid =
verify(&encrypted, "hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256)
.unwrap();
assert!(is_valid); assert!(is_valid);
} }

View File

@ -102,6 +102,18 @@ fn decode_token_wrong_algorithm() {
claims.unwrap(); claims.unwrap();
} }
#[test]
#[should_panic(expected = "InvalidAlgorithm")]
fn encode_wrong_alg_family() {
let my_claims = Claims {
sub: "b@b.com".to_string(),
company: "ACME".to_string(),
exp: Utc::now().timestamp() + 10000,
};
let claims = encode(&Header::default(), &my_claims, &EncodingKey::from_rsa_der(b"secret"));
claims.unwrap();
}
#[test] #[test]
fn decode_token_with_bytes_secret() { fn decode_token_with_bytes_secret() {
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks"; let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks";

View File

@ -57,11 +57,9 @@ fn round_trip_sign_verification_der() {
let pubkey_der = include_bytes!("public_rsa_key.der"); let pubkey_der = include_bytes!("public_rsa_key.der");
for &alg in RSA_ALGORITHMS { for &alg in RSA_ALGORITHMS {
let encrypted = let encrypted = sign("hello world", &EncodingKey::from_rsa_der(privkey_der), alg).unwrap();
sign("hello world", &EncodingKey::from_der(privkey_der), alg).unwrap();
let is_valid = let is_valid =
verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey_der), alg) verify(&encrypted, "hello world", &DecodingKey::from_rsa_der(pubkey_der), alg).unwrap();
.unwrap();
assert!(is_valid); assert!(is_valid);
} }
} }