Skip to content

Commit

Permalink
Fix several bugs; remove num_clients; separate sys and algo config
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblerz committed Jul 3, 2024
1 parent fc5a076 commit 5d5adbc
Show file tree
Hide file tree
Showing 26 changed files with 277 additions and 282 deletions.
38 changes: 19 additions & 19 deletions src/algos/DisPFL.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def __init__(self, config) -> None:
self.anneal_factor = self.config["anneal_factor"]
self.dis_gradient_check = self.config["dis_gradient_check"]
self.server_node = 1 # leader node
self.num_clients = config["num_clients"]
self.neighbors = list(range(self.num_clients))
self.num_users = config["num_users"]
self.neighbors = list(range(self.num_users))
if self.node_id == 1:
self.clients = list(range(2, self.num_clients + 1))
self.clients = list(range(2, self.num_users + 1))

def local_train(self):
"""
Expand Down Expand Up @@ -295,13 +295,13 @@ def _benefit_choose(

if cs == "random":
# Random selection of available clients
num_clients = min(client_num_per_round, client_num_in_total)
num_users = min(client_num_per_round, client_num_in_total)
client_indexes = np.random.choice(
range(client_num_in_total), num_clients, replace=False
range(client_num_in_total), num_users, replace=False
)
while cur_clnt in client_indexes:
client_indexes = np.random.choice(
range(client_num_in_total), num_clients, replace=False
range(client_num_in_total), num_users, replace=False
)

elif cs == "ring":
Expand Down Expand Up @@ -332,14 +332,14 @@ def run_protocol(self):
self.params, sparse=self.dense_ratio
) # calculate sparsity to create masks
self.mask = self.init_masks(self.params, sparsities) # mask_per_local
dist_locals = np.zeros(shape=(self.num_clients))
dist_locals = np.zeros(shape=(self.num_users))
self.index = self.node_id - 1
masks_lstrnd = [self.mask for i in range(self.num_clients)]
masks_lstrnd = [self.mask for i in range(self.num_users)]
weights_lstrnd = [
copy.deepcopy(self.get_representation()) for i in range(self.num_clients)
copy.deepcopy(self.get_representation()) for i in range(self.num_users)
]
w_per_globals = [
copy.deepcopy(self.get_representation()) for i in range(self.num_clients)
copy.deepcopy(self.get_representation()) for i in range(self.num_users)
]
for round in range(start_epochs, total_epochs):
# wait for signal to start round
Expand All @@ -364,7 +364,7 @@ def run_protocol(self):
nei_indexs = self._benefit_choose(
round,
self.index,
self.num_clients,
self.num_users,
self.config["neighbors"],
dist_locals,
total_dis,
Expand All @@ -373,7 +373,7 @@ def run_protocol(self):
)
# If not selected in full, the current clint is made up and the
# aggregation operation is performed
if self.num_clients != self.config["neighbors"]:
if self.num_users != self.config["neighbors"]:
# when not active this round
nei_indexs = np.append(nei_indexs, self.index)
print(
Expand Down Expand Up @@ -455,7 +455,7 @@ def __init__(self, config) -> None:
self.config["results_path"], self.node_id
)
self.dense_ratio = self.config["dense_ratio"]
self.num_clients = self.config["num_clients"]
self.num_users = self.config["num_users"]

def get_representation(self) -> OrderedDict[str, Tensor]:
"""
Expand All @@ -467,7 +467,7 @@ def send_representations(self, representations):
"""
Set the model
"""
for client_node in self.clients:
for client_node in self.users:
self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES)
self.log_utils.log_console(
"Server sent {} representations to node {}".format(
Expand All @@ -494,7 +494,7 @@ def single_round(self, round, active_ths_rnd):
"""
Runs the whole training procedure
"""
for client_node in self.clients:
for client_node in self.users:
self.log_utils.log_console(
"Server sending semaphore from {} to {}".format(
self.node_id, client_node
Expand All @@ -511,10 +511,10 @@ def single_round(self, round, active_ths_rnd):
)

self.masks = self.comm_utils.wait_for_all_clients(
self.clients, self.tag.SHARE_MASKS
self.users, self.tag.SHARE_MASKS
)
self.reprs = self.comm_utils.wait_for_all_clients(
self.clients, self.tag.SHARE_WEIGHTS
self.users, self.tag.SHARE_WEIGHTS
)

def get_trainable_params(self):
Expand All @@ -532,13 +532,13 @@ def run_protocol(self):
self.round = round
active_ths_rnd = np.random.choice(
[0, 1],
size=self.num_clients,
size=self.num_users,
p=[1.0 - self.config["active_rate"], self.config["active_rate"]],
)
self.log_utils.log_console("Starting round {}".format(round))

# print("weight:",mask_pers_shared)
self.single_round(round, active_ths_rnd)

accs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.FINISH)
accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH)
self.log_utils.log_console("Round {} done; acc {}".format(round, accs))
14 changes: 7 additions & 7 deletions src/algos/L2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self, config) -> None:
self.sharing_mode = self.config["sharing"]

def init_collab_weights(self):
n = self.config["num_clients"]
# Neighbors id = [1, ..., num_clients]
# Neighbors idx = [0, ..., num_clients - 1]
n = self.config["num_users"]
# Neighbors id = [1, ..., num_users]
# Neighbors idx = [0, ..., num_users - 1]
self.neighbors_id_to_idx = {idx + 1: idx for idx in range(n)}

# TODO Init not specified in the paper
Expand Down Expand Up @@ -206,7 +206,7 @@ def run_protocol(self):

# Collab weights compted in previous round are used in current
# round
cw = np.zeros(self.config["num_clients"])
cw = np.zeros(self.config["num_users"])
for id, idx in self.neighbors_id_to_idx.items():
cw[id - 1] = self.collab_weights[idx]
round_stats = {
Expand Down Expand Up @@ -300,7 +300,7 @@ def single_round(self):
"""

# Send signal to all clients to start local training
for client_node in self.clients:
for client_node in self.users:
self.comm_utils.send_signal(
dest=client_node, data=None, tag=self.tag.ROUND_START
)
Expand All @@ -309,15 +309,15 @@ def single_round(self):
)

# Collect representations (from all clients
reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.REPR_ADVERT)
reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT)
self.log_utils.log_console("Server received all clients models")

# Broadcast the representations to all clients
self.send_representations(reprs)

# Collect round stats from all clients
round_stats = self.comm_utils.wait_for_all_clients(
self.clients, self.tag.ROUND_STATS
self.users, self.tag.ROUND_STATS
)
self.log_utils.log_console("Server received all clients stats")

Expand Down
10 changes: 5 additions & 5 deletions src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, config) -> None:
self.model_keys_to_ignore.extend(keys)

self.sharing_mode = self.config["sharing"]
self.neighbors_ids = list(range(1, self.config["num_clients"] + 1))
self.neighbors_ids = list(range(1, self.config["num_users"] + 1))

def get_representation(self):
repr = self.model_utils.substract_model_weights(
Expand Down Expand Up @@ -279,7 +279,7 @@ def run_protocol(self):
# if self.node_id == 1:
# print(list(self.encoder.parameters())[0])

cws = np.zeros(self.config["num_clients"])
cws = np.zeros(self.config["num_users"])
for id, cw in collab_weights_dict.items():
cws[id - 1] = cw
round_stats["collab_weights"] = cws
Expand Down Expand Up @@ -320,7 +320,7 @@ def single_round(self, avg_alpha):
"""

# Send signal to all clients to start local training
for client_node in self.clients:
for client_node in self.users:
self.comm_utils.send_signal(
dest=client_node, data=avg_alpha, tag=self.tag.ROUND_START
)
Expand All @@ -329,15 +329,15 @@ def single_round(self, avg_alpha):
)

# Collect representations (from all clients
reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.REPR_ADVERT)
reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT)
self.log_utils.log_console("Server received all clients models")

# Broadcast the representations to all clients
self.send_representations(reprs)

# Collect round stats from all clients
round_stats_and_alphas = self.comm_utils.wait_for_all_clients(
self.clients, self.tag.ROUND_STATS
self.users, self.tag.ROUND_STATS
)
alphas = [alpha for _, alpha in round_stats_and_alphas]
round_stats = [stats for stats, _ in round_stats_and_alphas]
Expand Down
Loading

0 comments on commit 5d5adbc

Please sign in to comment.