Skip to content

Commit

Permalink
Merge branch 'ada-svd' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Oct 17, 2024
2 parents 5237f36 + c981769 commit beb5f25
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 36 deletions.
2 changes: 0 additions & 2 deletions fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging

import peft
from omegaconf import DictConfig
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM

from fusion_bench.utils import timeit_context
Expand Down
12 changes: 3 additions & 9 deletions fusion_bench/method/we_moe/clip_we_moe.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import functools
import logging
import os
from abc import abstractmethod
from copy import deepcopy
from typing import List

import lightning as L
import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torch import Tensor
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
from transformers import CLIPModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPEncoder

from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
from fusion_bench.compat.modelpool import ModelPool
from fusion_bench.compat.modelpool.huggingface_clip_vision import (
HuggingFaceClipVisionPool,
)
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa F401
import sys

from typing_extensions import TYPE_CHECKING
Expand Down
3 changes: 1 addition & 2 deletions fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging

import peft
from omegaconf import DictConfig
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

from fusion_bench.compat.modelpool.base_pool import ModelPool
from fusion_bench.utils import timeit_context
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/modelpool/huggingface_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.modules import Module
from transformers import AutoModel

from .base_pool import ModelPool
from fusion_bench.compat.modelpool import ModelPool

log = logging.getLogger(__name__)

Expand Down
6 changes: 2 additions & 4 deletions fusion_bench/modelpool/nyuv2_modelpool.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import logging
from typing import Dict

import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.nn.modules import Module
from torch import nn

from fusion_bench.dataset.nyuv2 import NYUv2
from fusion_bench.models.nyuv2.aspp import DeepLabHead
from fusion_bench.models.nyuv2.lightning_module import NYUv2Model
from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated, resnet_dilated

from .base_pool import ModelPool
from fusion_bench.compat.modelpool.base_pool import ModelPool

log = logging.getLogger(__name__)

Expand Down
4 changes: 1 addition & 3 deletions fusion_bench/taskpool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa F401
import sys

from omegaconf import DictConfig
from typing_extensions import TYPE_CHECKING

from fusion_bench.utils.lazy_imports import LazyImporter
Expand All @@ -10,7 +10,6 @@
"clip_vision": ["CLIPVisionModelTaskPool"],
"dummy": ["DummyTaskPool"],
"gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
"flan_t5_glue_text_generation": ["FlanT5GLUETextGenerationTaskPool"],
"nyuv2_taskpool": ["NYUv2TaskPool"],
}

Expand All @@ -19,7 +18,6 @@
from .base_pool import BaseTaskPool
from .clip_vision import CLIPVisionModelTaskPool
from .dummy import DummyTaskPool
from .flan_t5_glue_text_generation import FlanT5GLUETextGenerationTaskPool
from .gpt2_text_classification import GPT2TextClassificationTaskPool
from .nyuv2_taskpool import NYUv2TaskPool

Expand Down
7 changes: 1 addition & 6 deletions fusion_bench/taskpool/base_pool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import logging
from abc import ABC, abstractmethod
from typing import Union

from omegaconf import DictConfig
from tqdm.autonotebook import tqdm
from abc import abstractmethod

from fusion_bench.mixins import BaseYAMLSerializableModel

Expand Down
13 changes: 4 additions & 9 deletions fusion_bench/taskpool/nyuv2_taskpool.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import logging
from pathlib import Path
from typing import Dict, cast

import lightning as L
import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.nn.modules import Module
from torch import nn
from torch.utils.data import DataLoader

from fusion_bench.compat.taskpool.base_pool import TaskPool
from fusion_bench.dataset.nyuv2 import NYUv2
from fusion_bench.models.nyuv2.aspp import DeepLabHead
from fusion_bench.models.nyuv2.lightning_module import NYUv2Model, NYUv2MTLModule
from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated, resnet_dilated

from .base_pool import TaskPool
from fusion_bench.models.nyuv2.lightning_module import NYUv2MTLModule
from fusion_bench.models.nyuv2.resnet_dilated import ResnetDilated

log = logging.getLogger(__name__)

Expand Down

0 comments on commit beb5f25

Please sign in to comment.