-
Notifications
You must be signed in to change notification settings - Fork 10
/
calculate_log.py
94 lines (90 loc) · 3.71 KB
/
calculate_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import print_function
import numpy as np
import numpy as np
def get_curve(dir_name, stypes=['Baseline', 'Gaussian_LDA']):
tp, fp = dict(), dict()
tnr_at_tpr95 = dict()
for stype in stypes:
known = np.loadtxt('{}/confidence_{}_In.txt'.format(dir_name, stype), delimiter='\n')
novel = np.loadtxt('{}/confidence_{}_Out.txt'.format(dir_name, stype), delimiter='\n')
known.sort()
novel.sort()
#end = np.max([np.max(known), np.max(novel)])
#start = np.min([np.min(known),np.min(novel)])
num_k = known.shape[0]
num_n = novel.shape[0]
tp[stype] = -np.ones([num_k+num_n+1], dtype=int)
fp[stype] = -np.ones([num_k+num_n+1], dtype=int)
tp[stype][0], fp[stype][0] = num_k, num_n
k, n = 0, 0
for l in range(num_k+num_n):
if k == num_k:
tp[stype][l+1:] = tp[stype][l]
fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1)
break
elif n == num_n:
tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1)
fp[stype][l+1:] = fp[stype][l]
break
else:
if novel[n] < known[k]:
n += 1
tp[stype][l+1] = tp[stype][l]
fp[stype][l+1] = fp[stype][l] - 1
else:
k += 1
tp[stype][l+1] = tp[stype][l] - 1
fp[stype][l+1] = fp[stype][l]
tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin()
tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n
return tp, fp, tnr_at_tpr95
def metric(dir_name, stypes=['Bas', 'Gau'], verbose=False):
tp, fp, tnr_at_tpr95 = get_curve(dir_name, stypes)
results = dict()
mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
if verbose:
print(' ', end='')
for mtype in mtypes:
print(' {mtype:6s}'.format(mtype=mtype), end='')
print('')
for stype in stypes:
if verbose:
print('{stype:5s} '.format(stype=stype), end='')
results[stype] = dict()
# TNR
mtype = 'TNR'
results[stype][mtype] = tnr_at_tpr95[stype]
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUROC
mtype = 'AUROC'
tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]])
fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]])
results[stype][mtype] = -np.trapz(1.-fpr, tpr)
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# DTACC
mtype = 'DTACC'
results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max()
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUIN
mtype = 'AUIN'
denom = tp[stype]+fp[stype]
denom[denom == 0.] = -1.
pin_ind = np.concatenate([[True], denom > 0., [True]])
pin = np.concatenate([[.5], tp[stype]/denom, [0.]])
results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUOUT
mtype = 'AUOUT'
denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype]
denom[denom == 0.] = -1.
pout_ind = np.concatenate([[True], denom > 0., [True]])
pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]])
results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind])
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
print('')
return results