diff --git a/src/algos/DisPFL.py b/src/algos/DisPFL.py index a762418..6c5cc80 100644 --- a/src/algos/DisPFL.py +++ b/src/algos/DisPFL.py @@ -1,19 +1,24 @@ +""" +This module defines the DisPFLClient and DisPFLServer classes for distributed personalized federated learning. +""" + +import copy +import math +import random from collections import OrderedDict from typing import Any, Dict, List -from torch import Tensor, cat, zeros_like, numel, randperm, from_numpy + +import numpy as np import torch import torch.nn as nn -import random -import copy -import numpy as np -import math +from torch import Tensor, from_numpy, numel, randperm, zeros_like from algos.base_class import BaseClient, BaseServer -class CommProtocol(object): +class CommProtocol: """ - Communication protocol tags for the server and clients + Communication protocol tags for the server and clients. """ DONE = 0 # Used to signal the server that the client is done with local training @@ -26,13 +31,14 @@ class CommProtocol(object): class DisPFLClient(BaseClient): + """ + Client class for DisPFL (Distributed Personalized Federated Learning). + """ def __init__(self, config) -> None: super().__init__(config) self.config = config self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.dense_ratio = self.config["dense_ratio"] self.anneal_factor = self.config["anneal_factor"] self.dis_gradient_check = self.config["dis_gradient_check"] @@ -44,24 +50,16 @@ def __init__(self, config) -> None: def local_train(self): """ - Train the model locally - + Train the model locally. """ loss, acc = self.model_utils.train_mask( self.model, self.mask, self.optim, self.dloader, self.loss_fn, self.device ) - - print("Node{} train loss: {}, train acc: {}".format(self.node_id, loss, acc)) - # loss = self.model_utils.train_mask(self.model, self.mask,self.optim, - # self.dloader, self.loss_fn, - # self.device) - - # print("Client {} finished training with loss {}".format(self.node_id, avg_loss)) - # self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch) + print(f"Node{self.node_id} train loss: {loss}, train acc: {acc}") def local_test(self, **kwargs): """ - Test the model locally, not to be used in the traditional FedAvg + Test the model locally, not to be used in the traditional FedAvg. """ test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device @@ -80,23 +78,25 @@ def get_trainable_params(self): def get_representation(self) -> OrderedDict[str, Tensor]: """ - Share the model weights + Share the model weights. """ return self.model.state_dict() def set_representation(self, representation: OrderedDict[str, Tensor]): """ - Set the model weights + Set the model weights. """ self.model.load_state_dict(representation) - def fire_mask(self, masks, round, total_round): + def fire_mask(self, masks, round_num, total_round): + """ + Fire mask method for model pruning. + """ weights = self.get_representation() drop_ratio = ( - self.anneal_factor / 2 * (1 + np.cos((round * np.pi) / total_round)) + self.anneal_factor / 2 * (1 + np.cos((round_num * np.pi) / total_round)) ) new_masks = copy.deepcopy(masks) - num_remove = {} for name in masks: num_non_zeros = torch.sum(masks[name]) @@ -111,6 +111,9 @@ def fire_mask(self, masks, round, total_round): return new_masks, num_remove def regrow_mask(self, masks, num_remove, gradient=None): + """ + Regrow mask method for model pruning. + """ new_masks = copy.deepcopy(masks) for name in masks: if not self.dis_gradient_check: @@ -135,15 +138,14 @@ def regrow_mask(self, masks, num_remove, gradient=None): new_masks[name].view(-1)[idx] = 1 return new_masks - def aggregate(self, nei_indexs, weights_lstrnd, masks_lstrnd): + def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd): """ - Aggregate the model weights + Aggregate the model weights. """ - # print("len masks:",mask_list) count_mask = copy.deepcopy(masks_lstrnd[self.index]) for k in count_mask.keys(): count_mask[k] = count_mask[k] - count_mask[k] # zero out by pruning - for clnt in nei_indexs: + for clnt in nei_indexes: count_mask[k] += masks_lstrnd[clnt][k].to(self.device) # mask for k in count_mask.keys(): count_mask[k] = np.divide( @@ -171,20 +173,25 @@ def aggregate(self, nei_indexs, weights_lstrnd, masks_lstrnd): def send_representations(self, representation): """ - Set the model + Set the model. """ for client_node in self.clients: self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) - print("Node 1 sent average weight to {} nodes".format(len(self.clients))) + print(f"Node 1 sent average weight to {len(self.clients)} nodes") - def calculate_sparsities(self, params, tabu=[], distribution="ERK", sparse=0.5): - spasities = {} + def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5): + """ + Calculate sparsities for model pruning. + """ + if tabu is None: + tabu = [] + sparsities = {} if distribution == "uniform": for name in params: if name not in tabu: - spasities[name] = 1 - self.dense_ratio + sparsities[name] = 1 - self.dense_ratio else: - spasities[name] = 0 + sparsities[name] = 0 elif distribution == "ERK": print("initialize by ERK") total_params = 0 @@ -219,21 +226,23 @@ def calculate_sparsities(self, params, tabu=[], distribution="ERK", sparse=0.5): is_epsilon_valid = False for mask_name, mask_raw_prob in raw_probabilities.items(): if mask_raw_prob == max_prob: - (f"Sparsity of var:{mask_name} had to be set to 0.") + print(f"Sparsity of var:{mask_name} had to be set to 0.") dense_layers.add(mask_name) else: is_epsilon_valid = True - # With the valid epsilon, we can set sparsities of the remaning - # layers. + # With the valid epsilon, we can set sparsities of the remaining layers. for name in params: if name in dense_layers: - spasities[name] = 0 + sparsities[name] = 0 else: - spasities[name] = 1 - epsilon * raw_probabilities[name] - return spasities + sparsities[name] = 1 - epsilon * raw_probabilities[name] + return sparsities def init_masks(self, params, sparsities): + """ + Initialize masks for model pruning. + """ masks = OrderedDict() for name in params: masks[name] = zeros_like(params[name]) @@ -246,11 +255,12 @@ def init_masks(self, params, sparsities): return masks def screen_gradient(self): + """ + Screen gradient method for model pruning. + """ model = self.model model.eval() - # # # train and update criterion = nn.CrossEntropyLoss().to(self.device) - # # sample one epoch of data model.zero_grad() (x, labels) = next(iter(self.dloader)) x, labels = x.to(self.device), labels.to(self.device) @@ -264,9 +274,11 @@ def screen_gradient(self): return gradient def hamming_distance(self, mask_a, mask_b): + """ + Calculate the Hamming distance between two masks. + """ dis = 0 total = 0 - for key in mask_a: dis += torch.sum( mask_a[key].int().to(self.device) ^ mask_b[key].int().to(self.device) @@ -285,16 +297,14 @@ def _benefit_choose( cs=False, active_ths_rnd=None, ): + """ + Benefit choose method for client selection. + """ if client_num_in_total == client_num_per_round: - # If one can communicate with all others and there is no bandwidth - # limit - client_indexes = [ - client_index for client_index in range(client_num_in_total) - ] + client_indexes = list(range(client_num_in_total)) return client_indexes if cs == "random": - # Random selection of available clients num_users = min(client_num_per_round, client_num_in_total) client_indexes = np.random.choice( range(client_num_in_total), num_users, replace=False @@ -305,13 +315,11 @@ def _benefit_choose( ) elif cs == "ring": - # Ring Topology in Decentralized setting left = (cur_clnt - 1 + client_num_in_total) % client_num_in_total right = (cur_clnt + 1) % client_num_in_total client_indexes = np.asarray([left, right]) elif cs == "full": - # Fully-connected Topology in Decentralized setting client_indexes = np.array(np.where(active_ths_rnd == 1)).squeeze() client_indexes = np.delete( client_indexes, int(np.where(client_indexes == cur_clnt)[0]) @@ -319,12 +327,18 @@ def _benefit_choose( return client_indexes def model_difference(self, model_a, model_b): - a = sum( + """ + Calculate the difference between two models. + """ + diff = sum( [torch.sum(torch.square(model_a[name] - model_b[name])) for name in model_a] ) - return a + return diff def run_protocol(self): + """ + Runs the entire training protocol. + """ start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] self.params = self.get_trainable_params() @@ -341,10 +355,10 @@ def run_protocol(self): w_per_globals = [ copy.deepcopy(self.get_representation()) for i in range(self.num_users) ] - for round in range(start_epochs, total_epochs): + for epoch in range(start_epochs, total_epochs): # wait for signal to start round active_ths_rnd = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START) - if round != 0: + if epoch != 0: [weights_lstrnd, masks_lstrnd] = self.comm_utils.wait_for_signal( src=0, tag=self.tag.LAST_ROUND ) @@ -353,16 +367,13 @@ def run_protocol(self): masks_lstrnd[self.index], self.mask ) print( - "Node{}: local mask change {}/{}".format( - self.node_id, dist_locals[self.index], total_dis - ) + f"Node{self.node_id}: local mask change {dist_locals[self.index]}/{total_dis}" ) - # share data with client 1 if active_ths_rnd[self.index] == 0: - nei_indexs = np.array([]) + nei_indexes = np.array([]) else: - nei_indexs = self._benefit_choose( - round, + nei_indexes = self._benefit_choose( + epoch, self.index, self.num_users, self.config["neighbors"], @@ -371,18 +382,13 @@ def run_protocol(self): self.config["cs"], active_ths_rnd, ) - # If not selected in full, the current clint is made up and the - # aggregation operation is performed if self.num_users != self.config["neighbors"]: - # when not active this round - nei_indexs = np.append(nei_indexs, self.index) + nei_indexes = np.append(nei_indexes, self.index) print( - "Node {}'s neighbors index:{}".format( - self.index, [i + 1 for i in nei_indexs] - ) + f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}" ) - for tmp_idx in nei_indexs: + for tmp_idx in nei_indexes: if tmp_idx != self.index: dist_locals[tmp_idx], _ = self.hamming_distance( self.mask, masks_lstrnd[tmp_idx] @@ -390,48 +396,41 @@ def run_protocol(self): if self.config["cs"] != "full": print( - "choose client_indexes: {}, accoring to {}".format( - str(nei_indexs), self.config["cs"] - ) + f"choose client_indexes: {str(nei_indexes)}, according to {self.config['cs']}" ) else: print( - "choose client_indexes: {}, accoring to {}".format( - str(nei_indexs), self.config["cs"] - ) + f"choose client_indexes: {str(nei_indexes)}, according to {self.config['cs']}" ) if active_ths_rnd[self.index] != 0: - nei_distances = [dist_locals[i] for i in nei_indexs] + nei_distances = [dist_locals[i] for i in nei_indexes] print("choose mask diff: " + str(nei_distances)) - # calculate new initial model if active_ths_rnd[self.index] == 1: new_repr, w_per_globals[self.index] = self.aggregate( - nei_indexs, weights_lstrnd, masks_lstrnd + nei_indexes, weights_lstrnd, masks_lstrnd ) else: new_repr = copy.deepcopy(weights_lstrnd[self.index]) w_per_globals[self.index] = copy.deepcopy(weights_lstrnd[self.index]) model_diff = self.model_difference(new_repr, self.repr) - print("Node {} model_diff{}".format(self.node_id, model_diff)) + print(f"Node {self.node_id} model_diff{model_diff}") self.comm_utils.send_signal( dest=0, data=copy.deepcopy(self.mask), tag=self.tag.SHARE_MASKS ) self.set_representation(new_repr) - # # locally train - print("Node {} local train".format(self.node_id)) + # locally train + print(f"Node {self.node_id} local train") self.local_train() loss, acc = self.local_test() - print("Node {} local test: {}".format(self.node_id, acc)) + print(f"Node {self.node_id} local test: {acc}") repr = self.get_representation() - # calculate new mask m_k,t+1 - gradient = None if not self.config["static"]: if not self.dis_gradient_check: gradient = self.screen_gradient() - self.mask, num_remove = self.fire_mask(self.mask, round, total_epochs) + self.mask, num_remove = self.fire_mask(self.mask, epoch, total_epochs) self.mask = self.regrow_mask(self.mask, num_remove, gradient) self.comm_utils.send_signal( dest=0, data=copy.deepcopy(repr), tag=self.tag.SHARE_WEIGHTS @@ -440,70 +439,63 @@ def run_protocol(self): # test updated model self.set_representation(repr) loss, acc = self.local_test() - # print("Node {} test_loss: {} test_acc:{}".format(self.node_id, loss,acc)) self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH) class DisPFLServer(BaseServer): + """ + Server class for DisPFL (Distributed Personalized Federated Learning). + """ def __init__(self, config) -> None: super().__init__(config) - # self.set_parameters() self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.dense_ratio = self.config["dense_ratio"] self.num_users = self.config["num_users"] def get_representation(self) -> OrderedDict[str, Tensor]: """ - Share the model weights + Share the model weights. """ return self.model.state_dict() def send_representations(self, representations): """ - Set the model + Set the model. """ for client_node in self.users: self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES) self.log_utils.log_console( - "Server sent {} representations to node {}".format( - len(representations), client_node - ) + f"Server sent {len(representations)} representations to node {client_node}" ) - # self.model.load_state_dict(representation) def test(self) -> float: """ - Test the model on the server + Test the model on the server. """ test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy - # so far + # TODO save the model if the accuracy is better than the best accuracy so far if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) return acc - def single_round(self, round, active_ths_rnd): + def single_round(self, epoch, active_ths_rnd): """ - Runs the whole training procedure + Runs the whole training procedure. """ for client_node in self.users: self.log_utils.log_console( - "Server sending semaphore from {} to {}".format( - self.node_id, client_node - ) + f"Server sending semaphore from {self.node_id} to {client_node}" ) self.comm_utils.send_signal( dest=client_node, data=active_ths_rnd, tag=self.tag.START ) - if round != 0: + if epoch != 0: self.comm_utils.send_signal( dest=client_node, data=[self.reprs, self.masks], @@ -524,21 +516,23 @@ def get_trainable_params(self): return param_dict def run_protocol(self): + """ + Runs the entire training protocol. + """ self.log_utils.log_console("Starting iid clients federated averaging") start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] - for round in range(start_epochs, total_epochs): - self.round = round + for epoch in range(start_epochs, total_epochs): + self.round = epoch active_ths_rnd = np.random.choice( [0, 1], size=self.num_users, p=[1.0 - self.config["active_rate"], self.config["active_rate"]], ) - self.log_utils.log_console("Starting round {}".format(round)) + self.log_utils.log_console(f"Starting round {epoch}") - # print("weight:",mask_pers_shared) - self.single_round(round, active_ths_rnd) + self.single_round(epoch, active_ths_rnd) accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH) - self.log_utils.log_console("Round {} done; acc {}".format(round, accs)) + self.log_utils.log_console(f"Round {epoch} done; acc {accs}") diff --git a/src/algos/def_kt.py b/src/algos/def_kt.py index 12201d1..0dcfe47 100644 --- a/src/algos/def_kt.py +++ b/src/algos/def_kt.py @@ -1,14 +1,19 @@ -from collections import OrderedDict -from typing import Any, Dict, List -from torch import Tensor, cat +""" +This module defines the DefKTClient and DefKTServer classes for federated learning using a knowledge +transfer approach. +""" + import copy -import torch.nn as nn import random +from collections import OrderedDict +from typing import List +from torch import Tensor +import torch.nn as nn from algos.base_class import BaseClient, BaseServer -class CommProtocol(object): +class CommProtocol: """ Communication protocol tags for the server and clients """ @@ -20,14 +25,16 @@ class CommProtocol(object): class DefKTClient(BaseClient): + """ + Client class for DefKT (Deep Mutual Learning with Knowledge Transfer) + """ def __init__(self, config) -> None: super().__init__(config) self.config = config self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.server_node = 1 # leader node + self.best_acc = 0.0 # Initialize best accuracy attribute if self.node_id == 1: self.num_users = config["num_users"] self.clients = list(range(2, self.num_users + 1)) @@ -39,8 +46,6 @@ def local_train(self): avg_loss = self.model_utils.train( self.model, self.optim, self.dloader, self.loss_fn, self.device ) - # print("Client {} finished training with loss {}".format(self.node_id, avg_loss)) - # self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch) def local_test(self, **kwargs): """ @@ -49,8 +54,6 @@ def local_test(self, **kwargs): test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy - # so far if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) @@ -58,11 +61,11 @@ def local_test(self, **kwargs): def deep_mutual_train(self, teacher_repr): """ - Train the model locally + Train the model locally with deep mutual learning """ teacher_model = copy.deepcopy(self.model) teacher_model.load_state_dict(teacher_repr) - print("Deep mutual learning at student Node {}".format(self.node_id)) + print(f"Deep mutual learning at student Node {self.node_id}") avg_loss, acc = self.model_utils.deep_mutual_train( [self.model, teacher_model], self.optim, self.dloader, self.device ) @@ -80,9 +83,9 @@ def set_representation(self, representation: OrderedDict[str, Tensor]): self.model.load_state_dict(representation) def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): - # All models are sampled currently at every round - # Each model is assumed to have equal amount of data and hence - # coeff is same for everyone + """ + Federated averaging of model weights + """ num_users = len(model_wts) coeff = 1 / num_users avgd_wts = OrderedDict() @@ -106,26 +109,28 @@ def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]): def send_representations(self, representation): """ - Set the model + Send the model representations to the clients """ for client_node in self.clients: self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) - print("Node 1 sent average weight to {} nodes".format(len(self.clients))) + print(f"Node 1 sent average weight to {len(self.clients)} nodes") def single_round(self, self_repr): """ - Runs the whole training procedure + Runs a single training round """ print("Node 1 waiting for all clients to finish") reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.DONE) reprs.append(self_repr) - print("Node 1 received {} clients' weights".format(len(reprs))) + print(f"Node 1 received {len(reprs)} clients' weights") avg_wts = self.aggregate(reprs) self.send_representations(avg_wts) return avg_wts - # wait for all clients to finish def assign_own_status(self, status): + """ + Assign the status (teacher/student) to the client + """ if self.node_id in status[0]: self.status = "teacher" index = status[0].index(self.node_id) @@ -137,81 +142,58 @@ def assign_own_status(self, status): else: self.status = None self.pair_id = None - print( - "Node {} is a {}, pair with {}".format( - self.node_id, self.status, self.pair_id - ) - ) + print(f"Node {self.node_id} is a {self.status}, pair with {self.pair_id}") def run_protocol(self): + """ + Runs the entire training protocol + """ start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] - for round in range(start_epochs, total_epochs): - # self.log_utils.logging.info("Client waiting for semaphore from {}".format(self.server_node)) - # print("Client waiting for semaphore from {}".format(self.server_node)) + for epoch in range(start_epochs, total_epochs): status = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START) self.assign_own_status(status) - # print("semaphore received, start local training") - # self.log_utils.logging.info("Client received semaphore from {}".format(self.server_node)) if self.status == "teacher": self.local_train() - # self.local_test() self_repr = self.get_representation() self.comm_utils.send_signal( dest=self.pair_id, data=self_repr, tag=self.tag.DONE ) - print( - "Node {} sent repr to student node {}".format( - self.node_id, self.pair_id - ) - ) - # self.log_utils.logging.info("Client {} sending done signal to {}".format(self.node_id, self.server_node)) - # print("sending signal to node {}".format(self.server_node)) + print(f"Node {self.node_id} sent repr to student node {self.pair_id}") elif self.status == "student": teacher_repr = self.comm_utils.wait_for_signal( src=self.pair_id, tag=self.tag.DONE ) - print( - "Node {} received repr from teacher node {}".format( - self.node_id, self.pair_id - ) - ) + print(f"Node {self.node_id} received repr from teacher node {self.pair_id}") self.deep_mutual_train(teacher_repr) else: - # self.comm_utils.send_signal(dest=self.server_node, data=self_repr, tag=self.tag.DONE) - print("Node {} do nothing".format(self.node_id)) - # repr = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.UPDATES) - - # self.set_representation(repr) - # test updated model + print(f"Node {self.node_id} do nothing") acc = self.local_test() - print("Node {} test_acc:{:.4f}".format(self.node_id, acc)) + print(f"Node {self.node_id} test_acc:{acc:.4f}") self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH) class DefKTServer(BaseServer): + """ + Server class for DefKT (Deep Mutual Learning with Knowledge Transfer) + """ def __init__(self, config) -> None: super().__init__(config) - # self.set_parameters() self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.best_acc = 0.0 # Initialize best accuracy attribute def send_representations(self, representations): """ - Set the model + Send the model representations to the clients """ for client_node in self.users: self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES) self.log_utils.log_console( - "Server sent {} representations to node {}".format( - len(representations), client_node - ) + f"Server sent {len(representations)} representations to node {client_node}" ) - # self.model.load_state_dict(representation) def test(self) -> float: """ @@ -220,52 +202,49 @@ def test(self) -> float: test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy - # so far if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) return acc def assigns_clients(self): + """ + Assigns clients as teachers and students + """ num_teachers = self.config["num_teachers"] clients = list(range(1, self.num_users + 1)) if 2 * num_teachers > self.num_users: return None # Not enough room to pick two non-overlapping subarrays - - # Pick the starting index of the first subarray selected_indices = random.sample(range(self.num_users), 2 * num_teachers) selected_elements = [clients[i] for i in selected_indices] - - # Divide the selected elements into two arrays of length num_teachers teachers = selected_elements[:num_teachers] students = selected_elements[num_teachers:] - return teachers, students def single_round(self): """ - Runs the whole training procedure + Runs a single training round """ teachers, students = self.assigns_clients() # type: ignore - self.log_utils.log_console("Teachers:{}".format(teachers)) - self.log_utils.log_console("Students:{}".format(students)) + self.log_utils.log_console(f"Teachers:{teachers}") + self.log_utils.log_console(f"Students:{students}") for client_node in self.users: self.log_utils.log_console( - "Server sending status from {} to {}".format(self.node_id, client_node) + f"Server sending status from {self.node_id} to {client_node}" ) self.comm_utils.send_signal( dest=client_node, data=[teachers, students], tag=self.tag.START ) def run_protocol(self): + """ + Runs the entire training protocol + """ self.log_utils.log_console("Starting iid clients federated averaging") start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] - for round in range(start_epochs, total_epochs): - self.round = round - self.log_utils.log_console("Starting round {}".format(round)) + for epoch in range(start_epochs, total_epochs): + self.log_utils.log_console(f"Starting round {epoch}") self.single_round() - accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH) - self.log_utils.log_console("Round {} done; acc {}".format(round, accs)) + self.log_utils.log_console(f"Round {epoch} done; acc {accs}") diff --git a/src/algos/fedfomo.py b/src/algos/fedfomo.py index 3de1984..bd106cf 100644 --- a/src/algos/fedfomo.py +++ b/src/algos/fedfomo.py @@ -1,6 +1,10 @@ +""" +Module for FedFomo algorithm. +""" + from collections import OrderedDict -from typing import Any, Dict, List -from torch import Tensor, cat, zeros_like, numel, randperm, from_numpy +from typing import List +from torch import Tensor import torch import torch.nn as nn import random @@ -11,7 +15,7 @@ from algos.base_class import BaseClient, BaseServer -class CommProtocol(object): +class CommProtocol: """ Communication protocol tags for the server and clients """ @@ -26,74 +30,72 @@ class CommProtocol(object): class FedFomoClient(BaseClient): + """ + Client class for FedFomo algorithm. + """ + def __init__(self, config) -> None: super().__init__(config) self.config = config self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.dense_ratio = self.config["dense_ratio"] self.anneal_factor = self.config["anneal_factor"] self.dis_gradient_check = self.config["dis_gradient_check"] self.server_node = 1 # leader node self.num_users = config["num_users"] self.neighbors = list(range(self.num_users)) + self.mask = None + self.params = None + self.index = self.node_id - 1 + self.repr = None if self.node_id == 1: self.clients = list(range(2, self.num_users + 1)) def local_train(self): """ - Train the model locally - + Train the model locally. """ loss, acc = self.model_utils.train_mask( self.model, self.mask, self.optim, self.dloader, self.loss_fn, self.device ) - - print("Node{} train loss: {}, train acc: {}".format(self.node_id, loss, acc)) - # loss = self.model_utils.train_mask(self.model, self.mask,self.optim, - # self.dloader, self.loss_fn, - # self.device) - - # print("Client {} finished training with loss {}".format(self.node_id, avg_loss)) - # self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch) + print(f"Node{self.node_id} train loss: {loss}, train acc: {acc}") def local_test(self, **kwargs): """ - Test the model locally, not to be used in the traditional FedAvg + Test the model locally. """ test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy so far - # if acc > self.best_acc: - # self.best_acc = acc - # self.model_utils.save_model(self.model, self.model_save_path) return test_loss, acc def get_trainable_params(self): - param_dict = {} - for name, param in self.model.named_parameters(): - param_dict[name] = param + """ + Get trainable parameters. + """ + param_dict = {name: param for name, param in self.model.named_parameters()} return param_dict def get_representation(self) -> OrderedDict[str, Tensor]: """ - Share the model weights + Share the model weights. """ return self.model.state_dict() def set_representation(self, representation: OrderedDict[str, Tensor]): """ - Set the model weights + Set the model weights. """ self.model.load_state_dict(representation) - def fire_mask(self, masks, round, total_round): + def fire_mask(self, masks, epoch, total_epochs): + """ + Fire mask to prune the model. + """ weights = self.get_representation() drop_ratio = ( - self.anneal_factor / 2 * (1 + np.cos((round * np.pi) / total_round)) + self.anneal_factor / 2 * (1 + np.cos((epoch * np.pi) / total_epochs)) ) new_masks = copy.deepcopy(masks) @@ -106,11 +108,14 @@ def fire_mask(self, masks, round, total_round): torch.abs(weights[name]), 100000 * torch.ones_like(weights[name]), ) - x, idx = torch.sort(temp_weights.view(-1).to(self.device)) + _, idx = torch.sort(temp_weights.view(-1).to(self.device)) new_masks[name].view(-1)[idx[: num_remove[name]]] = 0 return new_masks, num_remove def regrow_mask(self, masks, num_remove, gradient=None): + """ + Regrow mask after pruning. + """ new_masks = copy.deepcopy(masks) for name in masks: if not self.dis_gradient_check: @@ -119,7 +124,7 @@ def regrow_mask(self, masks, num_remove, gradient=None): torch.abs(gradient[name]).to(self.device), -100000 * torch.ones_like(gradient[name]).to(self.device), ) - sort_temp, idx = torch.sort( + _, idx = torch.sort( temp.view(-1).to(self.device), descending=True ) new_masks[name].view(-1)[idx[: num_remove[name]]] = 1 @@ -146,7 +151,7 @@ def aggregate( w_local_mdl, ): """ - Aggregate the model weights + Aggregate the model weights. """ w_easy = copy.deepcopy(weights_local[nei_indexs]) w_easy = np.maximum(w_easy, 0) @@ -175,34 +180,34 @@ def aggregate( def send_representations(self, representation): """ - Set the model + Send the model representations to clients. """ for client_node in self.clients: self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) - print("Node 1 sent average weight to {} nodes".format(len(self.clients))) + print(f"Node 1 sent average weight to {len(self.clients)} nodes") def screen_gradient(self): + """ + Screen gradient method for model pruning. + """ model = self.model model.eval() - # # # train and update criterion = nn.CrossEntropyLoss().to(self.device) - # # sample one epoch of data model.zero_grad() (x, labels) = next(iter(self.dloader)) x, labels = x.to(self.device), labels.to(self.device) log_probs = model.forward(x) loss = criterion(log_probs, labels.long()) loss.backward() - gradient = {} - for name, param in self.model.named_parameters(): - gradient[name] = param.grad.to("cpu") - + gradient = {name: param.grad.to("cpu") for name, param in self.model.named_parameters()} return gradient def hamming_distance(self, mask_a, mask_b): + """ + Calculate the Hamming distance between two masks. + """ dis = 0 total = 0 - for key in mask_a: dis += torch.sum( mask_a[key].int().to(self.device) ^ mask_b[key].int().to(self.device) @@ -213,10 +218,11 @@ def hamming_distance(self, mask_a, mask_b): def benefit_choose( self, round_idx, cur_clnt, client_num_in_total, client_num_per_round, p_choose ): + """ + Benefit choose method for client selection. + """ if client_num_in_total == client_num_per_round: - client_indexes = [ - client_index for client_index in range(client_num_in_total) - ] + client_indexes = list(range(client_num_in_total)) else: num_users = min(client_num_per_round, client_num_in_total) p_choose[cur_clnt] = 0 @@ -231,18 +237,23 @@ def benefit_choose( range(client_num_in_total), num_users, replace=False ) - # self.logger.info("client_indexes = %s" % str(client_indexes)) return client_indexes def model_difference(self, model_a, model_b): - a = sum( + """ + Calculate the difference between two models. + """ + diff = sum( [torch.sum(torch.square(model_a[name] - model_b[name])) for name in model_a] ) - return a + return diff def update_weight( self, curr_idx, nei_indexs, w_per_mdls_lstrd, weight_local, w_local ): + """ + Update the weights for the clients. + """ client = self.client_list[curr_idx] metrics = client.val_test( w_per_mdls_lstrd[curr_idx], self.val_data_local_dict[curr_idx] @@ -282,6 +293,9 @@ def update_weight( return weight_local def run_protocol(self): + """ + Run the entire protocol for FedFomoClient. + """ start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] self.params = self.get_trainable_params() @@ -289,37 +303,30 @@ def run_protocol(self): weights_locals = np.full((self.num_users), 1.0 / self.num_users) p_choose_locals = np.ones(shape=(self.num_users)) reprs_lstrnd = [ - copy.deepcopy(self.get_representation()) for i in range(self.num_users) + copy.deepcopy(self.get_representation()) for _ in range(self.num_users) ] repr_per_global = [ - copy.deepcopy(self.get_representation()) for i in range(self.num_users) + copy.deepcopy(self.get_representation()) for _ in range(self.num_users) ] - for round in range(start_epochs, total_epochs): - # wait for signal to start round - if round != 0: - [reprs_lstrnd, masks_lstrnd] = self.comm_utils.wait_for_signal( + for epoch in range(start_epochs, total_epochs): + if epoch != 0: + [reprs_lstrnd, _] = self.comm_utils.wait_for_signal( src=0, tag=self.tag.LAST_ROUND ) self.local_train() self.repr = self.get_representation() - # share data with client 1 nei_indexs = self.benefit_choose( - round, + epoch, self.index, self.num_users, self.config["neighbors"], p_choose_locals[self.index], ) - # If not selected in full, the current clint is made up and the - # aggregation operation is performed if self.num_users != self.config["neighbors"]: - # when not active this round nei_indexs = np.append(nei_indexs, self.index) nei_indexs = np.sort(nei_indexs) print( - "Node {}'s neighbors index:{}".format( - self.index, [i + 1 for i in nei_indexs] - ) + f"Node {self.index}'s neighbors index: {[i + 1 for i in nei_indexs]}" ) weights_locals = self.update_weight( @@ -344,70 +351,65 @@ def run_protocol(self): self.set_representation(new_repr) loss, acc = self.local_test() - # print("Node {} test_loss: {} test_acc:{}".format(self.node_id, loss,acc)) self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH) class FedFomoServer(BaseServer): + """ + Server class for FedFomo algorithm. + """ + def __init__(self, config) -> None: super().__init__(config) - # self.set_parameters() self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.dense_ratio = self.config["dense_ratio"] self.num_users = self.config["num_users"] + self.reprs = None + self.masks = None def get_representation(self) -> OrderedDict[str, Tensor]: """ - Share the model weights + Share the model weights. """ return self.model.state_dict() def send_representations(self, representations): """ - Set the model + Set the model representations. """ for client_node in self.users: self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES) self.log_utils.log_console( - "Server sent {} representations to node {}".format( - len(representations), client_node - ) + f"Server sent {len(representations)} representations to node {client_node}" ) - # self.model.load_state_dict(representation) def test(self) -> float: """ - Test the model on the server + Test the model on the server. """ test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy - # so far if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) return acc - def single_round(self, round, active_ths_rnd): + def single_round(self, epoch, active_ths_rnd): """ - Runs the whole training procedure + Runs the whole training procedure for a single round. """ for client_node in self.users: self.log_utils.log_console( - "Server sending semaphore from {} to {}".format( - self.node_id, client_node - ) + f"Server sending semaphore from {self.node_id} to {client_node}" ) self.comm_utils.send_signal( dest=client_node, data=active_ths_rnd, tag=self.tag.START ) - if round != 0: + if epoch != 0: self.comm_utils.send_signal( dest=client_node, data=[self.reprs, self.masks], @@ -422,27 +424,30 @@ def single_round(self, round, active_ths_rnd): ) def get_trainable_params(self): - param_dict = {} - for name, param in self.model.named_parameters(): - param_dict[name] = param + """ + Get trainable parameters. + """ + param_dict = {name: param for name, param in self.model.named_parameters()} return param_dict def run_protocol(self): + """ + Run the entire protocol for FedFomoServer. + """ self.log_utils.log_console("Starting iid clients federated averaging") start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] - for round in range(start_epochs, total_epochs): - self.round = round + for epoch in range(start_epochs, total_epochs): + self.round = epoch active_ths_rnd = np.random.choice( [0, 1], size=self.num_users, p=[1.0 - self.config["active_rate"], self.config["active_rate"]], ) - self.log_utils.log_console("Starting round {}".format(round)) + self.log_utils.log_console(f"Starting round {epoch}") - # print("weight:",mask_pers_shared) - self.single_round(round, active_ths_rnd) + self.single_round(epoch, active_ths_rnd) accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH) - self.log_utils.log_console("Round {} done; acc {}".format(round, accs)) + self.log_utils.log_console(f"Round {epoch} done; acc {accs}") diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 2e12d1c..15d7fe7 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -1,13 +1,14 @@ -from collections import OrderedDict +""" +Module for FedStaticClient and FedStaticServer in Federated Learning. +""" + +from collections import defaultdict from typing import Any, Dict, List +import numpy as np import torch import torch.nn as nn -import random -import numpy as np from algos.base_class import BaseFedAvgClient, BaseFedAvgServer - -from collections import defaultdict from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays from algos.fl_ring import RingTopology from algos.fl_grid import GridTopology @@ -16,107 +17,112 @@ class FedStaticClient(BaseFedAvgClient): - def __init__(self, config) -> None: + """ + Federated Static Client Class. + """ + def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) - def get_collaborator_weights(self, reprs_dict, round): + def get_collaborator_weights(self, reprs_dict: Dict[int, Any], rnd: int) -> Dict[int, float]: """ - Returns the weights of the collaborators for the current round + Returns the weights of the collaborators for the current round. """ total_rounds = self.config["rounds"] within_community_sampling = self.config.get("within_community_sampling", 1) p_within_decay = self.config.get("p_within_decay", None) + if p_within_decay is not None: - if p_within_decay == "linear_inc": - within_community_sampling = within_community_sampling * ( - round / total_rounds - ) - elif p_within_decay == "linear_dec": - within_community_sampling = within_community_sampling * ( - 1 - round / total_rounds - ) - elif p_within_decay == "exp_inc": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log( - (1 - within_community_sampling) / within_community_sampling - ) - within_community_sampling = within_community_sampling * np.exp( - alpha * round / total_rounds - ) - elif p_within_decay == "exp_dec": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log( - within_community_sampling / (1 - within_community_sampling) - ) - within_community_sampling = within_community_sampling * np.exp( - -alpha * round / total_rounds - ) - elif p_within_decay == "log_inc": - alpha = np.exp(1 / within_community_sampling) - 1 - within_community_sampling = within_community_sampling * np.log2( - 1 + alpha * round / total_rounds - ) + within_community_sampling = self._decay_within_sampling( + p_within_decay, within_community_sampling, rnd, total_rounds + ) algo = self.config["algo"] + selected_ids = self._select_ids_based_on_algo(algo) + + collab_weights = defaultdict(lambda: 0.0) + for idx in selected_ids: + own_aggr_weight = self.config.get("own_aggr_weight", 1 / len(selected_ids)) + own_aggr_weight = self._apply_aggr_weight_strategy( + own_aggr_weight, rnd, total_rounds + ) + + collab_weights[idx] = self._calculate_collab_weight(idx, own_aggr_weight, selected_ids) + + return collab_weights + + def _decay_within_sampling(self, strategy: str, p: float, rnd: int, total_rounds: int) -> float: + """ + Applies the within-community sampling decay strategy. + """ + if strategy == "linear_inc": + p *= (rnd / total_rounds) + elif strategy == "linear_dec": + p *= (1 - rnd / total_rounds) + elif strategy == "exp_inc": + alpha = np.log((1 - p) / p) + p *= np.exp(alpha * rnd / total_rounds) + elif strategy == "exp_dec": + alpha = np.log(p / (1 - p)) + p *= np.exp(-alpha * rnd / total_rounds) + elif strategy == "log_inc": + alpha = np.exp(1 / p) - 1 + p *= np.log2(1 + alpha * rnd / total_rounds) + return p + + def _select_ids_based_on_algo(self, algo: str) -> List[int]: + """ + Selects IDs based on the specified algorithm. + """ if algo == "random": topology = RandomTopology() - selected_ids = topology.get_selected_ids( - self.node_id, self.config, self.reprs_dict, self.communities - ) - elif algo == "ring": + return topology.get_selected_ids(self.node_id, self.config, self.reprs_dict, self.communities) + if algo == "ring": topology = RingTopology() - selected_ids = topology.get_selected_ids(self.node_id, self.config) - - elif algo == "grid": + return topology.get_selected_ids(self.node_id, self.config) + if algo == "grid": topology = GridTopology() - selected_ids = topology.get_selected_ids(self.node_id, self.config) - - elif algo == "torus": + return topology.get_selected_ids(self.node_id, self.config) + if algo == "torus": topology = TorusTopology() - selected_ids = topology.get_selected_ids(self.node_id, self.config) + return topology.get_selected_ids(self.node_id, self.config) + return [] - collab_weights = defaultdict(lambda: 0.0) - for idx in selected_ids: - own_aggr_weight = self.config.get("own_aggr_weight", 1 / len(selected_ids)) - - aggr_weight_strategy = self.config.get("aggr_weight_strategy", None) - if aggr_weight_strategy is not None: - init_weight = 0.1 - target_weight = 0.5 - if aggr_weight_strategy == "linear": - target_round = total_rounds // 2 - own_aggr_weight = 1 - ( - init_weight - + (target_weight - init_weight) * (min(1, round / target_round)) - ) - elif aggr_weight_strategy == "log": - alpha = 0.05 - own_aggr_weight = 1 - ( - init_weight - + (target_weight - init_weight) - * ( - np.log(alpha * (round / total_rounds) + 1) - / np.log(alpha + 1) - ) - ) - else: - raise ValueError( - f"Aggregation weight strategy {aggr_weight_strategy} not implemented" - ) - - if self.node_id == 1 and idx == 1: - print(f"Collaborator {idx} weight: {own_aggr_weight}") - if idx == self.node_id: - collab_weights[idx] = own_aggr_weight + def _apply_aggr_weight_strategy(self, weight: float, rnd: int, total_rounds: int) -> float: + """ + Applies the aggregation weight strategy. + """ + strategy = self.config.get("aggr_weight_strategy", None) + if strategy is not None: + init_weight = 0.1 + target_weight = 0.5 + if strategy == "linear": + target_round = total_rounds // 2 + weight = 1 - (init_weight + (target_weight - init_weight) * (min(1, rnd / target_round))) + elif strategy == "log": + alpha = 0.05 + weight = 1 - (init_weight + (target_weight - init_weight) * (np.log(alpha * (rnd / total_rounds) + 1) / np.log(alpha + 1))) else: - collab_weights[idx] = (1 - own_aggr_weight) / (len(selected_ids) - 1) + raise ValueError(f"Aggregation weight strategy {strategy} not implemented") + return weight - return collab_weights + def _calculate_collab_weight(self, idx: int, own_aggr_weight: float, selected_ids: List[int]) -> float: + """ + Calculates the collaborator weight. + """ + if idx == self.node_id: + return own_aggr_weight + return (1 - own_aggr_weight) / (len(selected_ids) - 1) - def get_representation(self): + def get_representation(self) -> Dict[str, torch.Tensor]: + """ + Returns the model weights as representation. + """ return self.get_model_weights() - def mask_last_layer(self): + def mask_last_layer(self) -> None: + """ + Masks the last layer of the model. + """ wts = self.get_model_weights() keys = self.model_utils.get_last_layer_keys(wts) key = [k for k in keys if "weight" in k][0] @@ -124,7 +130,10 @@ def mask_last_layer(self): weight[self.classes_of_interest] = wts[key][self.classes_of_interest] self.model.load_state_dict({key: weight.to(self.device)}, strict=False) - def freeze_model_except_last_layer(self): + def freeze_model_except_last_layer(self) -> None: + """ + Freezes the model parameters except for the last layer. + """ wts = self.get_model_weights() keys = self.model_utils.get_last_layer_keys(wts) @@ -132,96 +141,73 @@ def freeze_model_except_last_layer(self): if name not in keys: param.requires_grad = False - def unfreeze_model(self): + def unfreeze_model(self) -> None: + """ + Unfreezes all model parameters. + """ for param in self.model.parameters(): param.requires_grad = True - def flatten_repr(self, repr): - params = [] - - for key in repr.keys(): - params.append(repr[key].view(-1)) - - params = torch.cat(params) - - return params + def flatten_repr(self, repr_dict: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Flattens the representation dictionary into a single tensor. + """ + params = [repr_dict[key].view(-1) for key in repr_dict.keys()] + return torch.cat(params) - def compute_pseudo_grad_norm(self, prev_wts, new_wts): + def compute_pseudo_grad_norm(self, prev_wts: Dict[str, torch.Tensor], new_wts: Dict[str, torch.Tensor]) -> float: + """ + Computes the pseudo gradient norm. + """ return np.linalg.norm(self.flatten_repr(prev_wts) - self.flatten_repr(new_wts)) - def run_protocol(self): + def run_protocol(self) -> None: + """ + Runs the federated learning protocol for the client. + """ print(f"Client {self.node_id} ready to start training") start_round = self.config.get("start_round", 0) if start_round != 0: - raise NotImplementedError( - "Start round different from 0 not implemented yet" - ) + raise NotImplementedError("Start round different from 0 not implemented yet") total_rounds = self.config["rounds"] epochs_per_round = self.config["epochs_per_round"] - for round in range(start_round, total_rounds): + for rnd in range(start_round, total_rounds): stats = {} # Wait on server to start the round - self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.ROUND_START - ) + self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.ROUND_START) if self.config.get("finetune_last_layer", False): self.freeze_model_except_last_layer() # Train locally and send the representation to the server if not self.config.get("local_train_after_aggr", False): - stats["train_loss"], stats["train_acc"] = self.local_train( - epochs_per_round - ) + stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - repr = self.get_representation() - self.comm_utils.send_signal( - dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT - ) + repr_dict = self.get_representation() + self.comm_utils.send_signal(dest=self.server_node, data=repr_dict, tag=self.tag.REPR_ADVERT) # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.REPRS_SHARE - ) - - # In the future this dict might be generated by the server to send - # only requested models + reprs = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPRS_SHARE) reprs_dict = {k: v for k, v in enumerate(reprs, 1)} - # Aggregate the representations based on the collab weights - collab_weights_dict = self.get_collaborator_weights(reprs_dict, round) - - # Since clients representations are also used to transmit knowledge - # There is no need to fetch the server for the selected clients' - # knowledge + # Aggregate the representations based on the collaborator weights + collab_weights_dict = self.get_collaborator_weights(reprs_dict, rnd) models_wts = reprs_dict layers_to_ignore = self.model_keys_to_ignore - active_collab = set([k for k, v in collab_weights_dict.items() if v > 0]) + active_collab = {k for k, v in collab_weights_dict.items() if v > 0} inter_commu_last_layer_to_aggr = self.config.get("inter_commu_layer", None) - # If partial merging is on and some client selected client is - # outside the community, ignore layers after specified layer if inter_commu_last_layer_to_aggr is not None and len( set(self.communities[self.node_id]).intersection(active_collab) ) != len(active_collab): - layer_idx = self.model_utils.models_layers_idx[self.config["model"]][ - inter_commu_last_layer_to_aggr - ] - layers_to_ignore = ( - self.model_keys_to_ignore - + list(list(models_wts.values())[0].keys())[layer_idx + 1 :] - ) - - avg_wts = self.weighted_aggregate( - models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore - ) + layer_idx = self.model_utils.models_layers_idx[self.config["model"]][inter_commu_last_layer_to_aggr] + layers_to_ignore = self.model_keys_to_ignore + list(list(models_wts.values())[0].keys())[layer_idx + 1:] - # Average whole model by default + avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore) self.set_model_weights(avg_wts, layers_to_ignore) if self.config.get("train_only_fc", False): - self.mask_last_layer() self.freeze_model_except_last_layer() self.local_train(1) @@ -231,88 +217,56 @@ def run_protocol(self): # Train locally and send the representation to the server if self.config.get("local_train_after_aggr", False): - prev_wts = self.get_model_weights() - stats["train_loss"], stats["train_acc"] = self.local_train( - epochs_per_round - ) + stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) new_wts = self.get_model_weights() - - stats["pseudo grad norm"] = self.compute_pseudo_grad_norm( - prev_wts, new_wts - ) - - # Test updated model + stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) stats["test_acc_after_training"] = self.local_test() - # Include collab weights in the stats collab_weight = np.zeros(self.config["num_users"]) for k, v in collab_weights_dict.items(): collab_weight[k - 1] = v stats["Collaborator weights"] = collab_weight - self.comm_utils.send_signal( - dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS - ) + self.comm_utils.send_signal(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) class FedStaticServer(BaseFedAvgServer): - def __init__(self, config) -> None: + """ + Federated Static Server Class. + """ + def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) - # self.set_parameters() self.config = config self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) + self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" def test(self) -> float: """ - Test the model on the server + Test the model on the server. """ - test_loss, acc = self.model_utils.test( - self.model, self._test_loader, self.loss_fn, self.device - ) - # TODO save the model if the accuracy is better than the best accuracy - # so far + _, acc = self.model_utils.test(self.model, self._test_loader, self.loss_fn, self.device) if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) return acc - def single_round(self): + def single_round(self) -> List[Dict[str, Any]]: """ - Runs the whole training procedure + Runs the whole training procedure for a single round. """ - - # Send signal to all clients to start local training for client_node in self.users: - self.comm_utils.send_signal( - dest=client_node, data=None, tag=self.tag.ROUND_START - ) - self.log_utils.log_console( - "Server waiting for all clients to finish local training" - ) + self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.ROUND_START) + self.log_utils.log_console("Server waiting for all clients to finish local training") - # Collect models from all clients - models = self.comm_utils.wait_for_all_clients( - self.users, self.tag.REPR_ADVERT - ) + models = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT) self.log_utils.log_console("Server received all clients models") - # Broadcast the models to all clients self.send_representations(models) - - # Collect round stats from all clients - clients_round_stats = self.comm_utils.wait_for_all_clients( - self.users, self.tag.ROUND_STATS - ) + clients_round_stats = self.comm_utils.wait_for_all_clients(self.users, self.tag.ROUND_STATS) self.log_utils.log_console("Server received all clients stats") - # Log the round stats on tensorboard except the collab weights - self.log_utils.log_tb_round_stats( - clients_round_stats, ["Collaborator weights"], self.round - ) + self.log_utils.log_tb_round_stats(clients_round_stats, ["Collaborator weights"], self.round) self.log_utils.log_console( f"Round test acc before local training {[stats['test_acc_before_training'] for stats in clients_round_stats]}" @@ -323,17 +277,18 @@ def single_round(self): return clients_round_stats - def run_protocol(self): + def run_protocol(self) -> None: + """ + Runs the federated learning protocol for the server. + """ self.log_utils.log_console("Starting static ring P2P collaboration") start_round = self.config.get("start_round", 0) total_round = self.config["rounds"] - # List of list stats per round stats = [] - for round in range(start_round, total_round): - self.round = round - self.log_utils.log_console("Starting round {}".format(round)) - + for rnd in range(start_round, total_round): + self.round = rnd + self.log_utils.log_console(f"Starting round {rnd}") round_stats = self.single_round() stats.append(round_stats) diff --git a/src/main.py b/src/main.py index b00431d..f4952fc 100644 --- a/src/main.py +++ b/src/main.py @@ -1,29 +1,31 @@ +""" +This module runs collaborative learning experiments using the Scheduler class. +""" + import argparse -from scheduler import Scheduler -import gc -import torch -import copy import logging +from scheduler import Scheduler + logging.getLogger("PIL").setLevel(logging.INFO) -b_default = "./configs/algo_config.py" -s_default = "./configs/sys_config.py" +B_DEFAULT = "./configs/algo_config.py" +S_DEFAULT = "./configs/sys_config.py" parser = argparse.ArgumentParser(description="Run collaborative learning experiments") parser.add_argument( "-b", nargs="?", - default=b_default, + default=B_DEFAULT, type=str, - help="filepath for benchmark config, default: {}".format(b_default), + help=f"filepath for benchmark config, default: {B_DEFAULT}", ) parser.add_argument( "-s", nargs="?", - default=s_default, + default=S_DEFAULT, type=str, - help="filepath for system config, default: {}".format(s_default), + help=f"filepath for system config, default: {S_DEFAULT}", ) args = parser.parse_args() @@ -32,7 +34,6 @@ scheduler.assign_config_by_path(args.s, args.b) print("Config loaded") - scheduler.install_config() scheduler.initialize() -scheduler.run_job() \ No newline at end of file +scheduler.run_job() diff --git a/src/resnet.py b/src/resnet.py index 072d077..18641f0 100644 --- a/src/resnet.py +++ b/src/resnet.py @@ -1,22 +1,27 @@ +# resnet.py # 2019.07.24-Changed output of forward function # Huawei Technologies Co., Ltd. -# taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py +# Taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py # for comparison with DAFL +""" +This module implements ResNet models for image classification. +""" -import torch -import torch.nn as nn +from torch import nn import torch.nn.functional as F class BasicBlock(nn.Module): + """ + A basic block for ResNet. + """ expansion = 1 def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False @@ -26,54 +31,49 @@ def __init__(self, in_planes, planes, stride=1): self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False, - ), + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): + """ + Forward pass for the BasicBlock. + """ out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out - class Bottleneck(nn.Module): + """ + A bottleneck block for ResNet. + """ expansion = 4 def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d( - planes, self.expansion * planes, kernel_size=1, bias=False - ) - self.bn3 = nn.BatchNorm2d(self.expansion * planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False, - ), + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): + """ + Forward pass for the Bottleneck block. + """ out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) @@ -83,8 +83,11 @@ def forward(self, x): class ResNet(nn.Module): + """ + A ResNet model. + """ def __init__(self, block, num_blocks, num_classes=10, num_channels=3): - super(ResNet, self).__init__() + super().__init__() self.in_planes = 64 self.conv1 = nn.Conv2d( num_channels, 64, kernel_size=3, stride=1, padding=1, bias=False @@ -105,10 +108,14 @@ def _make_layer(self, block, planes, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x, position=0, out_feature=False): + """ + Forward pass for the ResNet model. + """ if position == 0: x = self.conv1(x) x = self.bn1(x) x = F.relu(x) + # print(x.shape) # [16, 64, 32, 32] if position <= 1: x = self.layer1(x) @@ -122,41 +129,55 @@ def forward(self, x, position=0, out_feature=False): if position <= 4: x = self.layer4(x) # print(x.shape) # [16, 512, 4, 4] + if position <= 5: x = F.avg_pool2d(x, 4) feature = x.view(x.size(0), -1) x = self.linear(feature) - # print(x.shape) # [16, 10] - if not out_feature: - return x - else: - return x, feature + if out_feature: + return x, feature + return x -def ResNet10(num_channels=3, num_classes=10): +def resnet10(num_channels=3, num_classes=10): + """ + Constructs a ResNet-10 model. + """ return ResNet(BasicBlock, [1, 1, 1, 1], num_classes, num_channels) - -def ResNet18(num_channels=3, num_classes=10): +def resnet18(num_channels=3, num_classes=10): + """ + Constructs a ResNet-18 model. + """ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, num_channels) - -def ResNet34(num_channels=3, num_classes=10): +def resnet34(num_channels=3, num_classes=10): + """ + Constructs a ResNet-34 model. + """ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, num_channels) - -def ResNet50(num_channels=3, num_classes=10): +def resnet50(num_channels=3, num_classes=10): + """ + Constructs a ResNet-50 model. + """ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, num_channels) - -def ResNet101(num_channels=3, num_classes=10): +def resnet101(num_channels=3, num_classes=10): + """ + Constructs a ResNet-101 model. + """ return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, num_channels) - -def ResNet152(num_channels=3, num_classes=10): +def resnet152(num_channels=3, num_classes=10): + """ + Constructs a ResNet-152 model. + """ return ResNet(Bottleneck, [3, 8, 36, 3], num_classes, num_channels) + # model=ResNet34() # img=torch.randn((1,3,32,32)) # print(model.forward(img,0)) + diff --git a/src/resnet_in.py b/src/resnet_in.py index 7ba3efc..b160e3e 100644 --- a/src/resnet_in.py +++ b/src/resnet_in.py @@ -1,7 +1,11 @@ +# resnet_in.py # ResNet for ImageNet (224x224) -import torch -import torch.nn as nn +""" +This module implements ResNet models for ImageNet classification. +""" + +from torch import nn from torch.hub import load_state_dict_from_url @@ -52,28 +56,23 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): + """ + A basic block for ResNet. + """ expansion = 1 - def __init__( - self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None, - ): + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() + if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when - # stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) @@ -83,50 +82,36 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass for the BasicBlock. + """ identity = x - out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - out = self.conv2(out) out = self.bn2(out) - if self.downsample is not None: identity = self.downsample(x) - out += identity out = self.relu(out) - return out class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - + """ + A bottleneck block for ResNet. + """ expansion = 4 - def __init__( - self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None, - ): + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when - # stride != 1 + width = int(planes * (base_width / 64.)) * groups + self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) @@ -138,51 +123,41 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass for the Bottleneck block. + """ identity = x - out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - out = self.conv2(out) out = self.bn2(out) out = self.relu(out) - out = self.conv3(out) out = self.bn3(out) - if self.downsample is not None: identity = self.downsample(x) - out += identity out = self.relu(out) - return out class ResNet(nn.Module): - def __init__( - self, - block, - layers, - num_classes=1000, - zero_init_residual=False, - groups=1, - width_per_group=64, - replace_stride_with_dilation=None, - norm_layer=None, - ): + """ + A ResNet model. + """ + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): super(ResNet, self).__init__() + if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( @@ -208,7 +183,6 @@ def __init__( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.block = block self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): @@ -218,10 +192,12 @@ def __init__( nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to # https://arxiv.org/abs/1706.02677 + if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): @@ -230,6 +206,7 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def reset_fc(self, num_classes): + """Resets the fully connected layer with the specified number of classes.""" self.fc = nn.Linear(512 * self.block.expansion, num_classes) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): @@ -244,7 +221,6 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) - layers = [] layers.append( block( @@ -260,21 +236,14 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): ) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - ) - ) + + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x, position, return_features): - # See note [TorchScript super()] if position == 0: x = self.conv1(x) x = self.bn1(x) @@ -293,11 +262,12 @@ def _forward_impl(self, x, position, return_features): x = self.avgpool(x) feat = torch.flatten(x, 1) x = self.fc(feat) - if return_features: - return x, feat + if return_features: + return x, feat return x def forward(self, x, position, return_features=False): + """Forward pass for the ResNet model.""" return self._forward_impl(x, position, return_features=return_features) @@ -334,7 +304,9 @@ def resnet18(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + def resnet34(pretrained=False, progress=True, **kwargs): @@ -344,7 +316,9 @@ def resnet34(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + def resnet50(pretrained=False, progress=True, **kwargs): @@ -354,7 +328,9 @@ def resnet50(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + def resnet101(pretrained=False, progress=True, **kwargs): @@ -364,9 +340,9 @@ def resnet101(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet( - "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs - ) + + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + def resnet152(pretrained=False, progress=True, **kwargs): @@ -376,9 +352,9 @@ def resnet152(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet( - "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs - ) + + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + def resnext50_32x4d(pretrained=False, progress=True, **kwargs): @@ -388,11 +364,11 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 - return _resnet( - "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs - ) + + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + def resnext101_32x8d(pretrained=False, progress=True, **kwargs): @@ -402,11 +378,11 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet( - "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs - ) + + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + def wide_resnet50_2(pretrained=False, progress=True, **kwargs): @@ -420,10 +396,10 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet( - "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs - ) + + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + def wide_resnet101_2(pretrained=False, progress=True, **kwargs): @@ -437,7 +413,7 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet( - "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs - ) + + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + diff --git a/src/scheduler.py b/src/scheduler.py index e074f2c..94d481f 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -1,7 +1,19 @@ +# scheduler.py +""" +This module manages the orchestration of federated learning experiments. +""" + +import os +import random + from mpi4py import MPI import torch + +import numpy as np + import random import numpy + from algos.base_class import BaseNode from algos.fl import FedAvgClient, FedAvgServer from algos.isolated import IsolatedServer @@ -18,19 +30,29 @@ from algos.fl_central import CentralizedCLient, CentralizedServer from algos.fl_data_repr import FedDataRepClient, FedDataRepServer from algos.fl_val import FedValClient, FedValServer + +from utils.log_utils import check_and_create_path +from utils.config_utils import load_config, process_config + from utils.log_utils import copy_source_code, check_and_create_path from utils.config_utils import load_config, process_config, get_device_ids import os -# should be used as: algo_map[algo_name][rank>0](config) -# If rank is 0, then it returns the server class otherwise the client class + +# Mapping of algorithm names to their corresponding client and server classes so that they can be consumed by the scheduler later on. algo_map = { "fedavg": [FedAvgServer, FedAvgClient], "isolated": [IsolatedServer], + + "fedran": [FedRanServer, FedRanClient], + "fedgrid": [FedGridServer, FedGridClient], + "fedtorus": [FedTorusServer, FedTorusClient], "fedass": [FedAssServer, FedAssClient], "fediso": [FedIsoServer, FedIsoClient], "fedweight": [FedWeightServer, FedWeightClient], + "fedring": [FedRingServer, FedRingClient], "fedstatic": [FedStaticServer, FedStaticClient], + "swarm": [SWARMServer, SWARMClient], "dispfl": [DisPFLServer, DisPFLClient], "defkt": [DefKTServer, DefKTClient], diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 0736159..ea8670b 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,129 +1,115 @@ -import pdb +import os import numpy as np import torch import torchvision.transforms as T -from torchvision.datasets.cifar import CIFAR10 -from torchvision.datasets import MNIST from torch.utils.data import Subset -import os +from torchvision.datasets import CIFAR10, MNIST from PIL import Image - import medmnist - import wilds from wilds.datasets.wilds_dataset import WILDSSubset -class CIFAR10_DSET: - def __init__(self, dpath, rot_angle=0) -> None: - self.IMAGE_SIZE = 32 - self.NUM_CLS = 10 +class CIFAR10Dataset: + """ + CIFAR-10 Dataset Class. + """ + def __init__(self, dpath: str, rot_angle: int = 0) -> None: + self.image_size = 32 + self.num_cls = 10 self.mean = np.array((0.4914, 0.4822, 0.4465)) self.std = np.array((0.2023, 0.1994, 0.2010)) self.num_channels = 3 - self.gen_transform = T.Compose( - [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), - T.Normalize(self.mean, self.std), - ] - ) - train_transform = T.Compose( - [ - T.RandomCrop(32, padding=4), - ( - T.RandomHorizontalFlip() - if rot_angle % 180 == 0 - else T.RandomVerticalFlip() - ), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) - test_transform = T.Compose( - [ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) - if rot_angle != 0: - tr_transform, te_transform = train_transform, test_transform - - def train_transform(img): - return T.functional.rotate(tr_transform(img), angle=rot_angle) - def test_transform(img): - return T.functional.rotate(te_transform(img), angle=rot_angle) + self.train_transform = T.Compose([ + T.RandomCrop(32, padding=4), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.test_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + + if rot_angle != 0: + self.train_transform.transforms.insert(1, T.RandomVerticalFlip()) + self.train_transform.transforms.append( + T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) + ) + self.test_transform.transforms.append( + T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) + ) - self.train_dset = CIFAR10( - root=dpath, train=True, download=True, transform=train_transform - ) - self.test_dset = CIFAR10( - root=dpath, train=False, download=True, transform=test_transform - ) - self.IMAGE_BOUND_L = torch.tensor( - (-self.mean / self.std).reshape(1, -1, 1, 1) - ).float() - self.IMAGE_BOUND_U = torch.tensor( - ((1 - self.mean) / self.std).reshape(1, -1, 1, 1) - ).float() + self.train_dset = CIFAR10(root=dpath, train=True, download=True, transform=self.train_transform) + self.test_dset = CIFAR10(root=dpath, train=False, download=True, transform=self.test_transform) + self.image_bound_l = torch.tensor((-self.mean / self.std).reshape(1, -1, 1, 1)).float() + self.image_bound_u = torch.tensor(((1 - self.mean) / self.std).reshape(1, -1, 1, 1)).float() -class CIFAR10_R90_DSET(CIFAR10_DSET): - def __init__(self, dpath) -> None: +class CIFAR10R90Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 90 degrees rotation. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, rot_angle=90) -class CIFAR10_R180_DSET(CIFAR10_DSET): - def __init__(self, dpath) -> None: +class CIFAR10R180Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 180 degrees rotation. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, rot_angle=180) -class CIFAR10_R270_DSET(CIFAR10_DSET): - def __init__(self, dpath) -> None: +class CIFAR10R270Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 270 degrees rotation. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, rot_angle=270) -class MNIST_DSET: - def __init__(self, dpath) -> None: - self.IMAGE_SIZE = 28 - self.NUM_CLS = 10 +class MNISTDataset: + """ + MNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + self.image_size = 28 + self.num_cls = 10 self.mean = 0.1307 self.std = 0.3081 self.num_channels = 1 - self.gen_transform = T.Compose( - [ - T.Normalize(self.mean, self.std), - ] - ) - train_transform = T.Compose( - [ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) - test_transform = T.Compose( - [ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) - self.train_dset = MNIST( - root=dpath, train=True, download=True, transform=train_transform - ) - self.test_dset = MNIST( - root=dpath, train=False, download=True, transform=test_transform - ) - -class MEDMNIST_DSET: - def __init__(self, dpath, data_flag) -> None: + self.train_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.test_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.train_dset = MNIST(root=dpath, train=True, download=True, transform=self.train_transform) + self.test_dset = MNIST(root=dpath, train=False, download=True, transform=self.test_transform) + + +class MEDMNISTDataset: + """ + MEDMNIST Dataset Class. + """ + def __init__(self, dpath: str, data_flag: str) -> None: self.mean = np.array([0.5]) self.std = np.array([0.5]) info = medmnist.INFO[data_flag] self.num_channels = info["n_channels"] self.data_class = getattr(medmnist, info["python_class"]) - transform = T.Compose([T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5])]) + + self.transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std) + ]) + if not os.path.exists(dpath): os.makedirs(dpath) @@ -131,84 +117,92 @@ def target_transform(x): return x[0] self.train_dset = self.data_class( - root=dpath, - split="train", - transform=transform, - target_transform=target_transform, - download=True, + root=dpath, split="train", transform=self.transform, + target_transform=target_transform, download=True ) self.test_dset = self.data_class( - root=dpath, - split="test", - transform=transform, - target_transform=target_transform, - download=True, + root=dpath, split="test", transform=self.transform, + target_transform=target_transform, download=True ) -class PathMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class PathMNISTDataset(MEDMNISTDataset): + """ + PathMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "pathmnist") + self.image_size = 28 + self.num_cls = 9 - self.IMAGE_SIZE = 28 - self.NUM_CLS = 9 - -class DermaMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class DermaMNISTDataset(MEDMNISTDataset): + """ + DermaMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "dermamnist") - - self.IMAGE_SIZE = 28 - self.NUM_CLS = 7 + self.image_size = 28 + self.num_cls = 7 -class BloodMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class BloodMNISTDataset(MEDMNISTDataset): + """ + BloodMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "bloodmnist") + self.image_size = 28 + self.num_cls = 8 - self.IMAGE_SIZE = 28 - self.NUM_CLS = 8 - -class TissueMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class TissueMNISTDataset(MEDMNISTDataset): + """ + TissueMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "tissuemnist") - - self.IMAGE_SIZE = 28 - self.NUM_CLS = 8 + self.image_size = 28 + self.num_cls = 8 -class OrganAMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class OrganAMNISTDataset(MEDMNISTDataset): + """ + OrganAMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "organamnist") + self.image_size = 28 + self.num_cls = 11 - self.IMAGE_SIZE = 28 - self.NUM_CLS = 11 - -class OrganCMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class OrganCMNISTDataset(MEDMNISTDataset): + """ + OrganCMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "organcmnist") - - self.IMAGE_SIZE = 28 - self.NUM_CLS = 11 + self.image_size = 28 + self.num_cls = 11 -class OrganSMNIST_DSET(MEDMNIST_DSET): - def __init__(self, dpath) -> None: +class OrganSMNISTDataset(MEDMNISTDataset): + """ + OrganSMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: super().__init__(dpath, "organsmnist") - - self.IMAGE_SIZE = 28 - self.NUM_CLS = 11 + self.image_size = 28 + self.num_cls = 11 class CacheDataset: + """ + Caches the entire dataset in memory. + """ def __init__(self, dset): - - if hasattr(dset, "targets"): - self.targets = dset.targets - self.data = [] + self.targets = getattr(dset, "targets", None) for i in range(len(dset)): self.data.append(dset[i]) @@ -220,6 +214,9 @@ def __len__(self): class TransformDataset: + """ + Applies a transformation to the dataset. + """ def __init__(self, dset, transform): self.dset = dset self.transform = transform @@ -233,16 +230,15 @@ def __len__(self): return len(self.dset) -# https://github.com/FengHZ/KD3A/blob/master/datasets/DomainNet.py - - -def read_domainnet_data(dataset_path, domain_name, split="train", labels_to_keep=None): +def read_domainnet_data(dataset_path: str, domain_name: str, split: str = "train", labels_to_keep=None): + """ + Reads DomainNet data. + """ data_paths = [] data_labels = [] - split_file = os.path.join( - dataset_path, "splits", "{}_{}.txt".format(domain_name, split) - ) - with open(split_file, "r") as f: + split_file = os.path.join(dataset_path, "splits", f"{domain_name}_{split}.txt") + + with open(split_file, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: line = line.strip() @@ -256,31 +252,35 @@ def read_domainnet_data(dataset_path, domain_name, split="train", labels_to_keep label = int(label) data_paths.append(data_path) data_labels.append(label) + return data_paths, data_labels class DomainNet: + """ + DomainNet Dataset Class. + """ def __init__(self, data_paths, data_labels, transforms, domain_name, cache=False): self.data_paths = data_paths self.data_labels = data_labels self.transforms = transforms self.domain_name = domain_name self.cached_data = [] + if cache: for idx, _ in enumerate(data_paths): self.cached_data.append(self.__read_data__(idx)) def __read_data__(self, index): img = Image.open(self.data_paths[index]) - if not img.mode == "RGB": + if img.mode != "RGB": img = img.convert("RGB") label = self.data_labels[index] img = T.ToTensor()(img) - return img, label def __getitem__(self, index): - if len(self.cached_data) > 0: + if self.cached_data: img, label = self.cached_data[index] else: img, label = self.__read_data__(index) @@ -291,40 +291,28 @@ def __len__(self): return len(self.data_paths) -class DomainNet_DSET: - def __init__(self, dpath, domain_name): - # TODO Modify ResNet to support 64 x 64 images - self.IMAGE_SIZE = 32 - self.CROP_SCALE = 0.75 - self.IMAZE_RESIZE = int(np.ceil(self.IMAGE_SIZE * 1 / self.CROP_SCALE)) +class DomainNetDataset: + """ + DomainNet Dataset Class. + """ + def __init__(self, dpath: str, domain_name: str) -> None: + self.image_size = 32 + self.crop_scale = 0.75 + self.image_resize = int(np.ceil(self.image_size / self.crop_scale)) labels_to_keep = [ - "suitcase", - "teapot", - "pillow", - "streetlight", - "table", - "bathtub", - "wine_glass", - "vase", - "umbrella", - "bench", + "suitcase", "teapot", "pillow", "streetlight", "table", + "bathtub", "wine_glass", "vase", "umbrella", "bench" ] - self.NUM_CLS = len(labels_to_keep) + self.num_cls = len(labels_to_keep) self.num_channels = 3 - train_transform = T.Compose( - [ - T.Resize((self.IMAZE_RESIZE, self.IMAZE_RESIZE), antialias=True), - # T.ToTensor() - ] - ) - test_transform = T.Compose( - [ - T.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE), antialias=True), - # T.ToTensor() - ] - ) + train_transform = T.Compose([ + T.Resize((self.image_resize, self.image_resize), antialias=True), + ]) + test_transform = T.Compose([ + T.Resize((self.image_size, self.image_size), antialias=True), + ]) train_data_paths, train_data_labels = read_domainnet_data( dpath, domain_name, split="train", labels_to_keep=labels_to_keep ) @@ -348,13 +336,16 @@ def __init__(self, dpath, domain_name): class WildsDset: + """ + WILDS Dataset Class. + """ def __init__(self, dset, transform=None): self.dset = dset self.transform = transform self.targets = [t.item() for t in list(dset.y_array)] def __getitem__(self, index): - img, label, meta_data = self.dset[index] + img, label, _ = self.dset[index] if self.transform is not None: img = self.transform(img) return img, label.item() @@ -363,61 +354,38 @@ def __len__(self): return len(self.dset) -class Wilds_DSET: - def __init__(self, dset_name, dpath, domain): +class WildsDataset: + """ + WILDS Dataset Class. + """ + def __init__(self, dset_name: str, dpath: str, domain: int) -> None: dset = wilds.get_dataset(dset_name, download=False, root_dir=dpath) - self.NUM_CLS = len(list(np.unique(dset.y_array))) - - # print("Dataset: ", len(dset)) - # print("Number of classes: ",self.NUM_CLS) - # # print("Split arrays", np.unique(dset.split_array)) - # # print("Meta", np.unique(dset.metadata_array[:, dset.metadata_fields.index(WILDS_DOMAINS_DICT[dset_name])].numpy())) - - # print(dset.metadata_fields) - # print(np.unique(dset.metadata_array[:, dset.metadata_fields.index("region")].numpy())) - # print(np.unique(dset.metadata_array[:, dset.metadata_fields.index("year")].numpy())) + self.num_cls = len(list(np.unique(dset.y_array))) - # for i in range(51): - # idx, = np.where(np.logical_and(dset.metadata_array[:, dset.metadata_fields.index(WILDS_DOMAINS_DICT[dset_name])].numpy()==i, - # dset.split_array==0)) - # print("Domain: ", i, "Train samples: ", len(idx)) - - # Most wilds dset only have OOD data in the test set so we use the - # train set for both train and test + domain_key = WILDS_DOMAINS_DICT[dset_name] (idx,) = np.where( - np.logical_and( - dset.metadata_array[ - :, dset.metadata_fields.index(WILDS_DOMAINS_DICT[dset_name]) - ].numpy() - == domain, - dset.split_array == 0, - ) + (dset.metadata_array[:, dset.metadata_fields.index(domain_key)].numpy() == domain) & + (dset.split_array == 0) ) - # print("Dataset filter: ", len(idx)) self.mean = np.array((0.4914, 0.4822, 0.4465)) self.std = np.array((0.2023, 0.1994, 0.2010)) self.num_channels = 3 - train_transform = T.Compose( - [ - T.RandomResizedCrop(32), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) - test_transform = T.Compose( - [ - T.Resize(32), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ] - ) + train_transform = T.Compose([ + T.RandomResizedCrop(32), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + test_transform = T.Compose([ + T.Resize(32), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) num_samples_domain = len(idx) - TRAIN_RATIO = 0.8 - train_samples = int(num_samples_domain * TRAIN_RATIO) + train_samples = int(num_samples_domain * 0.8) idx = np.random.permutation(idx) train_dset = WILDSSubset(dset, idx[:train_samples], transform=None) test_dset = WILDSSubset(dset, idx[train_samples:], transform=None) @@ -425,58 +393,57 @@ def __init__(self, dset_name, dpath, domain): self.test_dset = CacheDataset(WildsDset(test_dset, transform=test_transform)) -def get_dataset(dname, dpath): +def get_dataset(dname: str, dpath: str): + """ + Returns the appropriate dataset class based on the dataset name. + """ dset_mapping = { - "cifar10": CIFAR10_DSET, - "cifar10_r0": CIFAR10_DSET, - "cifar10_r90": CIFAR10_R90_DSET, - "cifar10_r180": CIFAR10_R180_DSET, - "cifar10_r270": CIFAR10_R270_DSET, - "mnist": MNIST_DSET, - # "cifar100": CIFAR100_DSET, - "pathmnist": PathMNIST_DSET, - "dermamnist": DermaMNIST_DSET, - "bloodmnist": BloodMNIST_DSET, - "tissuemnist": BloodMNIST_DSET, - "organamnist": OrganAMNIST_DSET, - "organcmnist": OrganCMNIST_DSET, - "organsmnist": OrganSMNIST_DSET, + "cifar10": CIFAR10Dataset, + "cifar10_r0": CIFAR10Dataset, + "cifar10_r90": CIFAR10R90Dataset, + "cifar10_r180": CIFAR10R180Dataset, + "cifar10_r270": CIFAR10R270Dataset, + "mnist": MNISTDataset, + "pathmnist": PathMNISTDataset, + "dermamnist": DermaMNISTDataset, + "bloodmnist": BloodMNISTDataset, + "tissuemnist": TissueMNISTDataset, + "organamnist": OrganAMNISTDataset, + "organcmnist": OrganCMNISTDataset, + "organsmnist": OrganSMNISTDataset, } if dname.startswith("wilds"): - dname = dname.split("_") - return Wilds_DSET(dname[1], dpath, int(dname[2])) + dname_parts = dname.split("_") + return WildsDataset(dname_parts[1], dpath, int(dname_parts[2])) elif dname.startswith("domainnet"): - dname = dname.split("_") - return DomainNet_DSET(dpath, dname[1]) + dname_parts = dname.split("_") + return DomainNetDataset(dpath, dname_parts[1]) else: return dset_mapping[dname](dpath) -"""def get_noniid_dataset(dname, dpath, num_users, n_class, nsamples, rate_unbalance): - obj = get_dataset(dname, dpath) - # Chose euqal splits for every user - if dname == "cifar10": - obj.user_groups_train, obj.user_groups_test = cifar_extr_noniid(obj.train_dset, obj.test_dset, - num_users, n_class, nsamples, - rate_unbalance) - return obj""" - - def filter_by_class(dataset, classes): - indices = [i for i, (x, y) in enumerate(dataset) if y in classes] + """ + Filters the dataset by specified classes. + """ + indices = [i for i, (_, y) in enumerate(dataset) if y in classes] return Subset(dataset, indices), indices def random_samples(dataset, num_samples): + """ + Returns a random subset of samples from the dataset. + """ indices = torch.randperm(len(dataset))[:num_samples] return Subset(dataset, indices), indices def extr_noniid(train_dataset, samples_per_client, classes): - all_data = Subset( - train_dataset, [i for i, (x, y) in enumerate(train_dataset) if y in classes] - ) + """ + Extracts non-IID data from the training dataset. + """ + all_data = Subset(train_dataset, [i for i, (_, y) in enumerate(train_dataset) if y in classes]) perm = torch.randperm(len(all_data)) return Subset(all_data, perm[:samples_per_client]) @@ -484,40 +451,36 @@ def extr_noniid(train_dataset, samples_per_client, classes): def cifar_extr_noniid( train_dataset, test_dataset, num_users, n_class, num_samples, rate_unbalance ): - num_shards_train, num_imgs_train = int(50000 / num_samples), num_samples + """ + Extracts non-IID data for CIFAR-10 dataset. + """ + num_shards_train = int(50000 / num_samples) + num_imgs_train = num_samples num_classes = 10 - num_imgs_perc_test, num_imgs_test_total = 1000, 10000 + num_imgs_perc_test = 1000 + num_imgs_test_total = 10000 + assert n_class * num_users <= num_shards_train assert n_class <= num_classes - idx_class = [i for i in range(num_classes)] - idx_shard = [i for i in range(num_shards_train)] + dict_users_train = {i: np.array([]) for i in range(num_users)} dict_users_test = {i: np.array([]) for i in range(num_users)} idxs = np.arange(num_shards_train * num_imgs_train) - # labels = dataset.train_labels.numpy() labels = np.array(train_dataset.targets) idxs_test = np.arange(num_imgs_test_total) labels_test = np.array(test_dataset.targets) - # labels_test_raw = np.array(test_dataset.targets) - # stores the image idxs with their corresponding labels - # array([[ 0, 1, 2, ..., 49997, 49998, 49999], - # [ 6, 9, 9, ..., 9, 1, 1]]) idxs_labels = np.vstack((idxs, labels)) - # sorts the whole thing based on labels - # array([[29513, 16836, 32316, ..., 36910, 21518, 25648], - # [ 0, 0, 0, ..., 9, 9, 9]]) idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] idxs = idxs_labels[0, :] labels = idxs_labels[1, :] - # Same things as above except that it is test set now idxs_labels_test = np.vstack((idxs_test, labels_test)) idxs_labels_test = idxs_labels_test[:, idxs_labels_test[1, :].argsort()] idxs_test = idxs_labels_test[0, :] - # print(idxs_labels_test[1, :]) - # divide and assign + idx_shard = list(range(num_shards_train)) + for i in range(num_users): user_labels = np.array([]) rand_set = set(np.random.choice(idx_shard, n_class, replace=False)) @@ -528,14 +491,14 @@ def cifar_extr_noniid( dict_users_train[i] = np.concatenate( ( dict_users_train[i], - idxs[rand * num_imgs_train : (rand + 1) * num_imgs_train], + idxs[rand * num_imgs_train: (rand + 1) * num_imgs_train], ), axis=0, ) user_labels = np.concatenate( ( user_labels, - labels[rand * num_imgs_train : (rand + 1) * num_imgs_train], + labels[rand * num_imgs_train: (rand + 1) * num_imgs_train], ), axis=0, ) @@ -544,10 +507,7 @@ def cifar_extr_noniid( ( dict_users_train[i], idxs[ - rand - * num_imgs_train : int( - (rand + rate_unbalance) * num_imgs_train - ) + rand * num_imgs_train: int((rand + rate_unbalance) * num_imgs_train) ], ), axis=0, @@ -556,99 +516,79 @@ def cifar_extr_noniid( ( user_labels, labels[ - rand - * num_imgs_train : int( - (rand + rate_unbalance) * num_imgs_train - ) + rand * num_imgs_train: int((rand + rate_unbalance) * num_imgs_train) ], ), axis=0, ) unbalance_flag = 1 + user_labels_set = set(user_labels) - # print(user_labels_set) - # print(user_labels) for label in user_labels_set: dict_users_test[i] = np.concatenate( ( dict_users_test[i], idxs_test[ - int(label) - * num_imgs_perc_test : int(label + 1) - * num_imgs_perc_test + int(label) * num_imgs_perc_test: int(label + 1) * num_imgs_perc_test ], ), axis=0, ) - # print(set(labels_test_raw[dict_users_test[i].astype(int)])) return dict_users_train, dict_users_test def balanced_subset(dataset, num_samples): + """ + Returns a balanced subset of the dataset. + """ indices = [] targets = np.array(dataset.targets) classes = set(dataset.targets) for c in classes: indices += list((targets == c).nonzero()[0][:num_samples]) - - # Avoid samples from the same class being consecutive indices = np.random.permutation(indices) return Subset(dataset, indices), indices def random_balanced_subset(dataset, num_samples): + """ + Returns a random balanced subset of the dataset. + """ indices = [] targets = np.array(dataset.targets) classes = set(dataset.targets) for c in classes: indices += list( - np.random.choice( - list((targets == c).nonzero()[0]), num_samples, replace=False - ) + np.random.choice(list((targets == c).nonzero()[0]), num_samples, replace=False) ) return Subset(dataset, indices), indices def non_iid_unbalanced_dataidx_map(dset_obj, n_parties, beta=0.4): + """ + Returns a non-IID unbalanced data index map. + """ train_dset = dset_obj.train_dset - n_classes = dset_obj.NUM_CLS + n_classes = dset_obj.num_cls N = len(train_dset) labels = np.array(train_dset.targets) - min_size = 0 # Tracks the minimum number of samples in a party + min_size = 0 min_require_size = 10 + while min_size < min_require_size: idx_batch = [[] for _ in range(n_parties)] for k in range(n_classes): - # Get indexes of class k idx_k = np.where(labels == k)[0] np.random.shuffle(idx_k) - - # Sample proportions from a dirichlet distribution proportions = np.random.dirichlet(np.repeat(beta, n_parties)) - - # Keep only proportions that lead to a samller number of samples - # than - proportions = np.array( - [ - p * (len(idx_j) < N / n_parties) - for p, idx_j in zip(proportions, idx_batch) - ] - ) + proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)]) proportions = proportions / proportions.sum() - - # Get range of split according to proportions proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] - - # Divide class k indexes according to proportions - idx_batch = [ - idx_j + idx.tolist() - for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions)) - ] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] min_size = min([len(idx_j) for idx_j in idx_batch]) - # Convert list to map net_dataidx_map = {} for j in range(n_parties): np.random.shuffle(idx_batch[j]) @@ -656,59 +596,43 @@ def non_iid_unbalanced_dataidx_map(dset_obj, n_parties, beta=0.4): return net_dataidx_map -def non_iid_balanced( - dset_obj, n_client, n_data_per_clnt, alpha=0.4, cls_priors=None, is_train=True -): +def non_iid_balanced(dset_obj, n_client, n_data_per_clnt, alpha=0.4, cls_priors=None, is_train=True): + """ + Returns a non-IID balanced dataset. + """ if is_train: - # y, x = np.array(dset_obj.train_dset.targets), np.array(dset_obj.train_dset.data) y = np.array(dset_obj.train_dset.targets) else: - # y, x = np.array(dset_obj.test_dset.targets), np.array(dset_obj.test_dset.data) y = np.array(dset_obj.test_dset.targets) - n_cls = dset_obj.NUM_CLS - height = width = dset_obj.IMAGE_SIZE - channels = dset_obj.num_channels - - clnt_data_list = (np.ones(n_client) * n_data_per_clnt).astype( - int - ) # Number of data per client + + n_cls = dset_obj.num_cls + clnt_data_list = (np.ones(n_client) * n_data_per_clnt).astype(int) if cls_priors is None: cls_priors = np.random.dirichlet(alpha=[alpha] * n_cls, size=n_client) + prior_cumsum = np.cumsum(cls_priors, axis=1) idx_list = [np.where(y == i)[0] for i in range(n_cls)] cls_amount = np.array([len(idx_list[i]) for i in range(n_cls)]) - - # clnt_x = [np.zeros((clnt_data_list[clnt__], height, width, channels)).astype(np.float32) for clnt__ in range(n_client) ] - clnt_y = [ - np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) - for clnt__ in range(n_client) - ] + clnt_y = [np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(n_client)] clnt_idx = [[] for clnt__ in range(n_client)] clients = list(np.arange(n_client)) + while np.sum(clnt_data_list) != 0: curr_clnt = np.random.choice(clients) - # curr_clnt = np.random.randint(n_client) - # If current node is full resample a client - # print('Remaining Data: %d' %np.sum(clnt_data_list)) if clnt_data_list[curr_clnt] <= 0: clients.remove(curr_clnt) continue clnt_data_list[curr_clnt] -= 1 curr_prior = prior_cumsum[curr_clnt] while True: - cls_label = np.argmax( - (np.random.uniform() <= curr_prior) & (cls_amount > 0) - ) - # Redraw class label if trn_y is out of that class + cls_label = np.argmax((np.random.uniform() <= curr_prior) & (cls_amount > 0)) if cls_amount[cls_label] <= 0: continue cls_amount[cls_label] -= 1 idx = idx_list[cls_label][cls_amount[cls_label]] - # clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = x[idx] clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = y[idx] clnt_idx[curr_clnt].append(idx) break - # clnt_x = np.asarray(clnt_x) - clnt_y = np.asarray(clnt_y) + clnt_y = np.asarray(clnt_y) return clnt_y, clnt_idx, cls_priors diff --git a/src/utils/distrib_utils.py b/src/utils/distrib_utils.py index fadd8a0..be8a056 100644 --- a/src/utils/distrib_utils.py +++ b/src/utils/distrib_utils.py @@ -1,26 +1,41 @@ +""" +This module provides utility classes and functions for distributed learning. +""" + import numpy as np import torch -import torch.nn as nn +from torch import nn from torch.nn.parallel import DataParallel +from torch.utils.data import Subset, DataLoader from resnet import ResNet34, ResNet18, ResNet50 - -from torch.utils.data import Subset -from torch.utils.data import DataLoader - from utils.data_utils import extr_noniid - def load_weights(model_dir: str, model: nn.Module, client_num: int): - wts = torch.load("{}/saved_models/c{}.pt".format(model_dir, client_num)) + """ + Load weights for the given model and client number from the specified directory. + + Args: + model_dir (str): Directory where the model weights are stored. + model (nn.Module): Model to load the weights into. + client_num (int): Client number to identify which weights to load. + + Returns: + nn.Module: Model with loaded weights. + """ + wts = torch.load(f"{model_dir}/saved_models/c{client_num}.pt") model.load_state_dict(wts) print(f"successfully loaded checkpoint for client {client_num}") return model - -class ServerObj(): +class ServerObj: + """ + Server object for federated learning. + """ def __init__(self, config, obj, rank) -> None: - self.num_users, self.samples_per_client = config["num_users"], config["samples_per_client"] - self.device, self.device_id = obj["device"], obj["device_id"] + self.num_users = config["num_users"] + self.samples_per_client = config["samples_per_client"] + self.device = obj["device"] + self.device_id = obj["device_id"] test_dataset = obj["dset_obj"].test_dset batch_size = config["batch_size"] num_channels = obj["dset_obj"].num_channels @@ -29,84 +44,38 @@ def __init__(self, config, obj, rank) -> None: model_dict = { "ResNet18": ResNet18(num_channels), "ResNet34": ResNet34(num_channels), - "ResNet50": ResNet50(num_channels)} + "ResNet50": ResNet50(num_channels) + } model = model_dict[config["model"]] self.model = model.to(self.device) - -class ClientObj(): +class ClientObj: + """ + Client object for federated learning. + """ def __init__(self, config, obj, rank) -> None: - self.num_users, self.samples_per_client = config["num_users"], config["samples_per_client"] - self.device, self.device_id = obj["device"], obj["device_id"] - train_dataset, test_dataset = obj["dset_obj"].train_dset, obj["dset_obj"].test_dset - batch_size, lr = config["batch_size"], config["model_lr"] + self.num_users = config["num_users"] + self.samples_per_client = config["samples_per_client"] + self.device = obj["device"] + self.device_id = obj["device_id"] + train_dataset = obj["dset_obj"].train_dset + test_dataset = obj["dset_obj"].test_dset + batch_size = config["batch_size"] + lr = config["model_lr"] self.test_loader = DataLoader(test_dataset, batch_size=batch_size) indices = np.random.permutation(len(train_dataset)) optim = torch.optim.Adam self.model = ResNet34() - self.model = DataParallel( - self.model.to( - self.device), - device_ids=self.device_id) + self.model = DataParallel(self.model.to(self.device), device_ids=self.device_id) self.optim = optim(self.model.parameters(), lr=lr) self.loss_fn = nn.CrossEntropyLoss() - if "non_iid" in self.config["exp_type"]: + if "non_iid" in config["exp_type"]: perm = torch.randperm(10) sp = [(0, 2), (2, 4)] - self.c_dset = extr_noniid(train_dataset, - config["samples_per_user"], - perm[sp[rank - 1][0]:sp[rank - 1][1]]) #TODO: Not clear if rank is the correct index + self.c_dset = extr_noniid(train_dataset, config["samples_per_user"], perm[sp[rank - 1][0]:sp[rank - 1][1]]) else: - # rank-1 because rank 0 is the server - self.c_dset = Subset(train_dataset, indices[( - rank - 1) * self.samples_per_client:rank * self.samples_per_client]) + self.c_dset = Subset(train_dataset, indices[(rank - 1) * self.samples_per_client:rank * self.samples_per_client]) self.c_dloader = DataLoader(self.c_dset, batch_size=batch_size) - - -# class WebObj(): -# def __init__(self, config, obj, rank) -> None: -# """ The purpose of this class is to bootstrap the objects for the whole distributed training -# setup -# """ -# self.num_users, self.samples_per_client = config["num_users"], config["samples_per_client"] -# self.device, self.device_id = obj["device"], obj["device_id"] -# train_dataset, test_dataset = obj["dset_obj"].train_dset, obj["dset_obj"].test_dset -# batch_size, lr = config["batch_size"], config["model_lr"] - -# # train_loader = DataLoader(train_dataset, batch_size=batch_size) -# self.test_loader = DataLoader(test_dataset, batch_size=batch_size) -# indices = np.random.permutation(len(train_dataset)) - -# optim = torch.optim.Adam -# self.c_models = [] -# self.c_optims = [] -# self.c_dsets = [] -# self.c_dloaders = [] - -# for i in range(self.num_users): -# model = ResNet34() -# if config["load_existing"]: -# model = load_weights(config["results_path"], model, i) -# c_model = nn.DataParallel(model.to(self.device), device_ids=self.device_ids) -# c_optim = optim(c_model.parameters(), lr=lr) -# if config["exp_type"].startswith("non_iid"): -# if i == 0: -# # only need to call this func once since it returns all user_groups -# user_groups_train, user_groups_test = cifar_extr_noniid(train_dataset, test_dataset, -# config["num_users"], config["class_per_client"], -# config["samples_per_client"], rate_unbalance=1) -# c_dset = Subset(train_dataset, user_groups_train[i].astype(int)) -# else: -# c_idx = indices[i*self.samples_per_client: (i+1)*self.samples_per_client] -# c_dset = Subset(train_dataset, c_idx) - -# c_dloader = DataLoader(c_dset, batch_size=batch_size*len(self.device_ids), shuffle=True) - -# self.c_models.append(c_model) -# self.c_optims.append(c_optim) -# self.c_dsets.append(c_dset) -# self.c_dloaders.append(c_dloader) -# print(f"Client {i} initialized") diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index b9957fd..3957f9d 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -1,20 +1,32 @@ +""" +This module provides utility functions and classes for handling logging, +copying source code, and normalizing images in a distributed learning setting. +""" + import os -import pickle import shutil import logging +import sys +from glob import glob +from shutil import copytree, copy2 +from PIL import Image import torch import torchvision.transforms as T from torchvision.utils import make_grid, save_image from tensorboardX import SummaryWriter -from shutil import copytree, copy2 -from glob import glob -from PIL import Image import numpy as np -# Normalize an image - def deprocess(img): + """ + Deprocesses an image tensor by normalizing it to the original range. + + Args: + img (torch.Tensor): Image tensor to deprocess. + + Returns: + torch.Tensor: Deprocessed image tensor. + """ inv_normalize = T.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225], @@ -25,85 +37,85 @@ def deprocess(img): def check_and_create_path(path): + """ + Checks if the specified path exists and prompts the user for action if it does. + Creates the directory if it does not exist. + + Args: + path (str): Path to check and create if necessary. + """ if os.path.isdir(path): - print("Experiment in {} already present".format(path)) + print(f"Experiment in {path} already present") done = False while not done: inp = input("Press e to exit, r to replace it: ") if inp == "e": - exit() + sys.exit() elif inp == "r": done = True shutil.rmtree(path) os.makedirs(path) else: print("Input not understood") - # exit() else: os.makedirs(path) def copy_source_code(config: dict) -> None: - """Copy source code to experiment folder - This happens only once at the start of the experiment - This is to ensure that the source code is snapshoted at the start of the experiment - for reproducibility purposes + """ + Copy source code to experiment folder for reproducibility. + Args: - config (dict): [description] + config (dict): Configuration dictionary with the results path. """ path = config["results_path"] print("exp path:", path) if config["load_existing"]: print("Continue with loading checkpoint") return - else: - # throw a prompt - check_and_create_path(path) - # the last folder is the path where all the expts are stored - denylist = [ - "./__pycache__/", - "./.ipynb_checkpoints/", - "./expt_dump/", - "./helper_scripts/", - "./imgs/", - "./expt_dump_old/", - "./comparison_plots/", - "./toy_exp/", - "./toy_exp_ml/", - "./toy_exp.py", - "./toy_exp_ml.py" "/".join(path.split("/")[:-1]) + "/", - ] - folders = glob(r"./*/") - print(denylist, folders) - - # For copying python files - for file_ in glob(r"./*.py"): - copy2(file_, path) - - # For copying json files - for file_ in glob(r"./*.json"): - copy2(file_, path) - - for folder in folders: - if folder not in denylist: - # Remove first char which is . due to the glob - copytree(folder, path + folder[1:]) - - # For saving models in the future - os.mkdir(config["saved_models"]) - os.mkdir(config["log_path"]) - print("source code copied to exp_dump") + check_and_create_path(path) + denylist = [ + "./__pycache__/", + "./.ipynb_checkpoints/", + "./expt_dump/", + "./helper_scripts/", + "./imgs/", + "./expt_dump_old/", + "./comparison_plots/", + "./toy_exp/", + "./toy_exp_ml/", + "./toy_exp.py", + "./toy_exp_ml.py", + "/".join(path.split("/")[:-1]) + "/", + ] + folders = glob(r"./*/") + print(denylist, folders) + + for file_ in glob(r"./*.py"): + copy2(file_, path) + for file_ in glob(r"./*.json"): + copy2(file_, path) + for folder in folders: + if folder not in denylist: + copytree(folder, path + folder[1:]) + os.mkdir(config["saved_models"]) + os.mkdir(config["log_path"]) + print("source code copied to exp_dump") class LogUtils: + """ + Utility class for logging and saving experiment data. + """ def __init__(self, config) -> None: - log_dir, load_existing = config["log_path"], config["load_existing"] + log_dir = config["log_path"] + load_existing = config["load_existing"] log_format = ( "%(asctime)s::%(levelname)s::%(name)s::" "%(filename)s::%(lineno)d::%(message)s" ) logging.basicConfig( - filename="{log_path}/log_console.log".format(log_path=log_dir), + filename=f"{log_dir}/log_console.log", level="DEBUG", format=log_format, ) @@ -114,42 +126,94 @@ def __init__(self, config) -> None: self.init_summary() def init_summary(self): - # Open a txt file to write summary - self.summary_file = open(f"{self.log_dir}/summary.txt", "w") + """ + Initialize summary file for logging. + """ + self.summary_file = open(f"{self.log_dir}/summary.txt", "w", encoding="utf-8") def init_tb(self, load_existing): - tb_path = self.log_dir + "/tensorboard" - # if not os.path.exists(tb_path) or not os.path.isdir(tb_path): + """ + Initialize TensorBoard logging. + + Args: + load_existing (bool): Whether to load existing logs. + """ + tb_path = f"{self.log_dir}/tensorboard" if not load_existing: os.makedirs(tb_path) self.writer = SummaryWriter(tb_path) def init_npy(self): - npy_path = self.log_dir + "/npy" + """ + Initialize directory for saving numpy arrays. + """ + npy_path = f"{self.log_dir}/npy" if not os.path.exists(npy_path) or not os.path.isdir(npy_path): os.makedirs(npy_path) def log_image(self, imgs: torch.Tensor, key, iteration): - # imgs = deprocess(imgs.detach().cpu())[:64] + """ + Log image to file and TensorBoard. + + Args: + imgs (torch.Tensor): Tensor of images to log. + key (str): Key for the logged image. + iteration (int): Current iteration number. + """ grid_img = make_grid(imgs.detach().cpu(), normalize=True, scale_each=True) - # Save the grid image using torchvision api save_image(grid_img, f"{self.log_dir}/{iteration}_{key}.png") - # Save the grid image using tensorboard api self.writer.add_image(key, grid_img.numpy(), iteration) def log_console(self, msg): + """ + Log a message to the console. + + Args: + msg (str): Message to log. + """ logging.info(msg) def log_tb(self, key, value, iteration): + """ + Log a scalar value to TensorBoard. + + Args: + key (str): Key for the logged value. + value (float): Value to log. + iteration (int): Current iteration number. + """ self.writer.add_scalar(key, value, iteration) def log_npy(self, key, value): + """ + Save a numpy array to file. + + Args: + key (str): Key for the saved array. + value (numpy.ndarray): Array to save. + """ np.save(f"{self.log_dir}/npy/{key}.npy", value) def log_max_stats_per_client(self, stats_per_client, round_step, metric): + """ + Log maximum statistics per client. + + Args: + stats_per_client (numpy.ndarray): Statistics for each client. + round_step (int): Step size for rounds. + metric (str): Metric being logged. + """ self.__log_stats_per_client__(stats_per_client, round_step, metric, is_max=True) def log_min_stats_per_client(self, stats_per_client, round_step, metric): + """ + Log minimum statistics per client. + + Args: + stats_per_client (numpy.ndarray): Statistics for each client. + round_step (int): Step size for rounds. + metric (str): Metric being logged. + """ self.__log_stats_per_client__( stats_per_client, round_step, metric, is_max=False ) @@ -157,6 +221,15 @@ def log_min_stats_per_client(self, stats_per_client, round_step, metric): def __log_stats_per_client__( self, stats_per_client, round_step, metric, is_max=False ): + """ + Internal method to log statistics per client. + + Args: + stats_per_client (numpy.ndarray): Statistics for each client. + round_step (int): Step size for rounds. + metric (str): Metric being logged. + is_max (bool): Whether to log maximum or minimum statistics. + """ if is_max: best_round_per_client = np.argmax(stats_per_client, axis=1) * round_step best_val_per_client = np.max(stats_per_client, axis=1) @@ -165,7 +238,6 @@ def __log_stats_per_client__( best_val_per_client = np.min(stats_per_client, axis=1) minmax = "max" if is_max else "min" - # Write to summary file self.summary_file.write( f"============== {minmax} {metric} per client ==============\n" ) @@ -173,25 +245,38 @@ def __log_stats_per_client__( zip(best_round_per_client, best_val_per_client) ): self.summary_file.write( - f"Client {client_idx+1} : {best_val} at round {best_round}\n" + f"Client {client_idx + 1} : {best_val} at round {best_round}\n" ) self.summary_file.write( f"Mean of {minmax} {metric} : {np.mean(best_val_per_client)}, quantiles: {np.quantile(best_val_per_client, [0.25, 0.75])}\n" ) - def log_tb_round_stats(self, round_stats, stats_to_exclude, round): + def log_tb_round_stats(self, round_stats, stats_to_exclude, current_round): + """ + Log round statistics to TensorBoard. + + Args: + round_stats (list): List of round statistics for each client. + stats_to_exclude (list): List of statistics keys to exclude from logging. + current_round (int): Current round number. + """ stats_key = round_stats[0].keys() for key in stats_key: if key not in stats_to_exclude: average = 0 for client_id, stats in enumerate(round_stats, 1): - self.log_tb(f"{key}/client{client_id}", stats[key], round) + self.log_tb(f"{key}/client{client_id}", stats[key], current_round) average += stats[key] average /= len(round_stats) - self.log_tb(f"{key}/clients", average, round) + self.log_tb(f"{key}/clients", average, current_round) - def log_experiments_stats(self, gloabl_stats): + def log_experiments_stats(self, global_stats): + """ + Log experiment statistics. + Args: + global_stats (dict): Dictionary of global statistics. + """ basic_stats = { "train_loss": "min", "train_acc": "max", @@ -202,7 +287,7 @@ def log_experiments_stats(self, gloabl_stats): "validation_acc": "max", } - for key, stats in gloabl_stats.items(): + for key, stats in global_stats.items(): if key == "round_step": continue self.log_npy(key.lower().replace(" ", "_"), stats)