-
Notifications
You must be signed in to change notification settings - Fork 0
/
LabelPredictor.java
111 lines (90 loc) · 3.26 KB
/
LabelPredictor.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import java.io.*;
import java.lang.*;
import java.util.*;
import weka.classifiers.*;
import weka.classifiers.lazy.*;
import weka.classifiers.functions.*;
import weka.core.*;
import java.util.*;
public class LabelPredictor{
private MultipleFeatureInstances train, test;
private ArrayList<Integer> testIndex;
private ArrayList<ArrayList<Double>> auxtrain, auxtest;
private MCSClassifier MCS;
public LabelPredictor() throws Exception{
train = test = null;
ArrayList<AbstractClassifier> classifiers = new ArrayList<AbstractClassifier>();
classifiers.add(new IBk()); // kNN
ArrayList<AbstractDiversityMeasure> dm = new ArrayList<AbstractDiversityMeasure>();
dm.add(new QStatistic());
dm.add(new DoubleFaultMeasure());
dm.add(new DisagreementMeasure());
dm.add(new CorrelationCoefficient());
dm.add(new InterraterAgreement());
AverageAccuracyMean metrics = new AverageAccuracyMean();
Concensus concensus = new Concensus(dm, 100, 6, metrics);
SMO svm = new SMO(); // SVM
MCSClassifier MCS = new MCSClassifier(classifiers, concensus, 75, svm);
testIndex=null;
}
public void addFeature(ArrayList<ArrayList<Double>> array){
int i, j, flag=0;
auxtrain=auxtest=null;
if(testIndex==null){
testIndex = new ArrayList<Integer>();
flag=1;
}
for(i=0;i<array.size();i++){
if(array.get(i).get(array.get(i).size()-1) == -1){
auxtest.add(new ArrayList<Double>(array.get(i)));
if(flag==1)
testIndex.add(i);
}else{
auxtrain.add(new ArrayList<Double>(array.get(i)));
}
}
ArrayList<Attribute> attributes = new ArrayList<Attribute>();
double[] instanceValue = new double[auxtrain.get(0).size()];
for(i=0;i<auxtrain.get(0).size()-1;i++)
attributes.add(new Attribute("attribute" + i));
attributes.add(new Attribute("class"));
Instances trainInstances = new Instances("train dataset", attributes, auxtrain.size());
Instances testInstances = new Instances("test dataset", attributes, auxtest.size());
DenseInstance inst;
for(i=0;i<auxtrain.size();i++){
inst = new DenseInstance(attributes.size());
trainInstances.add(inst);
for(j=0;j<auxtrain.get(0).size();j++){
trainInstances.get(i).setValue(j, auxtrain.get(i).get(j));
}
}
for(i=0;i<auxtest.size();i++){
inst = new DenseInstance(attributes.size());
testInstances.add(inst);
for(j=0;j<auxtest.get(0).size();j++){
testInstances.get(i).setValue(j, auxtest.get(i).get(j));
}
}
if(train == null){
train = new MultipleFeatureInstances(new ArrayList<Instances>(Arrays.asList(trainInstances)));
test = new MultipleFeatureInstances(new ArrayList<Instances>(Arrays.asList(testInstances)));
}else{
train.addFeature(trainInstances);
test.addFeature(testInstances);
}
}
public ArrayList<Double> getLabels() throws Exception{
int i, j=0;
MCS.buildClassifier(train);
ArrayList<Double> r = new ArrayList<Double>();
for (i=0;i<train.size()+test.size();i++){
if(i==testIndex.get(j)){
r.add(MCS.classifyInstance(test.instance(j)));
j++;
}else{
r.add(train.instance(i).classValue());
}
}
return r;
}
}