From bd62831eb0fa9d1efcda6382af74df6e811daebc Mon Sep 17 00:00:00 2001 From: hetailang <1095251224@qq.com> Date: Fri, 1 Nov 2024 20:54:42 +0800 Subject: [PATCH] fix calculate the average accuracy and loss in fusion_bench/fusion_bench/taskpool/clip_vision/taskpool.py --- .../method/adamerging/task_wise_adamerging.py | 1 + fusion_bench/taskpool/clip_vision/taskpool.py | 23 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/fusion_bench/method/adamerging/task_wise_adamerging.py b/fusion_bench/method/adamerging/task_wise_adamerging.py index 7b2867d6..85696f35 100644 --- a/fusion_bench/method/adamerging/task_wise_adamerging.py +++ b/fusion_bench/method/adamerging/task_wise_adamerging.py @@ -5,6 +5,7 @@ import lightning as L import numpy as np import torch +import torch.nn as nn from omegaconf import DictConfig from torch import Tensor from torch.utils.data import DataLoader diff --git a/fusion_bench/taskpool/clip_vision/taskpool.py b/fusion_bench/taskpool/clip_vision/taskpool.py index dfc182e3..990f78a4 100644 --- a/fusion_bench/taskpool/clip_vision/taskpool.py +++ b/fusion_bench/taskpool/clip_vision/taskpool.py @@ -288,17 +288,18 @@ def evaluate(self, model: Union[CLIPVisionModel, CLIPVisionTransformer]): self.on_task_evaluation_end() # calculate the average accuracy and loss - report["average"] = {} - accuracies = [ - value["accuracy"] for key, value in report.items() if "accuracy" in value - ] - if len(accuracies) > 0: - average_accuracy = sum(accuracies) / len(accuracies) - report["accuracy"] = average_accuracy - losses = [value["loss"] for key, value in report.items() if "loss" in value] - if len(losses) > 0: - average_loss = sum(losses) / len(losses) - report["loss"] = average_loss + if "average" not in report: + report["average"] = {} + accuracies = [ + value["accuracy"] for key, value in report.items() if "accuracy" in value + ] + if len(accuracies) > 0: + average_accuracy = sum(accuracies) / len(accuracies) + report["average"]["accuracy"] = average_accuracy + losses = [value["loss"] for key, value in report.items() if "loss" in value] + if len(losses) > 0: + average_loss = sum(losses) / len(losses) + report["average"]["loss"] = average_loss log.info(f"Evaluation Result: {report}") if self.fabric.is_global_zero and len(self.fabric._loggers) > 0: