Skip to content

Commit

Permalink
add prompts for imagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Oct 16, 2024
1 parent f062303 commit f34487c
Show file tree
Hide file tree
Showing 10 changed files with 1,422 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
ignore = E501, W503
ignore = E501, W503, E203
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defaults:
- _self_
- /model/clip-vit@models:
- clip-vit-base-patch32
- clip-vit-base-patch32_sun397
- clip-vit-base-patch32_stanford-cars
- clip-vit-base-patch32_resisc45
- clip-vit-base-patch32_eurosat
- clip-vit-base-patch32_svhn
- clip-vit-base-patch32_gtsrb
- clip-vit-base-patch32_mnist
- clip-vit-base-patch32_dtd
- /dataset/image_classification/train@train_datasets:
- tiny-imagenet

_target_: fusion_bench.modelpool.CLIPVisionModelPool
_recursive_: false

models: ???
train_datasets: ???

processor:
_target_: transformers.CLIPProcessor.from_pretrained
pretrained_model_name_or_path: openai/clip-vit-base-patch32
8 changes: 4 additions & 4 deletions fusion_bench/method/smile_upscaling/smile_upscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def _average_experts(self, pretarined_model, finetuned_models, name: str):

def _upscale_submodules(
self,
pretrained_model,
finetuned_model,
pretrained_model: nn.Module,
finetuned_models: List[nn.Module],
tqdm_desc: str = "Upscaling Linear Modules",
):
"""
Expand All @@ -446,9 +446,9 @@ def _upscale_submodules(
if isinstance(module, self._linear_layer_cls):
self._upscale_linear_layer(
pretrained_model=pretrained_model,
finetuned_models=finetuned_model,
finetuned_models=finetuned_models,
name=name,
)
elif config.average_experts and len(tuple(module.named_modules())) == 1:
# if the module is a leaf module, we perform a parameter average
self._average_experts(pretrained_model, finetuned_model, name)
self._average_experts(pretrained_model, finetuned_models, name)
16 changes: 13 additions & 3 deletions fusion_bench/method/weighted_average/weighted_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@ class WeightedAverageAlgorithm(BaseModelFusionAlgorithm, SimpleProfilerMixin):
"weights": "weights",
}

def __init__(self, normalize: bool, weights: List[float], **kwargs):
def __init__(
self,
normalize: bool,
weights: List[float],
verbose: bool = True,
**kwargs,
):
self.normalize = normalize
self.weights = weights
self.verbose = verbose
log.disabled = not self.verbose
super().__init__(**kwargs)

@override
Expand Down Expand Up @@ -70,7 +78,8 @@ def run(self, modelpool: BaseModelPool):
)
if self.normalize:
weights = weights / np.sum(weights)
print(f"weights: {weights}, normalized: {self.normalize}")
if self.verbose:
print(f"weights: {weights}, normalized: {self.normalize}")

sd: Optional[StateDictType] = None
forward_model = None
Expand All @@ -88,5 +97,6 @@ def run(self, modelpool: BaseModelPool):
)

forward_model.load_state_dict(sd)
self.print_profile_summary()
if self.verbose:
self.print_profile_summary()
return forward_model
7 changes: 7 additions & 0 deletions fusion_bench/modelpool/base_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def load_model(
)
return model

def load_pretrained_model(self, *args, **kwargs):
assert (
self.has_pretrained
), "No pretrained model available. Check `_pretrained_` is in the `models` key."
model = self.load_model("_pretrained_", *args, **kwargs)
return model

def load_pretrained_or_first_model(self, *args, **kwargs):
"""
Load the pretrained model if available, otherwise load the first available model.
Expand Down
Empty file.
256 changes: 256 additions & 0 deletions fusion_bench/models/smile_moe/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import logging
from typing import Dict, List, Tuple # noqa: F401

import torch
import torch.nn.functional as F
from torch import Tensor, nn

log = logging.getLogger(__name__)


class ExpertNotTrainedError(Exception):
pass


def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
if isinstance(tensor, Tensor):
return torch.allclose(tensor, torch.zeros_like(tensor))
else:
return all(_is_all_zeros(t) for t in tensor)


def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
u, s, vh = torch.linalg.svd(
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
)
v = vh.T
return u, s, v


def svd(
w: Tensor, full_matrices=True, accelerator=None
) -> Tuple[Tensor, Tensor, Tensor]:
if accelerator is None:
return _svd(w, full_matrices=full_matrices)
original_device = w.device
w = w.to(accelerator)
u, s, v = _svd(w)
return u.to(original_device), s.to(original_device), v.to(original_device)


class SmileGate(nn.Module):
def __init__(
self,
input_features: int,
w_diff_list: List[Tensor],
k: int,
svd_list=None, # cached `svd_list`, pass it to avoid recomputing
upscaling_accelerator=None,
):
super().__init__()
self.input_features = input_features
self.num_experts = len(w_diff_list)
weights = []
for i, w_diff in enumerate(w_diff_list):
if svd_list is None:
u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
else:
u, s, v = svd_list[i]
u = u[:, :k]
s = s[:k]
v = v[:, :k]

# weights.append((s * v).T)
weights.append(v.T)
self.k = s.size(0) # k is the actual k after truncation

weights = (
torch.stack(weights, dim=0)
.reshape(self.num_experts * self.k, -1)
.contiguous()
)
self.weights = nn.Parameter(
weights
) # weights should be a tensor of shape (num_experts * k, n)

def forward(self, x: Tensor):
batch_size = x.size(0)
if self.num_experts == 1:
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)

routing_weights = F.linear(x, self.weights).view(
batch_size, self.num_experts, self.k
)
routing_weights = routing_weights.norm(p=2, dim=2)
return routing_weights


class SmileCompressedLinear(nn.Module):
def __init__(self, model: nn.Linear, k: int, svd_cache=None):
super().__init__()
if svd_cache is None:
u, s, v = svd(model.weight)
else:
u, s, v = svd_cache
if k > 0:
u = u[:, :k]
s = s[:k]
v = v[:, :k]

self.u = nn.Parameter(u)
self.svh = nn.Parameter((s * v).T)

if model.bias is not None:
self.bias = nn.Parameter(model.bias.data, requires_grad=True)
else:
self.register_parameter("bias", None)

def forward(self, x):
x = F.linear(x, self.svh)
x = F.linear(x, self.u, self.bias)
return x


class SmileMoELinear(nn.Module):
@torch.no_grad()
def __init__(
self,
pretrained_model: nn.Linear,
finetuned_models: List[nn.Linear],
gate_k: int,
k: int,
top_k: int = 1,
full_matrices=True,
upscaling_accelerator=None,
routing_use_diff=True,
):
super().__init__()
self.num_experts = len(finetuned_models)
self.top_k = top_k
self.k = k
self.gate_k = gate_k
self.in_features = pretrained_model.in_features
self.out_features = pretrained_model.out_features

w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models]
if _is_all_zeros(w_diff_list):
# All fine-tuned models are identical to the pretrained model
raise ExpertNotTrainedError()

if routing_use_diff or k > 0:
svd_cache_list = [
svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator)
for w in w_diff_list
] # the svd cache list to avoid recomputing

# construct the gate network
if routing_use_diff:
self.gate = SmileGate(
input_features=self.in_features,
w_diff_list=w_diff_list,
k=gate_k,
svd_list=svd_cache_list,
upscaling_accelerator=upscaling_accelerator,
)
else:
self.gate = SmileGate(
input_features=self.in_features,
w_diff_list=[m.weight for m in finetuned_models],
k=gate_k,
svd_list=None,
upscaling_accelerator=upscaling_accelerator,
)

# construct experts
for m, w_diff in zip(finetuned_models, w_diff_list):
m.weight.data = w_diff
if k > 0:
experts = [
SmileCompressedLinear(m, k, svd_cache=svd_cache)
for m, svd_cache in zip(finetuned_models, svd_cache_list)
]
else:
# if k is not set (<0), we use the full fine-tuned model
experts = finetuned_models
self.experts = nn.ModuleList(experts)

if pretrained_model.bias is not None:
for m in experts:
m.bias.data = m.bias.data - pretrained_model.bias
# assign the pretrained model (the shared part)
self.pretrained_model = pretrained_model

def forward(self, hidden_states: Tensor):
pretrained_out = self.pretrained_model(hidden_states)

input_shape = hidden_states.size()
hidden_states = hidden_states.view(-1, self.in_features)

router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1)
# sample the expert according to the routing weights
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = torch.zeros(
(hidden_states.size(0), self.out_features),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
if current_state.numel() == 0:
continue
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states.reshape(
*input_shape[:-1], self.out_features
)
final_hidden_states = pretrained_out + final_hidden_states
return final_hidden_states

@property
def weight(self):
"""
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
"""
return self.pretrained_model.weight

@property
def bias(self):
return self.pretrained_model.bias

def __repr__(self):
return (
f"SingularMoELinear("
f"in_features={self.pretrained_model.in_features}, "
f"out_features={self.pretrained_model.out_features}, "
f"num_experts={self.num_experts}, "
f"top_k={self.top_k}, "
f"gate_k={self.gate_k}, "
f"k={self.k}"
f")"
)
25 changes: 25 additions & 0 deletions fusion_bench/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

from torch import nn


def del_attr(obj, names: List[str]):
"""
Expand Down Expand Up @@ -45,3 +47,26 @@ def get_attr(obj, names: List[str]):
return getattr(obj, names[0])
else:
return get_attr(getattr(obj, names[0]), names[1:])


def find_layers_with_type(
module: nn.Module,
layer_types=[nn.Linear],
prefix="",
):
"""
Recursively find the layers of a certain type in a module.
Args:
module (nn.Module): PyTorch module.
layer_types (list): List of layer types to find.
prefix (str): A prefix to add to the layer names.
Returns:
dict: Dictionary of layers of the given type(s) within the module.
"""
res = {}
for name, submodule in module.named_modules(prefix=prefix):
if isinstance(submodule, tuple(layer_types)):
res[name] = submodule
return res
Loading

0 comments on commit f34487c

Please sign in to comment.