Skip to content

Commit

Permalink
fix train and test methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamjajoo committed Jun 12, 2024
1 parent 8de0b39 commit 9190bbd
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/algos/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def local_train(self):
"""
Train the model locally
"""
avg_loss = self.model_utils.train(self.model, self.optim,
avg_loss, acc = self.model_utils.train(self.model, self.optim,
self.dloader, self.loss_fn,
self.device)
# print("Client {} finished training with loss {}".format(self.node_id, avg_loss))
print("Client {} finished training with loss {}".format(self.node_id, avg_loss))
return acc
# self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch)

def local_test(self, **kwargs):
Expand Down Expand Up @@ -118,15 +119,17 @@ def single_round(self,self_repr):
def run_protocol(self):
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
test_accs = np.empty((self.num_clients, total_epochs)) # Transpose the shape
train_accs = np.empty((self.num_clients, total_epochs)) # Transpose the shape

for round in range(start_epochs, total_epochs):
#self.log_utils.logging.info("Client waiting for semaphore from {}".format(self.server_node))
#print("Client waiting for semaphore from {}".format(self.server_node))
self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
#print("semaphore received, start local training")
# self.log_utils.logging.info("Client received semaphore from {}".format(self.server_node))
self.local_train()
train_acc = self.local_train()
train_accs[self.node_id-1, round] = train_acc
np.save('./train_accs.npy', train_accs)
#self.local_test()
self_repr = self.get_representation()
# self.log_utils.logging.info("Client {} sending done signal to {}".format(self.node_id, self.server_node))
Expand All @@ -144,8 +147,6 @@ def run_protocol(self):
acc = self.local_test()
print("Node {} test_acc:{:.4f}".format(self.node_id, acc))
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)
test_accs[self.node_id-1, round] = acc
np.save('./test_accs.npy', test_accs)

class SWARMServer(BaseServer):
def __init__(self, config) -> None:
Expand Down Expand Up @@ -211,4 +212,4 @@ def run_protocol(self):
for i, acc in enumerate(accs):
train_accs[i, round] = acc

np.save('./train_accs.npy', train_accs)
np.save('./test_accs.npy', train_accs)

0 comments on commit 9190bbd

Please sign in to comment.