diff --git a/src/modules/aws.rs b/src/modules/aws.rs index 7abc79715..97bbbac4a 100644 --- a/src/modules/aws.rs +++ b/src/modules/aws.rs @@ -1,6 +1,4 @@ use std::collections::HashMap; -use std::fs::File; -use std::io::{BufRead, BufReader}; use std::path::PathBuf; use std::str::FromStr; @@ -10,7 +8,7 @@ use super::{Context, Module, RootModuleConfig}; use crate::configs::aws::AwsConfig; use crate::formatter::StringFormatter; -use crate::utils::render_time; +use crate::utils::{read_file, render_time}; type Profile = String; type Region = String; @@ -40,19 +38,19 @@ fn get_config_file_path(context: &Context) -> Option { fn get_aws_region_from_config(context: &Context, aws_profile: Option<&str>) -> Option { let config_location = get_config_file_path(context)?; - let file = File::open(&config_location).ok()?; - let reader = BufReader::new(file); - let lines = reader.lines().filter_map(Result::ok); + let contents = read_file(&config_location).ok()?; let region_line = if let Some(aws_profile) = aws_profile { - lines + contents + .lines() .skip_while(|line| line != &format!("[profile {}]", &aws_profile)) .skip(1) .take_while(|line| !line.starts_with('[')) .find(|line| line.starts_with("region")) } else { - lines - .skip_while(|line| line != "[default]") + contents + .lines() + .skip_while(|&line| line != "[default]") .skip(1) .take_while(|line| !line.starts_with('[')) .find(|line| line.starts_with("region")) @@ -91,11 +89,7 @@ fn get_credentials_duration(context: &Context, aws_profile: Option<&Profile>) -> { chrono::DateTime::parse_from_rfc3339(&expiration_date).ok() } else { - let credentials_location = get_credentials_file_path(context)?; - - let file = File::open(&credentials_location).ok()?; - let reader = BufReader::new(file); - let lines = reader.lines().filter_map(Result::ok); + let contents = read_file(get_credentials_file_path(context)?).ok()?; let profile_line = if let Some(aws_profile) = aws_profile { format!("[{}]", aws_profile) @@ -103,7 +97,8 @@ fn get_credentials_duration(context: &Context, aws_profile: Option<&Profile>) -> "[default]".to_string() }; - let expiration_date_line = lines + let expiration_date_line = contents + .lines() .skip_while(|line| line != &profile_line) .skip(1) .take_while(|line| !line.starts_with('[')) @@ -182,7 +177,7 @@ pub fn module<'a>(context: &'a Context) -> Option> { mod tests { use crate::test::ModuleRenderer; use ansi_term::Color; - use std::fs::File; + use std::fs::{create_dir, File}; use std::io::{self, Write}; #[test] @@ -306,6 +301,37 @@ mod tests { assert_eq!(expected, actual); } + #[test] + fn credentials_file_is_ignored_when_is_directory() -> io::Result<()> { + let dir = tempfile::tempdir()?; + let config_path = dir.path().join("credentials"); + create_dir(&config_path)?; + + assert!(ModuleRenderer::new("aws") + .env( + "AWS_CREDENTIALS_FILE", + config_path.to_string_lossy().as_ref(), + ) + .collect() + .is_none()); + + dir.close() + } + + #[test] + fn config_file_path_is_ignored_when_is_directory() -> io::Result<()> { + let dir = tempfile::tempdir()?; + let config_path = dir.path().join("config"); + create_dir(&config_path)?; + + assert!(ModuleRenderer::new("aws") + .env("AWS_CONFIG_FILE", config_path.to_string_lossy().as_ref()) + .collect() + .is_none()); + + dir.close() + } + #[test] fn default_profile_set() -> io::Result<()> { let dir = tempfile::tempdir()?;