Validate key type with algo in encode/decode
This commit is contained in:
parent
4dd2f12c6d
commit
689cc6d32e
|
@ -2,6 +2,13 @@ use crate::errors::{Error, ErrorKind, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
|
||||||
|
pub(crate) enum AlgorithmFamily {
|
||||||
|
Hmac,
|
||||||
|
Rsa,
|
||||||
|
Ec,
|
||||||
|
}
|
||||||
|
|
||||||
/// The algorithms supported for signing/verifying JWTs
|
/// The algorithms supported for signing/verifying JWTs
|
||||||
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
|
||||||
pub enum Algorithm {
|
pub enum Algorithm {
|
||||||
|
@ -58,6 +65,21 @@ impl FromStr for Algorithm {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Algorithm {
|
||||||
|
pub(crate) fn family(self) -> AlgorithmFamily {
|
||||||
|
match self {
|
||||||
|
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac,
|
||||||
|
Algorithm::RS256
|
||||||
|
| Algorithm::RS384
|
||||||
|
| Algorithm::RS512
|
||||||
|
| Algorithm::PS256
|
||||||
|
| Algorithm::PS384
|
||||||
|
| Algorithm::PS512 => AlgorithmFamily::Rsa,
|
||||||
|
Algorithm::ES256 | Algorithm::ES384 => AlgorithmFamily::Ec,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -2,6 +2,7 @@ use std::borrow::Cow;
|
||||||
|
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
|
|
||||||
|
use crate::algorithms::AlgorithmFamily;
|
||||||
use crate::crypto::verify;
|
use crate::crypto::verify;
|
||||||
use crate::errors::{new_error, ErrorKind, Result};
|
use crate::errors::{new_error, ErrorKind, Result};
|
||||||
use crate::header::Header;
|
use crate::header::Header;
|
||||||
|
@ -40,37 +41,70 @@ pub(crate) enum DecodingKeyKind<'a> {
|
||||||
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct DecodingKey<'a> {
|
pub struct DecodingKey<'a> {
|
||||||
|
pub(crate) family: AlgorithmFamily,
|
||||||
pub(crate) kind: DecodingKeyKind<'a>,
|
pub(crate) kind: DecodingKeyKind<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> DecodingKey<'a> {
|
impl<'a> DecodingKey<'a> {
|
||||||
/// If you're using HMAC, use this.
|
/// If you're using HMAC, use this.
|
||||||
pub fn from_secret(secret: &'a [u8]) -> Self {
|
pub fn from_secret(secret: &'a [u8]) -> Self {
|
||||||
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)) }
|
DecodingKey {
|
||||||
|
family: AlgorithmFamily::Hmac,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If you're using HMAC with a base64 encoded, use this.
|
||||||
|
pub fn from_base64_secret(secret: &str) -> Result<Self> {
|
||||||
|
let out = base64::decode(&secret)?;
|
||||||
|
Ok(DecodingKey {
|
||||||
|
family: AlgorithmFamily::Hmac,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you are loading a public RSA key in a PEM format, use this.
|
/// If you are loading a public RSA key in a PEM format, use this.
|
||||||
pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
|
pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
|
||||||
let pem_key = PemEncodedKey::new(key)?;
|
let pem_key = PemEncodedKey::new(key)?;
|
||||||
let content = pem_key.as_rsa_key()?;
|
let content = pem_key.as_rsa_key()?;
|
||||||
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
|
Ok(DecodingKey {
|
||||||
|
family: AlgorithmFamily::Rsa,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you have (n, e) RSA public key components, use this.
|
/// If you have (n, e) RSA public key components, use this.
|
||||||
pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
|
pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
|
||||||
DecodingKey { kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent } }
|
DecodingKey {
|
||||||
|
family: AlgorithmFamily::Rsa,
|
||||||
|
kind: DecodingKeyKind::RsaModulusExponent { n: modulus, e: exponent },
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you have a ECDSA public key in PEM format, use this.
|
/// If you have a ECDSA public key in PEM format, use this.
|
||||||
pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
|
pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
|
||||||
let pem_key = PemEncodedKey::new(key)?;
|
let pem_key = PemEncodedKey::new(key)?;
|
||||||
let content = pem_key.as_ec_public_key()?;
|
let content = pem_key.as_ec_public_key()?;
|
||||||
Ok(DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())) })
|
Ok(DecodingKey {
|
||||||
|
family: AlgorithmFamily::Ec,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you know what you're doing and have the DER encoded public key, use this.
|
/// If you know what you're doing and have a RSA DER encoded public key, use this.
|
||||||
pub fn from_der(der: &'a [u8]) -> Self {
|
pub fn from_rsa_der(der: &'a [u8]) -> Self {
|
||||||
DecodingKey { kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)) }
|
DecodingKey {
|
||||||
|
family: AlgorithmFamily::Rsa,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If you know what you're doing and have a RSA EC encoded public key, use this.
|
||||||
|
pub fn from_ec_der(der: &'a [u8]) -> Self {
|
||||||
|
DecodingKey {
|
||||||
|
family: AlgorithmFamily::Ec,
|
||||||
|
kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||||
|
@ -104,6 +138,12 @@ pub fn decode<T: DeserializeOwned>(
|
||||||
key: &DecodingKey,
|
key: &DecodingKey,
|
||||||
validation: &Validation,
|
validation: &Validation,
|
||||||
) -> Result<TokenData<T>> {
|
) -> Result<TokenData<T>> {
|
||||||
|
for alg in &validation.algorithms {
|
||||||
|
if key.family != alg.family() {
|
||||||
|
return Err(new_error(ErrorKind::InvalidAlgorithm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let (signature, message) = expect_two!(token.rsplitn(2, '.'));
|
let (signature, message) = expect_two!(token.rsplitn(2, '.'));
|
||||||
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
|
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
|
||||||
let header = Header::from_encoded(header)?;
|
let header = Header::from_encoded(header)?;
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
use serde::ser::Serialize;
|
use serde::ser::Serialize;
|
||||||
|
|
||||||
|
use crate::algorithms::AlgorithmFamily;
|
||||||
use crate::crypto;
|
use crate::crypto;
|
||||||
use crate::errors::Result;
|
use crate::errors::{new_error, ErrorKind, Result};
|
||||||
use crate::header::Header;
|
use crate::header::Header;
|
||||||
use crate::pem::decoder::PemEncodedKey;
|
use crate::pem::decoder::PemEncodedKey;
|
||||||
use crate::serialization::b64_encode_part;
|
use crate::serialization::b64_encode_part;
|
||||||
|
@ -10,13 +11,20 @@ use crate::serialization::b64_encode_part;
|
||||||
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
/// This key can be re-used so make sure you only initialize it once if you can for better performance
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct EncodingKey {
|
pub struct EncodingKey {
|
||||||
|
pub(crate) family: AlgorithmFamily,
|
||||||
content: Vec<u8>,
|
content: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodingKey {
|
impl EncodingKey {
|
||||||
/// If you're using HMAC, use that.
|
/// If you're using a HMAC secret that is not base64, use that.
|
||||||
pub fn from_secret(secret: &[u8]) -> Self {
|
pub fn from_secret(secret: &[u8]) -> Self {
|
||||||
EncodingKey { content: secret.to_vec() }
|
EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If you have a base64 HMAC secret, use that.
|
||||||
|
pub fn from_base64_secret(secret: &str) -> Result<Self> {
|
||||||
|
let out = base64::decode(&secret)?;
|
||||||
|
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you are loading a RSA key from a .pem file.
|
/// If you are loading a RSA key from a .pem file.
|
||||||
|
@ -24,7 +32,7 @@ impl EncodingKey {
|
||||||
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
|
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
|
||||||
let pem_key = PemEncodedKey::new(key)?;
|
let pem_key = PemEncodedKey::new(key)?;
|
||||||
let content = pem_key.as_rsa_key()?;
|
let content = pem_key.as_rsa_key()?;
|
||||||
Ok(EncodingKey { content: content.to_vec() })
|
Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you are loading a ECDSA key from a .pem file
|
/// If you are loading a ECDSA key from a .pem file
|
||||||
|
@ -32,12 +40,17 @@ impl EncodingKey {
|
||||||
pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
|
pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
|
||||||
let pem_key = PemEncodedKey::new(key)?;
|
let pem_key = PemEncodedKey::new(key)?;
|
||||||
let content = pem_key.as_ec_private_key()?;
|
let content = pem_key.as_ec_private_key()?;
|
||||||
Ok(EncodingKey { content: content.to_vec() })
|
Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If you know what you're doing and have the DER-encoded key, for RSA or ECDSA
|
/// If you know what you're doing and have the DER-encoded key, for RSA only
|
||||||
pub fn from_der(der: &[u8]) -> Self {
|
pub fn from_rsa_der(der: &[u8]) -> Self {
|
||||||
EncodingKey { content: der.to_vec() }
|
EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If you know what you're doing and have the DER-encoded key, for ECDSA
|
||||||
|
pub fn from_ec_der(der: &[u8]) -> Self {
|
||||||
|
EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn inner(&self) -> &[u8] {
|
pub(crate) fn inner(&self) -> &[u8] {
|
||||||
|
@ -68,6 +81,9 @@ impl EncodingKey {
|
||||||
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
|
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
|
||||||
/// ```
|
/// ```
|
||||||
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
|
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
|
||||||
|
if key.family != header.alg.family() {
|
||||||
|
return Err(new_error(ErrorKind::InvalidAlgorithm));
|
||||||
|
}
|
||||||
let encoded_header = b64_encode_part(&header)?;
|
let encoded_header = b64_encode_part(&header)?;
|
||||||
let encoded_claims = b64_encode_part(&claims)?;
|
let encoded_claims = b64_encode_part(&claims)?;
|
||||||
let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join(".");
|
let message = [encoded_header.as_ref(), encoded_claims.as_ref()].join(".");
|
||||||
|
|
|
@ -56,7 +56,8 @@ pub enum ErrorKind {
|
||||||
InvalidSubject,
|
InvalidSubject,
|
||||||
/// When a token’s nbf claim represents a time in the future
|
/// When a token’s nbf claim represents a time in the future
|
||||||
ImmatureSignature,
|
ImmatureSignature,
|
||||||
/// When the algorithm in the header doesn't match the one passed to `decode`
|
/// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
|
||||||
|
/// used doesn't match the alg requested
|
||||||
InvalidAlgorithm,
|
InvalidAlgorithm,
|
||||||
|
|
||||||
// 3rd party errors
|
// 3rd party errors
|
||||||
|
|
|
@ -17,8 +17,11 @@ fn round_trip_sign_verification_pk8() {
|
||||||
let privkey = include_bytes!("private_ecdsa_key.pk8");
|
let privkey = include_bytes!("private_ecdsa_key.pk8");
|
||||||
let pubkey = include_bytes!("public_ecdsa_key.pk8");
|
let pubkey = include_bytes!("public_ecdsa_key.pk8");
|
||||||
|
|
||||||
let encrypted = sign("hello world", &EncodingKey::from_der(privkey), Algorithm::ES256).unwrap();
|
let encrypted =
|
||||||
let is_valid = verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey), Algorithm::ES256).unwrap();
|
sign("hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap();
|
||||||
|
let is_valid =
|
||||||
|
verify(&encrypted, "hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256)
|
||||||
|
.unwrap();
|
||||||
assert!(is_valid);
|
assert!(is_valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,6 +102,18 @@ fn decode_token_wrong_algorithm() {
|
||||||
claims.unwrap();
|
claims.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "InvalidAlgorithm")]
|
||||||
|
fn encode_wrong_alg_family() {
|
||||||
|
let my_claims = Claims {
|
||||||
|
sub: "b@b.com".to_string(),
|
||||||
|
company: "ACME".to_string(),
|
||||||
|
exp: Utc::now().timestamp() + 10000,
|
||||||
|
};
|
||||||
|
let claims = encode(&Header::default(), &my_claims, &EncodingKey::from_rsa_der(b"secret"));
|
||||||
|
claims.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decode_token_with_bytes_secret() {
|
fn decode_token_with_bytes_secret() {
|
||||||
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks";
|
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.Hm0yvKH25TavFPz7J_coST9lZFYH1hQo0tvhvImmaks";
|
||||||
|
|
|
@ -57,11 +57,9 @@ fn round_trip_sign_verification_der() {
|
||||||
let pubkey_der = include_bytes!("public_rsa_key.der");
|
let pubkey_der = include_bytes!("public_rsa_key.der");
|
||||||
|
|
||||||
for &alg in RSA_ALGORITHMS {
|
for &alg in RSA_ALGORITHMS {
|
||||||
let encrypted =
|
let encrypted = sign("hello world", &EncodingKey::from_rsa_der(privkey_der), alg).unwrap();
|
||||||
sign("hello world", &EncodingKey::from_der(privkey_der), alg).unwrap();
|
|
||||||
let is_valid =
|
let is_valid =
|
||||||
verify(&encrypted, "hello world", &DecodingKey::from_der(pubkey_der), alg)
|
verify(&encrypted, "hello world", &DecodingKey::from_rsa_der(pubkey_der), alg).unwrap();
|
||||||
.unwrap();
|
|
||||||
assert!(is_valid);
|
assert!(is_valid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue