Validate key type with algo in encode/decode
This commit is contained in:
parent
4dd2f12c6d
commit
689cc6d32e
|
@ -2,6 +2,13 @@ use crate::errors::{Error, ErrorKind, Result};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub(crate) enum AlgorithmFamily {
|
||||
Hmac,
|
||||
Rsa,
|
||||
Ec,
|
||||
}
|
||||
|
||||
/// The algorithms supported for signing/verifying JWTs
|
||||
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub enum Algorithm {
|
||||
|
@ -58,6 +65,21 @@ impl FromStr for Algorithm {
|
|||
}
|
||||
}
|
||||
|
||||
impl Algorithm {
|
||||
pub(crate) fn family(self) -> AlgorithmFamily {
|
||||
match self {
|
||||
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac,
|
||||
Algorithm::RS256
|
||||
| Algorithm::RS384
|
||||
| Algorithm::RS512
|
||||
| Algorithm::PS256
|
||||
| Algorithm::PS384
|
||||
| Algorithm::PS512 => AlgorithmFamily::Rsa,
|
||||
Algorithm::ES256 | Algorithm::ES384 => AlgorithmFamily::Ec,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::borrow::Cow;
|
|||
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::algorithms::AlgorithmFamily;
|
||||
use crate::crypto::verify;
|
||||
use crate::errors::{new_error, ErrorKind, Result};
|
||||
use crate::header::Header;
|
||||
|
@ -40,37 +41,70 @@ pub(crate) enum DecodingKeyKind<'a> {
|
|||
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct DecodingKey<'a> {
|
||||
pub(crate) family: AlgorithmFamily,
|
||||
pub(crate) kind: DecodingKeyKind<'a>,
|
||||
}
|
||||
|
||||
impl<'a> DecodingKey<'a> {
|
||||
/// If you're using HMAC, use this.
|
||||
pub fn from_secret(secret: &'a [u8]) -> Self {
|
||||
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)) }
|
||||
DecodingKey {
|
||||
family: AlgorithmFamily::Hmac,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
|
||||
}
|
||||
}
|
||||
|
||||
/// If you're using HMAC with a base64 encoded, use this.
|
||||
pub fn from_base64_secret(secret: &str) -> Result<Self> {
|
||||
let out = base64::decode(&secret)?;
|
||||
Ok(DecodingKey {
|
||||
family: AlgorithmFamily::Hmac,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
|
||||
})
|
||||
}
|
||||
|
||||
/// If you are loading a public RSA key in a PEM format, use this.
|
||||
pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
|
||||
let pem_key = PemEncodedKey::new(key)?;
|
||||
let content = pem_key.as_rsa_key()?;
|
||||
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
|
||||
Ok(DecodingKey {
|
||||
family: AlgorithmFamily::Rsa,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
|
||||
})
|
||||
}
|
||||
|
||||
/// If you have (n, e) RSA public key components, use this.
|
||||
pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
|
||||
DecodingKey { kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent } }
|
||||
DecodingKey {
|
||||
family: AlgorithmFamily::Rsa,
|
||||
kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent },
|
||||
}
|
||||
}
|
||||
|
||||
/// If you have a ECDSA public key in PEM format, use this.
|
||||
pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
|
||||
let pem_key = PemEncodedKey::new(key)?;
|
||||
let content = pem_key.as_ec_public_key()?;
|
||||
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
|
||||
Ok(DecodingKey {
|
||||
family: AlgorithmFamily::Ec,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
|
||||
})
|
||||
}
|
||||
|
||||
/// If you know what you're doing and have the DER encoded public key, use this.
|
||||
pub fn from_der(der: &'a [u8]) -> Self {
|
||||
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)) }
|
||||
/// If you know what you're doing and have a RSA DER encoded public key, use this.
|
||||
pub fn from_rsa_der(der: &'a [u8]) -> Self {
|
||||
DecodingKey {
|
||||
family: AlgorithmFamily::Rsa,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
|
||||
}
|
||||
}
|
||||
|
||||
/// If you know what you're doing and have a RSA EC encoded public key, use this.
|
||||
pub fn from_ec_der(der: &'a [u8]) -> Self {
|
||||
DecodingKey {
|
||||
family: AlgorithmFamily::Ec,
|
||||
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||
|
@ -104,6 +138,12 @@ pub fn decode<T: DeserializeOwned>(
|
|||
key: &DecodingKey,
|
||||
validation: &Validation,
|
||||
) -> Result<TokenData<T>> {
|
||||
for alg in &validation.algorithms {
|
||||
if key.family != alg.family() {
|
||||
return Err(new_error(ErrorKind::InvalidAlgorithm));
|
||||
}
|
||||
}
|
||||
|
||||
let (signature, message) = expect_two!(token.rsplitn(2, '.'));
|
||||
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
|
||||
let header = Header::from_encoded(header)?;
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use serde::ser::Serialize;
|
||||
|
||||
use crate::algorithms::AlgorithmFamily;
|
||||
use crate::crypto;
|
||||
use crate::errors::Result;
|
||||
use crate::errors::{new_error, ErrorKind, Result};
|
||||
use crate::header::Header;
|
||||
use crate::pem::decoder::PemEncodedKey;
|
||||
use crate::serialization::b64_encode_part;
|
||||
|
@ -10,13 +11,20 @@ use crate::serialization::b64_encode_part;
|
|||
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct EncodingKey {
|
||||
pub(crate) family: AlgorithmFamily,
|
||||
content: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncodingKey {
|
||||
/// If you're using HMAC, use that.
|
||||
/// If you're using a HMAC secret that is not base64, use that.
|
||||
pub fn from_secret(secret: &[u8]) -> Self {
|
||||
EncodingKey { content: secret.to_vec() }
|
||||
EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() }
|
||||
}
|
||||
|
||||
/// If you have a base64 HMAC secret, use that.
|
||||
pub fn from_base64_secret(secret: &str) -> Result<Self> {
|
||||
let out = base64::decode(&secret)?;
|
||||
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
|
||||
}
|
||||
|
||||
/// If you are loading a RSA key from a .pem file.
|
||||
|
@ -24,7 +32,7 @@ impl EncodingKey {
|
|||
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
|
||||
let pem_key = PemEncodedKey::new(key)?;
|
||||
let content = pem_key.as_rsa_key()?;
|
||||
Ok(EncodingKey { content: content.to_vec() })
|
||||
Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() })
|
||||
}
|
||||
|
||||
/// If you are loading a ECDSA key from a .pem file
|
||||
|
@ -32,12 +40,17 @@ impl EncodingKey {
|
|||
pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
|
||||
let pem_key = PemEncodedKey::new(key)?;
|
||||
let content = pem_key.as_ec_private_key()?;
|
||||
Ok(EncodingKey { content: content.to_vec() })
|
||||
Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() })
|
||||
}
|
||||
|
||||
/// If you know what you're doing and have the DER-encoded key, for RSA or ECDSA
|
||||
pub fn from_der(der: &[u8]) -> Self {
|
||||
EncodingKey { content: der.to_vec() }
|
||||
/// If you know what you're doing and have the DER-encoded key, for RSA only
|
||||
pub fn from_rsa_der(der: &[u8]) -> Self {
|
||||
EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() }
|
||||
}
|
||||
|
||||
/// If you know what you're doing and have the DER-encoded key, for ECDSA
|
||||
pub fn from_ec_der(der: &[u8]) -> Self {
|
||||
EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() }
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &[u8] {
|
||||
|
@ -68,6 +81,9 @@ impl EncodingKey {
|
|||
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
|
||||
/// ```
|
||||
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
|
||||
if key.family != header.alg.family() {
|
||||
return Err(new_error(ErrorKind::InvalidAlgorithm));
|
||||
}
|
||||
let encoded_header = b64_encode_part(&header)?;
|
||||
let encoded_claims = b64_encode_part(&claims)?;
|
||||
let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join(".");
|
||||
|
|
|
@ -56,7 +56,8 @@ pub enum ErrorKind {
|
|||
InvalidSubject,
|
||||
/// When a token’s nbf claim represents a time in the future
|
||||
ImmatureSignature,
|
||||
/// When the algorithm in the header doesn't match the one passed to `decode`
|
||||
/// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
|
||||
/// used doesn't match the alg requested
|
||||
InvalidAlgorithm,
|
||||
|
||||
// 3rd party errors
|
||||
|
|
|
@ -17,8 +17,11 @@ fn round_trip_sign_verification_pk8() {
|
|||
let privkey = include_bytes!("private_ecdsa_key.pk8");
|
||||
let pubkey = include_bytes!("public_ecdsa_key.pk8");
|
||||
|
||||
let encrypted = sign("hello world", &EncodingKey::from_der(privkey), Algorithm::ES256).unwrap();
|
||||
let is_valid = verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey), Algorithm::ES256).unwrap();
|
||||
let encrypted =
|
||||
sign("hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap();
|
||||
let is_valid =
|
||||
verify(&encrypted, "hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256)
|
||||
.unwrap();
|
||||
assert!(is_valid);
|
||||
}
|
||||
|
||||
|
|
|
@ -102,6 +102,18 @@ fn decode_token_wrong_algorithm() {
|
|||
claims.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "InvalidAlgorithm")]
|
||||
fn encode_wrong_alg_family() {
|
||||
let my_claims = Claims {
|
||||
sub: "b@b.com".to_string(),
|
||||
company: "ACME".to_string(),
|
||||
exp: Utc::now().timestamp() + 10000,
|
||||
};
|
||||
let claims = encode(&Header::default(), &my_claims, &EncodingKey::from_rsa_der(b"secret"));
|
||||
claims.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_token_with_bytes_secret() {
|
||||
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks";
|
||||
|
|
|
@ -57,11 +57,9 @@ fn round_trip_sign_verification_der() {
|
|||
let pubkey_der = include_bytes!("public_rsa_key.der");
|
||||
|
||||
for &alg in RSA_ALGORITHMS {
|
||||
let encrypted =
|
||||
sign("hello world", &EncodingKey::from_der(privkey_der), alg).unwrap();
|
||||
let encrypted = sign("hello world", &EncodingKey::from_rsa_der(privkey_der), alg).unwrap();
|
||||
let is_valid =
|
||||
verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey_der), alg)
|
||||
.unwrap();
|
||||
verify(&encrypted, "hello world", &DecodingKey::from_rsa_der(pubkey_der), alg).unwrap();
|
||||
assert!(is_valid);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue