formatted

This commit is contained in:
bahdotsh
2025-08-14 23:30:26 +05:30
parent 250a88ba94
commit db1d4bcf48
10 changed files with 106 additions and 60 deletions

View File

@@ -14,9 +14,7 @@ fn bench_basic_masking(c: &mut Criterion) {
let text = "The password is password123 and the API key is api_key_abcdef123456. Also super_secret_value_that_should_be_masked is here."; let text = "The password is password123 and the API key is api_key_abcdef123456. Also super_secret_value_that_should_be_masked is here.";
c.bench_function("basic_masking", |b| { c.bench_function("basic_masking", |b| b.iter(|| masker.mask(black_box(text))));
b.iter(|| masker.mask(black_box(text)))
});
} }
fn bench_pattern_masking(c: &mut Criterion) { fn bench_pattern_masking(c: &mut Criterion) {
@@ -50,7 +48,7 @@ fn bench_large_text_masking(c: &mut Criterion) {
fn bench_many_secrets(c: &mut Criterion) { fn bench_many_secrets(c: &mut Criterion) {
let mut masker = SecretMasker::new(); let mut masker = SecretMasker::new();
// Add many secrets // Add many secrets
for i in 0..100 { for i in 0..100 {
masker.add_secret(format!("secret_{}", i)); masker.add_secret(format!("secret_{}", i));
@@ -58,9 +56,7 @@ fn bench_many_secrets(c: &mut Criterion) {
let text = "This text contains secret_50 and secret_75 but not others."; let text = "This text contains secret_50 and secret_75 but not others.";
c.bench_function("many_secrets", |b| { c.bench_function("many_secrets", |b| b.iter(|| masker.mask(black_box(text))));
b.iter(|| masker.mask(black_box(text)))
});
} }
fn bench_contains_secrets(c: &mut Criterion) { fn bench_contains_secrets(c: &mut Criterion) {

View File

@@ -123,7 +123,7 @@
//! //!
//! ```rust //! ```rust
//! use wrkflw_secrets::{SecretProviderConfig, SecretManager, SecretConfig}; //! use wrkflw_secrets::{SecretProviderConfig, SecretManager, SecretConfig};
//! //!
//! // With prefix for better security //! // With prefix for better security
//! let provider = SecretProviderConfig::Environment { //! let provider = SecretProviderConfig::Environment {
//! prefix: Some("MYAPP_".to_string()) //! prefix: Some("MYAPP_".to_string())
@@ -190,7 +190,10 @@ mod tests {
.expect("Failed to create manager"); .expect("Failed to create manager");
// Use a unique test secret name to avoid conflicts // Use a unique test secret name to avoid conflicts
let test_secret_name = format!("TEST_SECRET_{}", uuid::Uuid::new_v4().to_string().replace('-', "_")); let test_secret_name = format!(
"TEST_SECRET_{}",
uuid::Uuid::new_v4().to_string().replace('-', "_")
);
std::env::set_var(&test_secret_name, "secret_value"); std::env::set_var(&test_secret_name, "secret_value");
let result = manager.get_secret(&test_secret_name).await; let result = manager.get_secret(&test_secret_name).await;
@@ -210,7 +213,10 @@ mod tests {
.expect("Failed to create manager"); .expect("Failed to create manager");
// Use a unique test secret name to avoid conflicts // Use a unique test secret name to avoid conflicts
let test_secret_name = format!("GITHUB_TOKEN_{}", uuid::Uuid::new_v4().to_string().replace('-', "_")); let test_secret_name = format!(
"GITHUB_TOKEN_{}",
uuid::Uuid::new_v4().to_string().replace('-', "_")
);
std::env::set_var(&test_secret_name, "ghp_test_token"); std::env::set_var(&test_secret_name, "ghp_test_token");
let mut substitution = SecretSubstitution::new(&manager); let mut substitution = SecretSubstitution::new(&manager);

View File

@@ -33,7 +33,7 @@ impl SecretManager {
for (name, provider_config) in &config.providers { for (name, provider_config) in &config.providers {
// Validate provider name // Validate provider name
validate_provider_name(name)?; validate_provider_name(name)?;
let provider: Box<dyn SecretProvider> = match provider_config { let provider: Box<dyn SecretProvider> = match provider_config {
SecretProviderConfig::Environment { prefix } => { SecretProviderConfig::Environment { prefix } => {
Box::new(EnvironmentProvider::new(prefix.clone())) Box::new(EnvironmentProvider::new(prefix.clone()))
@@ -54,7 +54,7 @@ impl SecretManager {
} }
let rate_limiter = RateLimiter::new(config.rate_limit.clone()); let rate_limiter = RateLimiter::new(config.rate_limit.clone());
Ok(Self { Ok(Self {
config, config,
providers, providers,

View File

@@ -137,16 +137,28 @@ impl SecretMasker {
let mut result = text.to_string(); let mut result = text.to_string();
// GitHub Personal Access Tokens // GitHub Personal Access Tokens
result = patterns.github_pat.replace_all(&result, "ghp_***").to_string(); result = patterns
.github_pat
.replace_all(&result, "ghp_***")
.to_string();
// GitHub App tokens // GitHub App tokens
result = patterns.github_app.replace_all(&result, "ghs_***").to_string(); result = patterns
.github_app
.replace_all(&result, "ghs_***")
.to_string();
// GitHub OAuth tokens // GitHub OAuth tokens
result = patterns.github_oauth.replace_all(&result, "gho_***").to_string(); result = patterns
.github_oauth
.replace_all(&result, "gho_***")
.to_string();
// AWS Access Key IDs // AWS Access Key IDs
result = patterns.aws_access_key.replace_all(&result, "AKIA***").to_string(); result = patterns
.aws_access_key
.replace_all(&result, "AKIA***")
.to_string();
// AWS Secret Access Keys (basic pattern) // AWS Secret Access Keys (basic pattern)
// Only mask if it's clearly in a secret context (basic heuristic) // Only mask if it's clearly in a secret context (basic heuristic)
@@ -155,10 +167,16 @@ impl SecretMasker {
} }
// JWT tokens (basic pattern) // JWT tokens (basic pattern)
result = patterns.jwt.replace_all(&result, "eyJ***.eyJ***.***").to_string(); result = patterns
.jwt
.replace_all(&result, "eyJ***.eyJ***.***")
.to_string();
// API keys with common prefixes // API keys with common prefixes
result = patterns.api_key.replace_all(&result, "${1}=***").to_string(); result = patterns
.api_key
.replace_all(&result, "${1}=***")
.to_string();
result result
} }
@@ -178,12 +196,12 @@ impl SecretMasker {
/// Check if text contains common secret patterns /// Check if text contains common secret patterns
fn has_secret_patterns(&self, text: &str) -> bool { fn has_secret_patterns(&self, text: &str) -> bool {
let patterns = PATTERNS.get_or_init(CompiledPatterns::new); let patterns = PATTERNS.get_or_init(CompiledPatterns::new);
patterns.github_pat.is_match(text) || patterns.github_pat.is_match(text)
patterns.github_app.is_match(text) || || patterns.github_app.is_match(text)
patterns.github_oauth.is_match(text) || || patterns.github_oauth.is_match(text)
patterns.aws_access_key.is_match(text) || || patterns.aws_access_key.is_match(text)
patterns.jwt.is_match(text) || patterns.jwt.is_match(text)
} }
/// Get the number of secrets being tracked /// Get the number of secrets being tracked

View File

@@ -1,4 +1,6 @@
use crate::{validation::validate_secret_value, SecretError, SecretProvider, SecretResult, SecretValue}; use crate::{
validation::validate_secret_value, SecretError, SecretProvider, SecretResult, SecretValue,
};
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
@@ -21,7 +23,6 @@ impl Default for EnvironmentProvider {
} }
impl EnvironmentProvider { impl EnvironmentProvider {
/// Get the full environment variable name /// Get the full environment variable name
fn get_env_name(&self, name: &str) -> String { fn get_env_name(&self, name: &str) -> String {
match &self.prefix { match &self.prefix {
@@ -40,7 +41,7 @@ impl SecretProvider for EnvironmentProvider {
Ok(value) => { Ok(value) => {
// Validate the secret value // Validate the secret value
validate_secret_value(&value)?; validate_secret_value(&value)?;
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
metadata.insert("source".to_string(), "environment".to_string()); metadata.insert("source".to_string(), "environment".to_string());
metadata.insert("env_var".to_string(), env_name); metadata.insert("env_var".to_string(), env_name);

View File

@@ -1,4 +1,6 @@
use crate::{validation::validate_secret_value, SecretError, SecretProvider, SecretResult, SecretValue}; use crate::{
validation::validate_secret_value, SecretError, SecretProvider, SecretResult, SecretValue,
};
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
@@ -154,7 +156,7 @@ impl SecretProvider for FileProvider {
if let Some(value) = secrets.get(name) { if let Some(value) = secrets.get(name) {
// Validate the secret value // Validate the secret value
validate_secret_value(value)?; validate_secret_value(value)?;
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
metadata.insert("source".to_string(), "file".to_string()); metadata.insert("source".to_string(), "file".to_string());
metadata.insert("file_path".to_string(), self.expand_path()); metadata.insert("file_path".to_string(), self.expand_path());

View File

@@ -56,7 +56,7 @@ impl RequestTracker {
fn cleanup_old_requests(&mut self, window_duration: Duration, now: Instant) { fn cleanup_old_requests(&mut self, window_duration: Duration, now: Instant) {
let cutoff = now - window_duration; let cutoff = now - window_duration;
self.requests.retain(|&req_time| req_time > cutoff); self.requests.retain(|&req_time| req_time > cutoff);
if let Some(&first) = self.requests.first() { if let Some(&first) = self.requests.first() {
self.first_request = first; self.first_request = first;
} }
@@ -82,8 +82,6 @@ impl RateLimiter {
} }
} }
/// Check if a request should be allowed for the given key /// Check if a request should be allowed for the given key
pub async fn check_rate_limit(&self, key: &str) -> SecretResult<()> { pub async fn check_rate_limit(&self, key: &str) -> SecretResult<()> {
if !self.config.enabled { if !self.config.enabled {
@@ -92,11 +90,11 @@ impl RateLimiter {
let now = Instant::now(); let now = Instant::now();
let mut trackers = self.trackers.write().await; let mut trackers = self.trackers.write().await;
// Clean up old requests for existing tracker // Clean up old requests for existing tracker
if let Some(tracker) = trackers.get_mut(key) { if let Some(tracker) = trackers.get_mut(key) {
tracker.cleanup_old_requests(self.config.window_duration, now); tracker.cleanup_old_requests(self.config.window_duration, now);
// Check if we're over the limit // Check if we're over the limit
if tracker.request_count() >= self.config.max_requests as usize { if tracker.request_count() >= self.config.max_requests as usize {
let time_until_reset = self.config.window_duration - (now - tracker.first_request); let time_until_reset = self.config.window_duration - (now - tracker.first_request);
@@ -105,7 +103,7 @@ impl RateLimiter {
time_until_reset.as_secs() time_until_reset.as_secs()
))); )));
} }
// Add the current request // Add the current request
tracker.add_request(now); tracker.add_request(now);
} else { } else {
@@ -114,7 +112,7 @@ impl RateLimiter {
tracker.add_request(now); tracker.add_request(now);
trackers.insert(key.to_string(), tracker); trackers.insert(key.to_string(), tracker);
} }
Ok(()) Ok(())
} }

View File

@@ -169,14 +169,17 @@ mod tests {
// Use unique secret names to avoid test conflicts // Use unique secret names to avoid test conflicts
let github_token_name = format!("GITHUB_TOKEN_{}", std::process::id()); let github_token_name = format!("GITHUB_TOKEN_{}", std::process::id());
let api_key_name = format!("API_KEY_{}", std::process::id()); let api_key_name = format!("API_KEY_{}", std::process::id());
std::env::set_var(&github_token_name, "ghp_test_token"); std::env::set_var(&github_token_name, "ghp_test_token");
std::env::set_var(&api_key_name, "secret_api_key"); std::env::set_var(&api_key_name, "secret_api_key");
let manager = SecretManager::default().await.unwrap(); let manager = SecretManager::default().await.unwrap();
let mut substitution = SecretSubstitution::new(&manager); let mut substitution = SecretSubstitution::new(&manager);
let input = format!("Token: ${{{{ secrets.{} }}}}, API: ${{{{ secrets.{} }}}}", github_token_name, api_key_name); let input = format!(
"Token: ${{{{ secrets.{} }}}}, API: ${{{{ secrets.{} }}}}",
github_token_name, api_key_name
);
let result = substitution.substitute(&input).await.unwrap(); let result = substitution.substitute(&input).await.unwrap();
assert_eq!(result, "Token: ghp_test_token, API: secret_api_key"); assert_eq!(result, "Token: ghp_test_token, API: secret_api_key");

View File

@@ -37,7 +37,8 @@ pub fn validate_secret_name(name: &str) -> SecretResult<()> {
if !SECRET_NAME_PATTERN.is_match(name) { if !SECRET_NAME_PATTERN.is_match(name) {
return Err(SecretError::InvalidSecretName { return Err(SecretError::InvalidSecretName {
reason: "Secret name can only contain letters, numbers, underscores, hyphens, and dots".to_string(), reason: "Secret name can only contain letters, numbers, underscores, hyphens, and dots"
.to_string(),
}); });
} }
@@ -56,8 +57,8 @@ pub fn validate_secret_name(name: &str) -> SecretResult<()> {
// Reserved names // Reserved names
let reserved_names = [ let reserved_names = [
"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8",
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
]; ];
if reserved_names.contains(&name.to_uppercase().as_str()) { if reserved_names.contains(&name.to_uppercase().as_str()) {
@@ -72,7 +73,7 @@ pub fn validate_secret_name(name: &str) -> SecretResult<()> {
/// Validate a secret value /// Validate a secret value
pub fn validate_secret_value(value: &str) -> SecretResult<()> { pub fn validate_secret_value(value: &str) -> SecretResult<()> {
let size = value.len(); let size = value.len();
if size > MAX_SECRET_SIZE { if size > MAX_SECRET_SIZE {
return Err(SecretError::SecretTooLarge { return Err(SecretError::SecretTooLarge {
size, size,
@@ -99,12 +100,16 @@ pub fn validate_provider_name(name: &str) -> SecretResult<()> {
} }
if name.len() > 64 { if name.len() > 64 {
return Err(SecretError::InvalidConfig( return Err(SecretError::InvalidConfig(format!(
format!("Provider name too long: {} characters (max: 64)", name.len()), "Provider name too long: {} characters (max: 64)",
)); name.len()
)));
} }
if !name.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-') { if !name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(SecretError::InvalidConfig( return Err(SecretError::InvalidConfig(
"Provider name can only contain letters, numbers, underscores, and hyphens".to_string(), "Provider name can only contain letters, numbers, underscores, and hyphens".to_string(),
)); ));
@@ -134,19 +139,19 @@ pub fn looks_like_secret(value: &str) -> bool {
// Check for high entropy (random-looking strings) // Check for high entropy (random-looking strings)
let unique_chars: std::collections::HashSet<char> = value.chars().collect(); let unique_chars: std::collections::HashSet<char> = value.chars().collect();
let entropy_ratio = unique_chars.len() as f64 / value.len() as f64; let entropy_ratio = unique_chars.len() as f64 / value.len() as f64;
if entropy_ratio > 0.6 && value.len() > 16 { if entropy_ratio > 0.6 && value.len() > 16 {
return true; return true;
} }
// Check for common secret patterns // Check for common secret patterns
let secret_patterns = [ let secret_patterns = [
r"^[A-Za-z0-9+/=]{40,}$", // Base64-like r"^[A-Za-z0-9+/=]{40,}$", // Base64-like
r"^[a-fA-F0-9]{32,}$", // Hex strings r"^[a-fA-F0-9]{32,}$", // Hex strings
r"^[A-Z0-9]{20,}$", // All caps alphanumeric r"^[A-Z0-9]{20,}$", // All caps alphanumeric
r"^sk_[a-zA-Z0-9_-]+$", // Stripe-like keys r"^sk_[a-zA-Z0-9_-]+$", // Stripe-like keys
r"^pk_[a-zA-Z0-9_-]+$", // Public keys r"^pk_[a-zA-Z0-9_-]+$", // Public keys
r"^rk_[a-zA-Z0-9_-]+$", // Restricted keys r"^rk_[a-zA-Z0-9_-]+$", // Restricted keys
]; ];
for pattern in &secret_patterns { for pattern in &secret_patterns {
@@ -224,7 +229,9 @@ mod tests {
assert!(looks_like_secret("sk_test_abcdefghijklmnop1234567890")); assert!(looks_like_secret("sk_test_abcdefghijklmnop1234567890"));
assert!(looks_like_secret("abcdefghijklmnopqrstuvwxyz123456")); assert!(looks_like_secret("abcdefghijklmnopqrstuvwxyz123456"));
assert!(looks_like_secret("ABCDEF1234567890ABCDEF1234567890")); assert!(looks_like_secret("ABCDEF1234567890ABCDEF1234567890"));
assert!(looks_like_secret("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY3ODkw")); assert!(looks_like_secret(
"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY3ODkw"
));
// Should not detect as secrets // Should not detect as secrets
assert!(!looks_like_secret("short")); assert!(!looks_like_secret("short"));

View File

@@ -60,8 +60,14 @@ async fn test_end_to_end_secret_workflow() {
// Test 1: Get secret from environment provider // Test 1: Get secret from environment provider
let env_secret = manager.get_secret(&env_secret_name).await.unwrap(); let env_secret = manager.get_secret(&env_secret_name).await.unwrap();
assert_eq!(env_secret.value(), "ghp_1234567890abcdefghijklmnopqrstuvwxyz"); assert_eq!(
assert_eq!(env_secret.metadata.get("source"), Some(&"environment".to_string())); env_secret.value(),
"ghp_1234567890abcdefghijklmnopqrstuvwxyz"
);
assert_eq!(
env_secret.metadata.get("source"),
Some(&"environment".to_string())
);
// Test 2: Get secret from file provider // Test 2: Get secret from file provider
let file_secret = manager let file_secret = manager
@@ -69,7 +75,10 @@ async fn test_end_to_end_secret_workflow() {
.await .await
.unwrap(); .unwrap();
assert_eq!(file_secret.value(), "super_secret_db_pass_123"); assert_eq!(file_secret.value(), "super_secret_db_pass_123");
assert_eq!(file_secret.metadata.get("source"), Some(&"file".to_string())); assert_eq!(
file_secret.metadata.get("source"),
Some(&"file".to_string())
);
// Test 3: List secrets from file provider // Test 3: List secrets from file provider
let all_secrets = manager.list_all_secrets().await.unwrap(); let all_secrets = manager.list_all_secrets().await.unwrap();
@@ -152,8 +161,8 @@ async fn test_error_handling() {
/// Test rate limiting functionality /// Test rate limiting functionality
#[tokio::test] #[tokio::test]
async fn test_rate_limiting() { async fn test_rate_limiting() {
use wrkflw_secrets::rate_limit::RateLimitConfig;
use std::time::Duration; use std::time::Duration;
use wrkflw_secrets::rate_limit::RateLimitConfig;
// Create config with very low rate limit // Create config with very low rate limit
let mut config = SecretConfig::default(); let mut config = SecretConfig::default();
@@ -179,7 +188,10 @@ async fn test_rate_limiting() {
// Third request should fail due to rate limiting // Third request should fail due to rate limiting
let result3 = manager.get_secret(&test_secret_name).await; let result3 = manager.get_secret(&test_secret_name).await;
assert!(result3.is_err()); assert!(result3.is_err());
assert!(result3.unwrap_err().to_string().contains("Rate limit exceeded")); assert!(result3
.unwrap_err()
.to_string()
.contains("Rate limit exceeded"));
// Cleanup // Cleanup
std::env::remove_var(&test_secret_name); std::env::remove_var(&test_secret_name);
@@ -308,7 +320,10 @@ async fn test_comprehensive_masking() {
for pattern in should_not_contain { for pattern in should_not_contain {
if pattern != "***" { if pattern != "***" {
assert!( assert!(
!masked.contains(pattern) || pattern == "ghp_" || pattern == "AKIA" || pattern == "eyJ", !masked.contains(pattern)
|| pattern == "ghp_"
|| pattern == "AKIA"
|| pattern == "eyJ",
"Masked text '{}' should not contain '{}' (or only partial patterns)", "Masked text '{}' should not contain '{}' (or only partial patterns)",
masked, masked,
pattern pattern