diff --git a/src/utils/post_hoc_plot_utils.py b/src/utils/post_hoc_plot_utils.py index 25b6c29..3efcda3 100644 --- a/src/utils/post_hoc_plot_utils.py +++ b/src/utils/post_hoc_plot_utils.py @@ -67,10 +67,12 @@ def aggregate_metrics_across_users(logs_dir: str, output_dir: Optional[str] = No # Convert to DataFrame for easier processing df_metrics = pd.DataFrame(all_metrics) - + # Select only numeric columns before calculating mean and std + numeric_columns = df_metrics.select_dtypes(include=[np.number]) + # Calculate average and standard deviation - avg_metrics = df_metrics.mean() - std_metrics = df_metrics.std() + avg_metrics = numeric_columns.mean() + std_metrics = numeric_columns.std() # Save the DataFrame with per-user metrics df_metrics.to_csv(os.path.join(output_dir, 'per_user_metrics.csv'), index=False) @@ -141,12 +143,9 @@ def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_na for col in metric_df.columns: plt.plot(rounds, metric_df[col], alpha=0.6, label=f'User {col+1}') - # Select only numeric columns before calculating mean and std - numeric_columns = df_metrics.select_dtypes(include=[np.number]) - - # Calculate average and standard deviation - avg_metrics = numeric_columns.mean() - std_metrics = numeric_columns.std() + # Compute mean and std + mean_metric = metric_df.mean(axis=1) + std_metric = metric_df.std(axis=1) # Save the mean and std if not os.path.exists(output_dir): @@ -189,8 +188,29 @@ def plot_all_metrics(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None print("Plots saved as PNG files.") +# Use if you a specific experiment folder +# if __name__ == "__main__": +# # Define the path where your experiment logs are saved +# logs_dir = '/mas/camera/Experiments/SONAR/abhi/cifar10_36users_1250_convergence_ringm3_seed2/logs/' +# avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir) +# plot_all_metrics(logs_dir) + + +# Use if you want to compute for multiple experiment folders if __name__ == "__main__": - # Define the path where your experiment logs are saved - logs_dir = '/u/jyuan24/sonar/src/expt_dump/1_malicious_exp/cifar10_40users_1250_data_poison_8_malicious_seed1/logs/' - avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir) - plot_all_metrics(logs_dir) \ No newline at end of file + # Define the base directory where your experiment logs are saved + base_logs_dir = '/mas/camera/Experiments/SONAR/abhi/' + + # Iterate over each subdirectory in the base directory + for experiment_folder in os.listdir(base_logs_dir): + experiment_path = os.path.join(base_logs_dir, experiment_folder) + logs_dir = os.path.join(experiment_path, 'logs') + + if os.path.isdir(logs_dir): + try: + print(f"Processing logs in: {logs_dir}") + avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir) + plot_all_metrics(logs_dir) + except Exception as e: + print(f"Error processing {logs_dir}: {e}") + continue \ No newline at end of file