-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_OFCL.py
93 lines (81 loc) · 4.16 KB
/
main_OFCL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
from configuration import config_FCL
from utils.data_loader import get_loader_all_clients
from utils.train_utils import get_free_gpu_idx, get_logger, initialize_clients, FedAvg, weightedFedAvg, test_global_model, save_results
from datetime import datetime
args = config_FCL.base_parser()
logger = get_logger(args)
if torch.cuda.is_available():
gpu_idx = get_free_gpu_idx()
args.cuda = True
args.device = f'cuda:{gpu_idx}'
else:
args.device = 'cpu'
print(args)
for run in range(args.n_runs):
loader_clients, cls_assignment_list, global_test_loader = get_loader_all_clients(args, run)
clients = initialize_clients(args, loader_clients, cls_assignment_list, run)
start_time = datetime.now()
while not all([client.train_completed for client in clients]):
for client in clients:
if not client.train_completed:
samples, labels = client.get_next_batch()
if samples is not None:
if args.with_memory:
if client.task_id == 0:
client.train_with_update(samples, labels)
else:
client.train_with_memory(samples, labels)
else:
client.train(samples, labels)
else:
print(f'Run {run} - Client {client.client_id} - Task {client.task_id} completed - {client.get_current_task()}')
# compute loss train
logger = client.compute_loss(logger, run)
print(f'Run {run} - Client {client.client_id} - Test time - Task {client.task_id}')
logger = client.test(logger, run)
logger = client.validation(logger, run)
logger = client.forgetting(logger, run)
if client.task_id + 1 >= args.n_tasks:
client.train_completed = True
print(f'Run {run} - Client {client.client_id} - Train completed')
logger = client.balanced_accuracy(logger, run)
else:
client.task_id += 1
# COMMUNICATION ROUND PART
selected_clients = [client.client_id for client in clients if (client.num_batches >= args.burnin and client.num_batches % args.jump == 0 and client.train_completed == False)]
if len(selected_clients) > 1:
# communication round when all clients process a mini-batch
if args.fl_update.startswith('w_'):
global_model = weightedFedAvg(args, selected_clients, clients)
else:
global_model = FedAvg(args, selected_clients, clients)
global_parameters = global_model.state_dict()
# local models update with averaged global parameters
for client_id in selected_clients:
clients[client_id].save_last_local_model()
clients[client_id].update_parameters(global_parameters)
clients[client_id].save_last_global_model(global_model)
end_time = datetime.now()
print(f'Duration: {end_time - start_time}')
# global model accuracy when all clients finish their training on all tasks (FedCIL ICLR2023)
logger = test_global_model(args, global_test_loader, global_model, logger, run)
final_accs = []
final_fors = []
for client_id in range(args.n_clients):
print(f'Client {client_id}: {clients[client_id].task_list}')
print(np.mean(logger['test']['acc'][client_id], 0))
final_acc = np.mean(np.mean(logger["test"]["acc"][client_id], 0)[args.n_tasks-1,:], 0)
final_for = np.mean(logger["test"]["forget"][client_id])
final_accs.append(final_acc)
final_fors.append(final_for)
print(f'Final client accuracy: {final_acc}')
print(f'Final client forgetting: {final_for}')
print(f'Final client balanced accuracy: {np.mean(logger["test"]["bal_acc"][client_id])}')
print()
print(f'Final average accuracy: {np.mean(final_accs):0.4f} (+-) {np.std(final_accs):0.4f}')
print(f'Final average forgetting: {np.mean(final_fors):0.4f} (+-) {np.std(final_fors):0.4f}')
print()
# save training results
save_results(args, logger)