Skip to content

Commit

Permalink
fix calculate the average accuracy and loss in fusion_bench/fusion_be…
Browse files Browse the repository at this point in the history
…nch/taskpool/clip_vision/taskpool.py
  • Loading branch information
hetailang authored and tanganke committed Nov 1, 2024
1 parent e35649a commit bd62831
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions fusion_bench/method/adamerging/task_wise_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lightning as L
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down
23 changes: 12 additions & 11 deletions fusion_bench/taskpool/clip_vision/taskpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,17 +288,18 @@ def evaluate(self, model: Union[CLIPVisionModel, CLIPVisionTransformer]):
self.on_task_evaluation_end()

# calculate the average accuracy and loss
report["average"] = {}
accuracies = [
value["accuracy"] for key, value in report.items() if "accuracy" in value
]
if len(accuracies) > 0:
average_accuracy = sum(accuracies) / len(accuracies)
report["accuracy"] = average_accuracy
losses = [value["loss"] for key, value in report.items() if "loss" in value]
if len(losses) > 0:
average_loss = sum(losses) / len(losses)
report["loss"] = average_loss
if "average" not in report:
report["average"] = {}
accuracies = [
value["accuracy"] for key, value in report.items() if "accuracy" in value
]
if len(accuracies) > 0:
average_accuracy = sum(accuracies) / len(accuracies)
report["average"]["accuracy"] = average_accuracy
losses = [value["loss"] for key, value in report.items() if "loss" in value]
if len(losses) > 0:
average_loss = sum(losses) / len(losses)
report["average"]["loss"] = average_loss

log.info(f"Evaluation Result: {report}")
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
Expand Down

0 comments on commit bd62831

Please sign in to comment.