Skip to content

Commit

Permalink
add ties-merging
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 16, 2024
1 parent 2682ce1 commit d890bbf
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 3 deletions.
6 changes: 6 additions & 0 deletions config/method/ties_merging.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
name: ties_merging
# Scaling factor $\lambda$
scaling_factor: 0.5
threshold: 0.5
# List of keys to remove from the state dict, default is empty
remove_keys: []
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
merge_func: sum
47 changes: 47 additions & 0 deletions docs/algorithms/ties_merging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Ties Merging

Ties-Merging[^1] represents a novel and structured approach to consolidating multiple task-specific models into a single, efficient multi-task model. This method employs a sequence of deliberate steps to systematically merge task vectors, ensuring that the final model effectively integrates the strengths of each individual task-specific model and resolves potential conflicts between them.

The Ties-Merging algorithm operates through three primary steps:

1. Trim: This initial step involves refining the task-specific models by trimming unnecessary parameters, focusing the model on essential elements for each task.
2. Elect Sign of Parameters: In this step, the algorithm selects the appropriate signs for the parameters, ensuring that the integrated model parameters are optimally oriented for multi-task learning.
3. Disjoint Merge: Finally, the method performs a disjoint merge to combine the task-specific parameters into a single cohesive task vector, denoted as $\tau$.

Given the final merged task vector $\tau$, the ultimate model is determined similarly to the method used in task arithmetic. The formulation is expressed as:

$$
\theta = \theta_0 + \lambda \tau
$$

where $\lambda$ is a hyperparameter chosen based on the validation set to ensure the best-performing model.

By following these structured steps, Ties-Merging effectively integrates multiple task-specific models into a unified multi-task model, balancing the contributions of each task to enhance overall performance. The process ensures that the final model retains the benefits of the pre-trained model while optimally incorporating the diverse knowledge contained within the individual task-specific models.

## Code Integration

Configuration template for the Ties-Merging algorithm:

```yaml title="config/method/ties_merging.yaml"
name: ties_merging
# Scaling factor $\lambda$
scaling_factor: 0.5
threshold: 0.5
# List of keys to remove from the state dict, default is empty
remove_keys: []
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
merge_func: sum
```
Use the following command to run the Ties-Merging algorithm:
```bash
fusion_bench method=ties_merging ...
```

::: fusion_bench.method.TiesMergingAlgorithm
options:
members: true


[^1]: (NIPS 2023) Resolving Interference When Merging Models. http://arxiv.org/abs/2306.01708
3 changes: 3 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .simple_average import SimpleAverageAlgorithm
from .weighted_average import WeightedAverageAlgorithm
from .task_arithmetic import TaskArithmeticAlgorithm
from .ties_merging.ties_merging import TiesMergingAlgorithm


def load_algorithm_from_config(method_config: DictConfig):
Expand All @@ -15,5 +16,7 @@ def load_algorithm_from_config(method_config: DictConfig):
return WeightedAverageAlgorithm(method_config)
elif method_config.name == "task_arithmetic":
return TaskArithmeticAlgorithm(method_config)
elif method_config.name == "ties_merging":
return TiesMergingAlgorithm(method_config)
else:
raise ValueError(f"Unknown algorithm: {method_config.name}")
1 change: 0 additions & 1 deletion fusion_bench/method/task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from ..utils.state_dict_arithmetic import (
state_dict_add,
state_dict_avg,
state_dict_mul,
state_dict_sub,
)
Expand Down
Empty file.
57 changes: 57 additions & 0 deletions fusion_bench/method/ties_merging/ties_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from copy import deepcopy
from typing import List, Mapping, Union

import torch
from torch import Tensor, nn

from ...modelpool import ModelPool
from ...utils.type import _StateDict
from ..base_algorithm import ModelFusionAlgorithm
from .ties_merging_utils import state_dict_to_vector, vector_to_state_dict, ties_merging

log = logging.getLogger(__name__)


class TiesMergingAlgorithm(ModelFusionAlgorithm):

@torch.no_grad()
def fuse(self, modelpool: ModelPool):
log.info("Fusing models using ties merging.")
remove_keys = self.config.get("remove_keys", [])
merge_func = self.config.get("merge_func", "sum")
scaling_factor = self.config.scaling_factor
threshold = self.config.threshold

pretrained_model = modelpool.load_model("_pretrained_")

# load the state dicts of the models
ft_checks: List[_StateDict] = [
modelpool.load_model(model_name).state_dict(keep_vars=True)
for model_name in modelpool.model_names
]
ptm_check: _StateDict = pretrained_model.state_dict(keep_vars=True)

# compute the task vectors
flat_ft = torch.vstack(
[state_dict_to_vector(check, remove_keys) for check in ft_checks]
)
flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
tv_flat_checks = flat_ft - flat_ptm

# Ties Merging
merged_tv = ties_merging(
tv_flat_checks,
reset_thresh=threshold,
merge_func=merge_func,
)
merged_check = flat_ptm + scaling_factor * merged_tv
merged_state_dict = vector_to_state_dict(
merged_check, ptm_check, remove_keys=remove_keys
)
merged_state_dict = vector_to_state_dict(
merged_check, ptm_check, remove_keys=remove_keys
)

pretrained_model.load_state_dict(merged_state_dict)
return pretrained_model
211 changes: 211 additions & 0 deletions fusion_bench/method/ties_merging/ties_merging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""
This is copy from https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py
"""

import sys
import os, copy
import torch
import matplotlib.pyplot as plt
import numpy as np
import re
from collections import OrderedDict
import torch.nn.functional as F

# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


## Model conversion utils
def state_dict_to_vector(state_dict, remove_keys=[]):
shared_state_dict = copy.deepcopy(state_dict)
for key in remove_keys:
if key in shared_state_dict:
del shared_state_dict[key]
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
return torch.nn.utils.parameters_to_vector(
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
)


def vector_to_state_dict(vector, state_dict, remove_keys=[]):
# create a reference dict to define the order of the vector
reference_dict = copy.deepcopy(state_dict)
for key in remove_keys:
if key in reference_dict:
del reference_dict[key]
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

# create a shared state dict using the refence dict
torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

# add back the encoder and decoder embedding weights.
if "transformer.shared.weight" in sorted_reference_dict:
for key in remove_keys:
sorted_reference_dict[key] = sorted_reference_dict[
"transformer.shared.weight"
]
return sorted_reference_dict


def add_ptm_to_tv(tv_dict, ptm_dict):
assert set(tv_dict.keys()) == set(
ptm_dict.keys()
), "Differing parameter names in models."
final_dict = copy.deepcopy(tv_dict)
for k, v in ptm_dict.items():
final_dict[k] = tv_dict[k] + v
return final_dict


def check_parameterNamesMatch(checkpoints):
parameter_names = set(checkpoints[0].keys())

if len(checkpoints) >= 2:
# raise ValueError("Number of models is less than 2.")
for checkpoint in checkpoints[1:]:
current_parameterNames = set(checkpoint.keys())
if current_parameterNames != parameter_names:
raise ValueError(
"Differing parameter names in models. "
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
)


def check_state_dicts_equal(state_dict1, state_dict2):
if set(state_dict1.keys()) != set(state_dict2.keys()):
return False

for key in state_dict1.keys():
if not torch.equal(state_dict1[key], state_dict2[key]):
return False

return True


## TIES MERGING UTILS


def topk_values_mask(M, K=0.7, return_mask=False):
if K > 1:
K /= 100

original_shape = M.shape
if M.dim() == 1:
M = M.unsqueeze(0)

n, d = M.shape
k = int(d * K)
k = d - k # Keep top k elements instead of bottom k elements

# Find the k-th smallest element by magnitude for each row
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
# Create a mask tensor with True for the top k elements in each row
mask = M.abs() >= kth_values
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

if return_mask:
return M * final_mask, final_mask.float().mean(dim=1), final_mask
return M * final_mask, final_mask.float().mean(dim=1)


def resolve_zero_signs(sign_to_mult, method="majority"):
majority_sign = torch.sign(sign_to_mult.sum())

if method == "majority":
sign_to_mult[sign_to_mult == 0] = majority_sign
elif method == "minority":
sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
return sign_to_mult


def resolve_sign(Tensor):
sign_to_mult = torch.sign(Tensor.sum(dim=0))
sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
return sign_to_mult


def disjoint_merge(Tensor, merge_func, sign_to_mult):
merge_func = merge_func.split("-")[-1]

# If sign is provided then we select the corresponding entries and aggregate.
if sign_to_mult is not None:
rows_to_keep = torch.where(
sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
)
selected_entries = Tensor * rows_to_keep
# Else we select all non-zero entries and aggregate.
else:
rows_to_keep = Tensor != 0
selected_entries = Tensor * rows_to_keep

if merge_func == "mean":
non_zero_counts = (selected_entries != 0).sum(dim=0).float()
disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
non_zero_counts, min=1
)
elif merge_func == "sum":
disjoint_aggs = torch.sum(selected_entries, dim=0)
elif merge_func == "max":
disjoint_aggs = selected_entries.abs().max(dim=0)[0]
disjoint_aggs *= sign_to_mult
else:
raise ValueError(f"Merge method {merge_func} is not defined.")

return disjoint_aggs


def ties_merging(
flat_task_checks,
reset_thresh=None,
merge_func="",
):
all_checks = flat_task_checks.clone()
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
print(f"RESOLVING SIGN")
final_signs = resolve_sign(updated_checks)
assert final_signs is not None

print(f"Disjoint AGGREGATION: {merge_func}")
merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)

return merged_tv


def disjoint_merge_split(Tensor, merge_func, sign_to_mult):
merge_func = merge_func.split("-")[-1]

# If sign is provided then we select the corresponding entries and aggregate.
if sign_to_mult is not None:
rows_to_keep = torch.where(
sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
)
selected_entries = Tensor * rows_to_keep
# Else we select all non-zero entries and aggregate.
else:
rows_to_keep = Tensor != 0
selected_entries = Tensor * rows_to_keep

if merge_func == "sum":
disjoint_aggs = torch.sum(selected_entries, dim=0)
else:
raise ValueError(f"Merge method {merge_func} is not defined.")

return selected_entries, disjoint_aggs


def ties_merging_split(
flat_task_checks,
reset_thresh=None,
merge_func="",
):
all_checks = flat_task_checks.clone()
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
print(f"RESOLVING SIGN")
final_signs = resolve_sign(updated_checks)
assert final_signs is not None

print(f"Disjoint AGGREGATION: {merge_func}")
selected_entries, merged_tv = disjoint_merge_split(
updated_checks, merge_func, final_signs
)

return selected_entries, merged_tv
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ nav:
- Introduction: algorithms/index.md
- Dummy: algorithms/dummy.md
- Simple Averaging: algorithms/simple_averaging.md
- Weighted Averaging: algorithms/weighted_averaging
- Weighted Averaging: algorithms/weighted_averaging.md
- Task Arithmetic: algorithms/task_arithmetic.md
- Ties-Merging: algorithms/ties_merging.md
- AdaMerging: algorithms/adamerging.md
Expand All @@ -16,7 +16,7 @@ nav:
- Task Pool:
- Introduction: taskpool/index.md
- Dummy: taskpool/dummy.md
- Classification Tasks for CLIP: taskpool/clip_vit_classification
- Classification Tasks for CLIP: taskpool/clip_vit_classification.md
- Command Line Interface:
- fusion_bench: cli/fusion_bench.md

Expand Down

0 comments on commit d890bbf

Please sign in to comment.