Validate key type with algo in encode/decode

This commit is contained in:
Vincent Prouillet 2020-01-13 19:38:33 +01:00
parent 4dd2f12c6d
commit 689cc6d32e
7 changed files with 114 additions and 22 deletions

View File

@ -2,6 +2,13 @@ use crate::errors::{Error, ErrorKind, Result};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
pub(crate) enum AlgorithmFamily {
Hmac,
Rsa,
Ec,
}
/// The algorithms supported for signing/verifying JWTs
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
pub enum Algorithm {
@ -58,6 +65,21 @@ impl FromStr for Algorithm {
}
}
impl Algorithm {
pub(crate) fn family(self) -> AlgorithmFamily {
match self {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac,
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => AlgorithmFamily::Rsa,
Algorithm::ES256 | Algorithm::ES384 => AlgorithmFamily::Ec,
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -2,6 +2,7 @@ use std::borrow::Cow;
use serde::de::DeserializeOwned;
use crate::algorithms::AlgorithmFamily;
use crate::crypto::verify;
use crate::errors::{new_error, ErrorKind, Result};
use crate::header::Header;
@ -40,37 +41,70 @@ pub(crate) enum DecodingKeyKind<'a> {
/// This key can be re-used so make sure you only initialize it once if you can for better performance
#[derive(Debug, Clone, PartialEq)]
pub struct DecodingKey<'a> {
pub(crate) family: AlgorithmFamily,
pub(crate) kind: DecodingKeyKind<'a>,
}
impl<'a> DecodingKey<'a> {
/// If you're using HMAC, use this.
pub fn from_secret(secret: &'a [u8]) -> Self {
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)) }
DecodingKey {
family: AlgorithmFamily::Hmac,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
}
}
/// If you're using HMAC with a base64 encoded, use this.
pub fn from_base64_secret(secret: &str) -> Result<Self> {
let out = base64::decode(&secret)?;
Ok(DecodingKey {
family: AlgorithmFamily::Hmac,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
})
}
/// If you are loading a public RSA key in a PEM format, use this.
pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_rsa_key()?;
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
Ok(DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
})
}
/// If you have (n, e) RSA public key components, use this.
pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
DecodingKey { kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent } }
DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent },
}
}
/// If you have a ECDSA public key in PEM format, use this.
pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ec_public_key()?;
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
Ok(DecodingKey {
family: AlgorithmFamily::Ec,
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
})
}
/// If you know what you're doing and have the DER encoded public key, use this.
pub fn from_der(der: &'a [u8]) -> Self {
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)) }
/// If you know what you're doing and have a RSA DER encoded public key, use this.
pub fn from_rsa_der(der: &'a [u8]) -> Self {
DecodingKey {
family: AlgorithmFamily::Rsa,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
}
}
/// If you know what you're doing and have a RSA EC encoded public key, use this.
pub fn from_ec_der(der: &'a [u8]) -> Self {
DecodingKey {
family: AlgorithmFamily::Ec,
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
}
}
pub(crate) fn as_bytes(&self) -> &[u8] {
@ -104,6 +138,12 @@ pub fn decode<T: DeserializeOwned>(
key: &DecodingKey,
validation: &Validation,
) -> Result<TokenData<T>> {
for alg in &validation.algorithms {
if key.family != alg.family() {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
}
let (signature, message) = expect_two!(token.rsplitn(2, '.'));
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
let header = Header::from_encoded(header)?;

View File

@ -1,7 +1,8 @@
use serde::ser::Serialize;
use crate::algorithms::AlgorithmFamily;
use crate::crypto;
use crate::errors::Result;
use crate::errors::{new_error, ErrorKind, Result};
use crate::header::Header;
use crate::pem::decoder::PemEncodedKey;
use crate::serialization::b64_encode_part;
@ -10,13 +11,20 @@ use crate::serialization::b64_encode_part;
/// This key can be re-used so make sure you only initialize it once if you can for better performance
#[derive(Debug, Clone, PartialEq)]
pub struct EncodingKey {
pub(crate) family: AlgorithmFamily,
content: Vec<u8>,
}
impl EncodingKey {
/// If you're using HMAC, use that.
/// If you're using a HMAC secret that is not base64, use that.
pub fn from_secret(secret: &[u8]) -> Self {
EncodingKey { content: secret.to_vec() }
EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() }
}
/// If you have a base64 HMAC secret, use that.
pub fn from_base64_secret(secret: &str) -> Result<Self> {
let out = base64::decode(&secret)?;
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
}
/// If you are loading a RSA key from a .pem file.
@ -24,7 +32,7 @@ impl EncodingKey {
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_rsa_key()?;
Ok(EncodingKey { content: content.to_vec() })
Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() })
}
/// If you are loading a ECDSA key from a .pem file
@ -32,12 +40,17 @@ impl EncodingKey {
pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ec_private_key()?;
Ok(EncodingKey { content: content.to_vec() })
Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() })
}
/// If you know what you're doing and have the DER-encoded key, for RSA or ECDSA
pub fn from_der(der: &[u8]) -> Self {
EncodingKey { content: der.to_vec() }
/// If you know what you're doing and have the DER-encoded key, for RSA only
pub fn from_rsa_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() }
}
/// If you know what you're doing and have the DER-encoded key, for ECDSA
pub fn from_ec_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() }
}
pub(crate) fn inner(&self) -> &[u8] {
@ -68,6 +81,9 @@ impl EncodingKey {
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
/// ```
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
if key.family != header.alg.family() {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
let encoded_header = b64_encode_part(&header)?;
let encoded_claims = b64_encode_part(&claims)?;
let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join(".");

View File

@ -56,7 +56,8 @@ pub enum ErrorKind {
InvalidSubject,
/// When a tokens nbf claim represents a time in the future
ImmatureSignature,
/// When the algorithm in the header doesn't match the one passed to `decode`
/// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
/// used doesn't match the alg requested
InvalidAlgorithm,
// 3rd party errors

View File

@ -17,8 +17,11 @@ fn round_trip_sign_verification_pk8() {
let privkey = include_bytes!("private_ecdsa_key.pk8");
let pubkey = include_bytes!("public_ecdsa_key.pk8");
let encrypted = sign("hello world", &EncodingKey::from_der(privkey), Algorithm::ES256).unwrap();
let is_valid = verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey), Algorithm::ES256).unwrap();
let encrypted =
sign("hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap();
let is_valid =
verify(&encrypted, "hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256)
.unwrap();
assert!(is_valid);
}

View File

@ -102,6 +102,18 @@ fn decode_token_wrong_algorithm() {
claims.unwrap();
}
#[test]
#[should_panic(expected = "InvalidAlgorithm")]
fn encode_wrong_alg_family() {
let my_claims = Claims {
sub: "b@b.com".to_string(),
company: "ACME".to_string(),
exp: Utc::now().timestamp() + 10000,
};
let claims = encode(&Header::default(), &my_claims, &EncodingKey::from_rsa_der(b"secret"));
claims.unwrap();
}
#[test]
fn decode_token_with_bytes_secret() {
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks";

View File

@ -57,11 +57,9 @@ fn round_trip_sign_verification_der() {
let pubkey_der = include_bytes!("public_rsa_key.der");
for &alg in RSA_ALGORITHMS {
let encrypted =
sign("hello world", &EncodingKey::from_der(privkey_der), alg).unwrap();
let encrypted = sign("hello world", &EncodingKey::from_rsa_der(privkey_der), alg).unwrap();
let is_valid =
verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey_der), alg)
.unwrap();
verify(&encrypted, "hello world", &DecodingKey::from_rsa_der(pubkey_der), alg).unwrap();
assert!(is_valid);
}
}