-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
70 lines (54 loc) · 1.79 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import torch
import torch.optim as optim
# K-fold cross validataion
# Developer: Alejandro Debus
# Email: aledebus@gmail.com
def partitions(number, k):
'''
Distribution of the folds
Args:
number: number of patients
k: folds number
'''
n_partitions = np.ones(k) * int(number/k)
n_partitions[0:(number % k)] += 1
return n_partitions
def get_indices(n_splits, subjects, frames=1):
'''
Indices of the set test
Args:
n_splits: folds number
subjects: number of patients
frames: length of the sequence of each patient
'''
l = partitions(subjects, n_splits)
fold_sizes = l * frames
indices = np.arange(subjects * frames).astype(int)
current = 0
for fold_size in fold_sizes:
start = current
stop = current + fold_size
current = stop
yield(indices[int(start):int(stop)])
def k_folds(n_splits, subjects, frames=1):
'''
Generates folds for cross validation
Args:
n_splits: folds number
subjects: number of patients
frames: length of the sequence of each patient
'''
indices = np.arange(subjects * frames).astype(int)
for test_idx in get_indices(n_splits, subjects, frames):
train_idx = np.setdiff1d(indices, test_idx)
yield train_idx, test_idx
#######################################################################
def _create_optimizer(opt, model):
lr = opt.lr
optimizer = optim.Adam(model.parameters(), lr=lr)
return optimizer
def _create_scheduler(optimizer, milestones):
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=patience, verbose=True)
return scheduler