Skip to content

Commit

Permalink
fix post hoc plot utils (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamjajoo authored Oct 29, 2024
1 parent 6deb34c commit 09cc4de
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions src/utils/post_hoc_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def aggregate_metrics_across_users(logs_dir: str, output_dir: Optional[str] = No

# Convert to DataFrame for easier processing
df_metrics = pd.DataFrame(all_metrics)

# Select only numeric columns before calculating mean and std
numeric_columns = df_metrics.select_dtypes(include=[np.number])

# Calculate average and standard deviation
avg_metrics = df_metrics.mean()
std_metrics = df_metrics.std()
avg_metrics = numeric_columns.mean()
std_metrics = numeric_columns.std()

# Save the DataFrame with per-user metrics
df_metrics.to_csv(os.path.join(output_dir, 'per_user_metrics.csv'), index=False)
Expand Down Expand Up @@ -141,12 +143,9 @@ def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_na
for col in metric_df.columns:
plt.plot(rounds, metric_df[col], alpha=0.6, label=f'User {col+1}')

# Select only numeric columns before calculating mean and std
numeric_columns = df_metrics.select_dtypes(include=[np.number])

# Calculate average and standard deviation
avg_metrics = numeric_columns.mean()
std_metrics = numeric_columns.std()
# Compute mean and std
mean_metric = metric_df.mean(axis=1)
std_metric = metric_df.std(axis=1)

# Save the mean and std
if not os.path.exists(output_dir):
Expand Down Expand Up @@ -189,8 +188,29 @@ def plot_all_metrics(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None

print("Plots saved as PNG files.")

# Use if you a specific experiment folder
# if __name__ == "__main__":
# # Define the path where your experiment logs are saved
# logs_dir = '/mas/camera/Experiments/SONAR/abhi/cifar10_36users_1250_convergence_ringm3_seed2/logs/'
# avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir)
# plot_all_metrics(logs_dir)


# Use if you want to compute for multiple experiment folders
if __name__ == "__main__":
# Define the path where your experiment logs are saved
logs_dir = '/u/jyuan24/sonar/src/expt_dump/1_malicious_exp/cifar10_40users_1250_data_poison_8_malicious_seed1/logs/'
avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir)
plot_all_metrics(logs_dir)
# Define the base directory where your experiment logs are saved
base_logs_dir = '/mas/camera/Experiments/SONAR/abhi/'

# Iterate over each subdirectory in the base directory
for experiment_folder in os.listdir(base_logs_dir):
experiment_path = os.path.join(base_logs_dir, experiment_folder)
logs_dir = os.path.join(experiment_path, 'logs')

if os.path.isdir(logs_dir):
try:
print(f"Processing logs in: {logs_dir}")
avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir)
plot_all_metrics(logs_dir)
except Exception as e:
print(f"Error processing {logs_dir}: {e}")
continue

0 comments on commit 09cc4de

Please sign in to comment.