-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
326 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
name: ties_merging | ||
# Scaling factor $\lambda$ | ||
scaling_factor: 0.5 | ||
threshold: 0.5 | ||
# List of keys to remove from the state dict, default is empty | ||
remove_keys: [] | ||
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max' | ||
merge_func: sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Ties Merging | ||
|
||
Ties-Merging[^1] represents a novel and structured approach to consolidating multiple task-specific models into a single, efficient multi-task model. This method employs a sequence of deliberate steps to systematically merge task vectors, ensuring that the final model effectively integrates the strengths of each individual task-specific model and resolves potential conflicts between them. | ||
|
||
The Ties-Merging algorithm operates through three primary steps: | ||
|
||
1. Trim: This initial step involves refining the task-specific models by trimming unnecessary parameters, focusing the model on essential elements for each task. | ||
2. Elect Sign of Parameters: In this step, the algorithm selects the appropriate signs for the parameters, ensuring that the integrated model parameters are optimally oriented for multi-task learning. | ||
3. Disjoint Merge: Finally, the method performs a disjoint merge to combine the task-specific parameters into a single cohesive task vector, denoted as $\tau$. | ||
|
||
Given the final merged task vector $\tau$, the ultimate model is determined similarly to the method used in task arithmetic. The formulation is expressed as: | ||
|
||
$$ | ||
\theta = \theta_0 + \lambda \tau | ||
$$ | ||
|
||
where $\lambda$ is a hyperparameter chosen based on the validation set to ensure the best-performing model. | ||
|
||
By following these structured steps, Ties-Merging effectively integrates multiple task-specific models into a unified multi-task model, balancing the contributions of each task to enhance overall performance. The process ensures that the final model retains the benefits of the pre-trained model while optimally incorporating the diverse knowledge contained within the individual task-specific models. | ||
|
||
## Code Integration | ||
|
||
Configuration template for the Ties-Merging algorithm: | ||
|
||
```yaml title="config/method/ties_merging.yaml" | ||
name: ties_merging | ||
# Scaling factor $\lambda$ | ||
scaling_factor: 0.5 | ||
threshold: 0.5 | ||
# List of keys to remove from the state dict, default is empty | ||
remove_keys: [] | ||
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max' | ||
merge_func: sum | ||
``` | ||
Use the following command to run the Ties-Merging algorithm: | ||
```bash | ||
fusion_bench method=ties_merging ... | ||
``` | ||
|
||
::: fusion_bench.method.TiesMergingAlgorithm | ||
options: | ||
members: true | ||
|
||
|
||
[^1]: (NIPS 2023) Resolving Interference When Merging Models. http://arxiv.org/abs/2306.01708 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
from copy import deepcopy | ||
from typing import List, Mapping, Union | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from ...modelpool import ModelPool | ||
from ...utils.type import _StateDict | ||
from ..base_algorithm import ModelFusionAlgorithm | ||
from .ties_merging_utils import state_dict_to_vector, vector_to_state_dict, ties_merging | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class TiesMergingAlgorithm(ModelFusionAlgorithm): | ||
|
||
@torch.no_grad() | ||
def fuse(self, modelpool: ModelPool): | ||
log.info("Fusing models using ties merging.") | ||
remove_keys = self.config.get("remove_keys", []) | ||
merge_func = self.config.get("merge_func", "sum") | ||
scaling_factor = self.config.scaling_factor | ||
threshold = self.config.threshold | ||
|
||
pretrained_model = modelpool.load_model("_pretrained_") | ||
|
||
# load the state dicts of the models | ||
ft_checks: List[_StateDict] = [ | ||
modelpool.load_model(model_name).state_dict(keep_vars=True) | ||
for model_name in modelpool.model_names | ||
] | ||
ptm_check: _StateDict = pretrained_model.state_dict(keep_vars=True) | ||
|
||
# compute the task vectors | ||
flat_ft = torch.vstack( | ||
[state_dict_to_vector(check, remove_keys) for check in ft_checks] | ||
) | ||
flat_ptm = state_dict_to_vector(ptm_check, remove_keys) | ||
tv_flat_checks = flat_ft - flat_ptm | ||
|
||
# Ties Merging | ||
merged_tv = ties_merging( | ||
tv_flat_checks, | ||
reset_thresh=threshold, | ||
merge_func=merge_func, | ||
) | ||
merged_check = flat_ptm + scaling_factor * merged_tv | ||
merged_state_dict = vector_to_state_dict( | ||
merged_check, ptm_check, remove_keys=remove_keys | ||
) | ||
merged_state_dict = vector_to_state_dict( | ||
merged_check, ptm_check, remove_keys=remove_keys | ||
) | ||
|
||
pretrained_model.load_state_dict(merged_state_dict) | ||
return pretrained_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
""" | ||
This is copy from https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py | ||
""" | ||
|
||
import sys | ||
import os, copy | ||
import torch | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import re | ||
from collections import OrderedDict | ||
import torch.nn.functional as F | ||
|
||
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | ||
|
||
|
||
## Model conversion utils | ||
def state_dict_to_vector(state_dict, remove_keys=[]): | ||
shared_state_dict = copy.deepcopy(state_dict) | ||
for key in remove_keys: | ||
if key in shared_state_dict: | ||
del shared_state_dict[key] | ||
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) | ||
return torch.nn.utils.parameters_to_vector( | ||
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()] | ||
) | ||
|
||
|
||
def vector_to_state_dict(vector, state_dict, remove_keys=[]): | ||
# create a reference dict to define the order of the vector | ||
reference_dict = copy.deepcopy(state_dict) | ||
for key in remove_keys: | ||
if key in reference_dict: | ||
del reference_dict[key] | ||
sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) | ||
|
||
# create a shared state dict using the refence dict | ||
torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) | ||
|
||
# add back the encoder and decoder embedding weights. | ||
if "transformer.shared.weight" in sorted_reference_dict: | ||
for key in remove_keys: | ||
sorted_reference_dict[key] = sorted_reference_dict[ | ||
"transformer.shared.weight" | ||
] | ||
return sorted_reference_dict | ||
|
||
|
||
def add_ptm_to_tv(tv_dict, ptm_dict): | ||
assert set(tv_dict.keys()) == set( | ||
ptm_dict.keys() | ||
), "Differing parameter names in models." | ||
final_dict = copy.deepcopy(tv_dict) | ||
for k, v in ptm_dict.items(): | ||
final_dict[k] = tv_dict[k] + v | ||
return final_dict | ||
|
||
|
||
def check_parameterNamesMatch(checkpoints): | ||
parameter_names = set(checkpoints[0].keys()) | ||
|
||
if len(checkpoints) >= 2: | ||
# raise ValueError("Number of models is less than 2.") | ||
for checkpoint in checkpoints[1:]: | ||
current_parameterNames = set(checkpoint.keys()) | ||
if current_parameterNames != parameter_names: | ||
raise ValueError( | ||
"Differing parameter names in models. " | ||
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}" | ||
) | ||
|
||
|
||
def check_state_dicts_equal(state_dict1, state_dict2): | ||
if set(state_dict1.keys()) != set(state_dict2.keys()): | ||
return False | ||
|
||
for key in state_dict1.keys(): | ||
if not torch.equal(state_dict1[key], state_dict2[key]): | ||
return False | ||
|
||
return True | ||
|
||
|
||
## TIES MERGING UTILS | ||
|
||
|
||
def topk_values_mask(M, K=0.7, return_mask=False): | ||
if K > 1: | ||
K /= 100 | ||
|
||
original_shape = M.shape | ||
if M.dim() == 1: | ||
M = M.unsqueeze(0) | ||
|
||
n, d = M.shape | ||
k = int(d * K) | ||
k = d - k # Keep top k elements instead of bottom k elements | ||
|
||
# Find the k-th smallest element by magnitude for each row | ||
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) | ||
# Create a mask tensor with True for the top k elements in each row | ||
mask = M.abs() >= kth_values | ||
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask | ||
|
||
if return_mask: | ||
return M * final_mask, final_mask.float().mean(dim=1), final_mask | ||
return M * final_mask, final_mask.float().mean(dim=1) | ||
|
||
|
||
def resolve_zero_signs(sign_to_mult, method="majority"): | ||
majority_sign = torch.sign(sign_to_mult.sum()) | ||
|
||
if method == "majority": | ||
sign_to_mult[sign_to_mult == 0] = majority_sign | ||
elif method == "minority": | ||
sign_to_mult[sign_to_mult == 0] = -1 * majority_sign | ||
return sign_to_mult | ||
|
||
|
||
def resolve_sign(Tensor): | ||
sign_to_mult = torch.sign(Tensor.sum(dim=0)) | ||
sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") | ||
return sign_to_mult | ||
|
||
|
||
def disjoint_merge(Tensor, merge_func, sign_to_mult): | ||
merge_func = merge_func.split("-")[-1] | ||
|
||
# If sign is provided then we select the corresponding entries and aggregate. | ||
if sign_to_mult is not None: | ||
rows_to_keep = torch.where( | ||
sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 | ||
) | ||
selected_entries = Tensor * rows_to_keep | ||
# Else we select all non-zero entries and aggregate. | ||
else: | ||
rows_to_keep = Tensor != 0 | ||
selected_entries = Tensor * rows_to_keep | ||
|
||
if merge_func == "mean": | ||
non_zero_counts = (selected_entries != 0).sum(dim=0).float() | ||
disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp( | ||
non_zero_counts, min=1 | ||
) | ||
elif merge_func == "sum": | ||
disjoint_aggs = torch.sum(selected_entries, dim=0) | ||
elif merge_func == "max": | ||
disjoint_aggs = selected_entries.abs().max(dim=0)[0] | ||
disjoint_aggs *= sign_to_mult | ||
else: | ||
raise ValueError(f"Merge method {merge_func} is not defined.") | ||
|
||
return disjoint_aggs | ||
|
||
|
||
def ties_merging( | ||
flat_task_checks, | ||
reset_thresh=None, | ||
merge_func="", | ||
): | ||
all_checks = flat_task_checks.clone() | ||
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False) | ||
print(f"RESOLVING SIGN") | ||
final_signs = resolve_sign(updated_checks) | ||
assert final_signs is not None | ||
|
||
print(f"Disjoint AGGREGATION: {merge_func}") | ||
merged_tv = disjoint_merge(updated_checks, merge_func, final_signs) | ||
|
||
return merged_tv | ||
|
||
|
||
def disjoint_merge_split(Tensor, merge_func, sign_to_mult): | ||
merge_func = merge_func.split("-")[-1] | ||
|
||
# If sign is provided then we select the corresponding entries and aggregate. | ||
if sign_to_mult is not None: | ||
rows_to_keep = torch.where( | ||
sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 | ||
) | ||
selected_entries = Tensor * rows_to_keep | ||
# Else we select all non-zero entries and aggregate. | ||
else: | ||
rows_to_keep = Tensor != 0 | ||
selected_entries = Tensor * rows_to_keep | ||
|
||
if merge_func == "sum": | ||
disjoint_aggs = torch.sum(selected_entries, dim=0) | ||
else: | ||
raise ValueError(f"Merge method {merge_func} is not defined.") | ||
|
||
return selected_entries, disjoint_aggs | ||
|
||
|
||
def ties_merging_split( | ||
flat_task_checks, | ||
reset_thresh=None, | ||
merge_func="", | ||
): | ||
all_checks = flat_task_checks.clone() | ||
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False) | ||
print(f"RESOLVING SIGN") | ||
final_signs = resolve_sign(updated_checks) | ||
assert final_signs is not None | ||
|
||
print(f"Disjoint AGGREGATION: {merge_func}") | ||
selected_entries, merged_tv = disjoint_merge_split( | ||
updated_checks, merge_func, final_signs | ||
) | ||
|
||
return selected_entries, merged_tv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters