Skip to content

Commit

Permalink
Merge pull request #6 from dasdebanna/linting
Browse files Browse the repository at this point in the history
Updated
  • Loading branch information
tremblerz authored Jul 7, 2024
2 parents f8b4a2a + b56329b commit c7db5eb
Show file tree
Hide file tree
Showing 11 changed files with 1,053 additions and 1,122 deletions.
222 changes: 108 additions & 114 deletions src/algos/DisPFL.py

Large diffs are not rendered by default.

133 changes: 56 additions & 77 deletions src/algos/def_kt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from collections import OrderedDict
from typing import Any, Dict, List
from torch import Tensor, cat
"""
This module defines the DefKTClient and DefKTServer classes for federated learning using a knowledge
transfer approach.
"""

import copy
import torch.nn as nn
import random
from collections import OrderedDict
from typing import List
from torch import Tensor
import torch.nn as nn

from algos.base_class import BaseClient, BaseServer


class CommProtocol(object):
class CommProtocol:
"""
Communication protocol tags for the server and clients
"""
Expand All @@ -20,14 +25,16 @@ class CommProtocol(object):


class DefKTClient(BaseClient):
"""
Client class for DefKT (Deep Mutual Learning with Knowledge Transfer)
"""
def __init__(self, config) -> None:
super().__init__(config)
self.config = config
self.tag = CommProtocol
self.model_save_path = "{}/saved_models/node_{}.pt".format(
self.config["results_path"], self.node_id
)
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.server_node = 1 # leader node
self.best_acc = 0.0 # Initialize best accuracy attribute
if self.node_id == 1:
self.num_users = config["num_users"]
self.clients = list(range(2, self.num_users + 1))
Expand All @@ -39,8 +46,6 @@ def local_train(self):
avg_loss = self.model_utils.train(
self.model, self.optim, self.dloader, self.loss_fn, self.device
)
# print("Client {} finished training with loss {}".format(self.node_id, avg_loss))
# self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch)

def local_test(self, **kwargs):
"""
Expand All @@ -49,20 +54,18 @@ def local_test(self, **kwargs):
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
# TODO save the model if the accuracy is better than the best accuracy
# so far
if acc > self.best_acc:
self.best_acc = acc
self.model_utils.save_model(self.model, self.model_save_path)
return acc

def deep_mutual_train(self, teacher_repr):
"""
Train the model locally
Train the model locally with deep mutual learning
"""
teacher_model = copy.deepcopy(self.model)
teacher_model.load_state_dict(teacher_repr)
print("Deep mutual learning at student Node {}".format(self.node_id))
print(f"Deep mutual learning at student Node {self.node_id}")
avg_loss, acc = self.model_utils.deep_mutual_train(
[self.model, teacher_model], self.optim, self.dloader, self.device
)
Expand All @@ -80,9 +83,9 @@ def set_representation(self, representation: OrderedDict[str, Tensor]):
self.model.load_state_dict(representation)

def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
# All models are sampled currently at every round
# Each model is assumed to have equal amount of data and hence
# coeff is same for everyone
"""
Federated averaging of model weights
"""
num_users = len(model_wts)
coeff = 1 / num_users
avgd_wts = OrderedDict()
Expand All @@ -106,26 +109,28 @@ def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]):

def send_representations(self, representation):
"""
Set the model
Send the model representations to the clients
"""
for client_node in self.clients:
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
print("Node 1 sent average weight to {} nodes".format(len(self.clients)))
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def single_round(self, self_repr):
"""
Runs the whole training procedure
Runs a single training round
"""
print("Node 1 waiting for all clients to finish")
reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.DONE)
reprs.append(self_repr)
print("Node 1 received {} clients' weights".format(len(reprs)))
print(f"Node 1 received {len(reprs)} clients' weights")
avg_wts = self.aggregate(reprs)
self.send_representations(avg_wts)
return avg_wts
# wait for all clients to finish

def assign_own_status(self, status):
"""
Assign the status (teacher/student) to the client
"""
if self.node_id in status[0]:
self.status = "teacher"
index = status[0].index(self.node_id)
Expand All @@ -137,81 +142,58 @@ def assign_own_status(self, status):
else:
self.status = None
self.pair_id = None
print(
"Node {} is a {}, pair with {}".format(
self.node_id, self.status, self.pair_id
)
)
print(f"Node {self.node_id} is a {self.status}, pair with {self.pair_id}")

def run_protocol(self):
"""
Runs the entire training protocol
"""
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for round in range(start_epochs, total_epochs):
# self.log_utils.logging.info("Client waiting for semaphore from {}".format(self.server_node))
# print("Client waiting for semaphore from {}".format(self.server_node))
for epoch in range(start_epochs, total_epochs):
status = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
self.assign_own_status(status)
# print("semaphore received, start local training")
# self.log_utils.logging.info("Client received semaphore from {}".format(self.server_node))
if self.status == "teacher":
self.local_train()
# self.local_test()
self_repr = self.get_representation()
self.comm_utils.send_signal(
dest=self.pair_id, data=self_repr, tag=self.tag.DONE
)
print(
"Node {} sent repr to student node {}".format(
self.node_id, self.pair_id
)
)
# self.log_utils.logging.info("Client {} sending done signal to {}".format(self.node_id, self.server_node))
# print("sending signal to node {}".format(self.server_node))
print(f"Node {self.node_id} sent repr to student node {self.pair_id}")
elif self.status == "student":
teacher_repr = self.comm_utils.wait_for_signal(
src=self.pair_id, tag=self.tag.DONE
)
print(
"Node {} received repr from teacher node {}".format(
self.node_id, self.pair_id
)
)
print(f"Node {self.node_id} received repr from teacher node {self.pair_id}")
self.deep_mutual_train(teacher_repr)
else:
# self.comm_utils.send_signal(dest=self.server_node, data=self_repr, tag=self.tag.DONE)
print("Node {} do nothing".format(self.node_id))
# repr = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.UPDATES)

# self.set_representation(repr)
# test updated model
print(f"Node {self.node_id} do nothing")
acc = self.local_test()
print("Node {} test_acc:{:.4f}".format(self.node_id, acc))
print(f"Node {self.node_id} test_acc:{acc:.4f}")
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)


class DefKTServer(BaseServer):
"""
Server class for DefKT (Deep Mutual Learning with Knowledge Transfer)
"""
def __init__(self, config) -> None:
super().__init__(config)
# self.set_parameters()
self.config = config
self.set_model_parameters(config)
self.tag = CommProtocol
self.model_save_path = "{}/saved_models/node_{}.pt".format(
self.config["results_path"], self.node_id
)
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.best_acc = 0.0 # Initialize best accuracy attribute

def send_representations(self, representations):
"""
Set the model
Send the model representations to the clients
"""
for client_node in self.users:
self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES)
self.log_utils.log_console(
"Server sent {} representations to node {}".format(
len(representations), client_node
)
f"Server sent {len(representations)} representations to node {client_node}"
)
# self.model.load_state_dict(representation)

def test(self) -> float:
"""
Expand All @@ -220,52 +202,49 @@ def test(self) -> float:
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
# TODO save the model if the accuracy is better than the best accuracy
# so far
if acc > self.best_acc:
self.best_acc = acc
self.model_utils.save_model(self.model, self.model_save_path)
return acc

def assigns_clients(self):
"""
Assigns clients as teachers and students
"""
num_teachers = self.config["num_teachers"]
clients = list(range(1, self.num_users + 1))
if 2 * num_teachers > self.num_users:
return None # Not enough room to pick two non-overlapping subarrays

# Pick the starting index of the first subarray
selected_indices = random.sample(range(self.num_users), 2 * num_teachers)
selected_elements = [clients[i] for i in selected_indices]

# Divide the selected elements into two arrays of length num_teachers
teachers = selected_elements[:num_teachers]
students = selected_elements[num_teachers:]

return teachers, students

def single_round(self):
"""
Runs the whole training procedure
Runs a single training round
"""
teachers, students = self.assigns_clients() # type: ignore
self.log_utils.log_console("Teachers:{}".format(teachers))
self.log_utils.log_console("Students:{}".format(students))
self.log_utils.log_console(f"Teachers:{teachers}")
self.log_utils.log_console(f"Students:{students}")
for client_node in self.users:
self.log_utils.log_console(
"Server sending status from {} to {}".format(self.node_id, client_node)
f"Server sending status from {self.node_id} to {client_node}"
)
self.comm_utils.send_signal(
dest=client_node, data=[teachers, students], tag=self.tag.START
)

def run_protocol(self):
"""
Runs the entire training protocol
"""
self.log_utils.log_console("Starting iid clients federated averaging")
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for round in range(start_epochs, total_epochs):
self.round = round
self.log_utils.log_console("Starting round {}".format(round))
for epoch in range(start_epochs, total_epochs):
self.log_utils.log_console(f"Starting round {epoch}")
self.single_round()

accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH)
self.log_utils.log_console("Round {} done; acc {}".format(round, accs))
self.log_utils.log_console(f"Round {epoch} done; acc {accs}")
Loading

0 comments on commit c7db5eb

Please sign in to comment.