From c9817699903980bea4eeb265349bf108704ed1d1 Mon Sep 17 00:00:00 2001 From: tanganke Date: Thu, 17 Oct 2024 09:42:58 +0800 Subject: [PATCH] fix compatibility issue --- .../compat/modelpool/AutoModelForSeq2SeqLM.py | 2 -- fusion_bench/method/we_moe/clip_we_moe.py | 12 +++--------- fusion_bench/mixins/__init__.py | 1 + fusion_bench/modelpool/PeftModelForSeq2SeqLM.py | 3 +-- fusion_bench/modelpool/huggingface_automodel.py | 2 +- fusion_bench/modelpool/nyuv2_modelpool.py | 6 ++---- fusion_bench/taskpool/__init__.py | 4 +--- fusion_bench/taskpool/base_pool.py | 7 +------ fusion_bench/taskpool/nyuv2_taskpool.py | 13 ++++--------- 9 files changed, 14 insertions(+), 36 deletions(-) diff --git a/fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py b/fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py index 3ba01015..91fab5fa 100644 --- a/fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +++ b/fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py @@ -1,8 +1,6 @@ import logging -import peft from omegaconf import DictConfig -from peft import PeftModel from transformers import AutoModelForSeq2SeqLM from fusion_bench.utils import timeit_context diff --git a/fusion_bench/method/we_moe/clip_we_moe.py b/fusion_bench/method/we_moe/clip_we_moe.py index 56b388cb..9bc78a62 100644 --- a/fusion_bench/method/we_moe/clip_we_moe.py +++ b/fusion_bench/method/we_moe/clip_we_moe.py @@ -1,20 +1,14 @@ import functools import logging import os -from abc import abstractmethod from copy import deepcopy -from typing import List -import lightning as L import torch -from omegaconf import DictConfig -from torch import Tensor, nn +from torch import Tensor from torch.utils.data import DataLoader -from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel -from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer +from transformers import CLIPModel, CLIPProcessor +from transformers.models.clip.modeling_clip import CLIPEncoder -from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm -from fusion_bench.compat.modelpool import ModelPool from fusion_bench.compat.modelpool.huggingface_clip_vision import ( HuggingFaceClipVisionPool, ) diff --git a/fusion_bench/mixins/__init__.py b/fusion_bench/mixins/__init__.py index a4b734ee..a508256d 100644 --- a/fusion_bench/mixins/__init__.py +++ b/fusion_bench/mixins/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa F401 import sys from typing_extensions import TYPE_CHECKING diff --git a/fusion_bench/modelpool/PeftModelForSeq2SeqLM.py b/fusion_bench/modelpool/PeftModelForSeq2SeqLM.py index bae9e6d2..7409df7d 100644 --- a/fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +++ b/fusion_bench/modelpool/PeftModelForSeq2SeqLM.py @@ -1,9 +1,8 @@ import logging -import peft from omegaconf import DictConfig from peft import PeftModel -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForSeq2SeqLM from fusion_bench.compat.modelpool.base_pool import ModelPool from fusion_bench.utils import timeit_context diff --git a/fusion_bench/modelpool/huggingface_automodel.py b/fusion_bench/modelpool/huggingface_automodel.py index 7dd5c16a..b2f02ae5 100644 --- a/fusion_bench/modelpool/huggingface_automodel.py +++ b/fusion_bench/modelpool/huggingface_automodel.py @@ -4,7 +4,7 @@ from torch.nn.modules import Module from transformers import AutoModel -from .base_pool import ModelPool +from fusion_bench.compat.modelpool import ModelPool log = logging.getLogger(__name__) diff --git a/fusion_bench/modelpool/nyuv2_modelpool.py b/fusion_bench/modelpool/nyuv2_modelpool.py index 1a8062ad..24ae14fd 100644 --- a/fusion_bench/modelpool/nyuv2_modelpool.py +++ b/fusion_bench/modelpool/nyuv2_modelpool.py @@ -1,17 +1,15 @@ import logging -from typing import Dict import torch from omegaconf import DictConfig -from torch import Tensor, nn -from torch.nn.modules import Module +from torch import nn from fusion_bench.dataset.nyuv2 import NYUv2 from fusion_bench.models.nyuv2.aspp import DeepLabHead from fusion_bench.models.nyuv2.lightning_module import NYUv2Model from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated, resnet_dilated -from .base_pool import ModelPool +from fusion_bench.compat.modelpool.base_pool import ModelPool log = logging.getLogger(__name__) diff --git a/fusion_bench/taskpool/__init__.py b/fusion_bench/taskpool/__init__.py index 6f5b65ed..b148730c 100644 --- a/fusion_bench/taskpool/__init__.py +++ b/fusion_bench/taskpool/__init__.py @@ -1,6 +1,6 @@ +# flake8: noqa F401 import sys -from omegaconf import DictConfig from typing_extensions import TYPE_CHECKING from fusion_bench.utils.lazy_imports import LazyImporter @@ -10,7 +10,6 @@ "clip_vision": ["CLIPVisionModelTaskPool"], "dummy": ["DummyTaskPool"], "gpt2_text_classification": ["GPT2TextClassificationTaskPool"], - "flan_t5_glue_text_generation": ["FlanT5GLUETextGenerationTaskPool"], "nyuv2_taskpool": ["NYUv2TaskPool"], } @@ -19,7 +18,6 @@ from .base_pool import BaseTaskPool from .clip_vision import CLIPVisionModelTaskPool from .dummy import DummyTaskPool - from .flan_t5_glue_text_generation import FlanT5GLUETextGenerationTaskPool from .gpt2_text_classification import GPT2TextClassificationTaskPool from .nyuv2_taskpool import NYUv2TaskPool diff --git a/fusion_bench/taskpool/base_pool.py b/fusion_bench/taskpool/base_pool.py index a780be59..e7704e3f 100644 --- a/fusion_bench/taskpool/base_pool.py +++ b/fusion_bench/taskpool/base_pool.py @@ -1,9 +1,4 @@ -import logging -from abc import ABC, abstractmethod -from typing import Union - -from omegaconf import DictConfig -from tqdm.autonotebook import tqdm +from abc import abstractmethod from fusion_bench.mixins import BaseYAMLSerializableModel diff --git a/fusion_bench/taskpool/nyuv2_taskpool.py b/fusion_bench/taskpool/nyuv2_taskpool.py index 03ac5f8a..10177460 100644 --- a/fusion_bench/taskpool/nyuv2_taskpool.py +++ b/fusion_bench/taskpool/nyuv2_taskpool.py @@ -1,20 +1,15 @@ import logging from pathlib import Path -from typing import Dict, cast import lightning as L -import torch from omegaconf import DictConfig -from torch import Tensor, nn -from torch.nn.modules import Module +from torch import nn from torch.utils.data import DataLoader +from fusion_bench.compat.taskpool.base_pool import TaskPool from fusion_bench.dataset.nyuv2 import NYUv2 -from fusion_bench.models.nyuv2.aspp import DeepLabHead -from fusion_bench.models.nyuv2.lightning_module import NYUv2Model, NYUv2MTLModule -from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated, resnet_dilated - -from .base_pool import TaskPool +from fusion_bench.models.nyuv2.lightning_module import NYUv2MTLModule +from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated log = logging.getLogger(__name__)