Skip to content

Commit

Permalink
fix: working predict.py for a single model
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Aug 5, 2024
1 parent 71f1c62 commit e467338
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def make_ensemble_predictions(
else:
df[pred_col] = preds

if len(checkpoint_paths) > 1:
df_preds = df.filter(regex=r"_pred_\d")
df_preds = df.filter(regex=r"_pred_\d")

if len(checkpoint_paths) > 1:
pred_ens_col = f"{target_col}_pred_ens" if target_col else "pred_ens"
df[pred_ens_col] = ensemble_preds = df_preds.mean(axis=1)

Expand All @@ -134,7 +134,7 @@ def make_ensemble_predictions(
).mean(axis=1)
df[pred_tot_std_ens] = (epistemic_std**2 + aleatoric_std**2) ** 0.5

if target_col:
if target_col is not None:
targets = df[target_col]
all_model_metrics = [
get_metrics(targets, df_preds[col], task_type) for col in df_preds
Expand All @@ -145,11 +145,12 @@ def make_ensemble_predictions(
print("\nSingle model performance:")
print(df_metrics.describe().round(4).loc[["mean", "std"]])

ensemble_metrics = get_metrics(targets, ensemble_preds, task_type)
if len(checkpoint_paths) > 1:
ensemble_metrics = get_metrics(targets, ensemble_preds, task_type)

print("\nEnsemble performance:")
for key, val in ensemble_metrics.items():
print(f"{key:<8} {val:.3}")
print("\nEnsemble performance:")
for key, val in ensemble_metrics.items():
print(f"{key:<8} {val:.3}")
return df, df_metrics

return df
Expand Down Expand Up @@ -208,7 +209,7 @@ def predict_from_wandb_checkpoints(
if not os.path.isfile(checkpoint_path):
run.file(f"{checkpoint_filename}").download(root=out_dir)

if target_col in kwargs:
if target_col is not None:
df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
# round to save disk space and speed up cloud storage uploads
return df.round(6), ensemble_metrics
Expand Down

0 comments on commit e467338

Please sign in to comment.