1 package liblinear;
2
3 import static liblinear.Linear.NL;
4 import static liblinear.Linear.atof;
5 import static liblinear.Linear.atoi;
6 import static liblinear.Linear.closeQuietly;
7 import static liblinear.Linear.printf;
8
9 import java.io.BufferedReader;
10 import java.io.BufferedWriter;
11 import java.io.File;
12 import java.io.FileInputStream;
13 import java.io.FileOutputStream;
14 import java.io.IOException;
15 import java.io.InputStreamReader;
16 import java.io.OutputStreamWriter;
17 import java.io.Writer;
18 import java.util.ArrayList;
19 import java.util.Formatter;
20 import java.util.List;
21 import java.util.StringTokenizer;
22 import java.util.regex.Pattern;
23
24
25 public class Predict {
26
27 private static boolean flag_predict_probability = false;
28
29 private static final Pattern COLON = Pattern.compile(":");
30
31
32
33
34 static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException {
35 int correct = 0;
36 int total = 0;
37
38 int nr_class = model.getNrClass();
39 double[] prob_estimates = null;
40 int n;
41 int nr_feature = model.getNrFeature();
42 if (model.bias >= 0)
43 n = nr_feature + 1;
44 else
45 n = nr_feature;
46
47 Formatter out = new Formatter(writer);
48
49 if (flag_predict_probability) {
50 if (model.solverType != SolverType.L2R_LR) {
51 throw new IllegalArgumentException("probability output is only supported for logistic regression");
52 }
53
54 int[] labels = model.getLabels();
55 prob_estimates = new double[nr_class];
56
57 printf(out, "labels");
58 for (int j = 0; j < nr_class; j++)
59 printf(out, " %d", labels[j]);
60 printf(out, "\n");
61 }
62
63
64 String line = null;
65 while ((line = reader.readLine()) != null) {
66 List<FeatureNode> x = new ArrayList<FeatureNode>();
67 StringTokenizer st = new StringTokenizer(line, " \t");
68 String label = st.nextToken();
69 int target_label = atoi(label);
70
71 while (st.hasMoreTokens()) {
72 String[] split = COLON.split(st.nextToken(), 2);
73 if (split == null || split.length < 2) exit_input_error(total + 1);
74
75 try {
76 int idx = atoi(split[0]);
77 double val = atof(split[1]);
78
79
80 if (idx <= nr_feature) {
81 FeatureNode node = new FeatureNode(idx, val);
82 x.add(node);
83 }
84 } catch (NumberFormatException e) {
85 exit_input_error(total + 1, e);
86 }
87 }
88
89 if (model.bias >= 0) {
90 FeatureNode node = new FeatureNode(n, model.bias);
91 x.add(node);
92 }
93
94 FeatureNode[] nodes = new FeatureNode[x.size()];
95 nodes = x.toArray(nodes);
96
97 int predict_label;
98
99 if (flag_predict_probability) {
100 predict_label = Linear.predictProbability(model, nodes, prob_estimates);
101 printf(out, "%d", predict_label);
102 for (int j = 0; j < model.nr_class; j++)
103 printf(out, " %g", prob_estimates[j]);
104 printf(out, "\n");
105 } else {
106 predict_label = Linear.predict(model, nodes);
107 printf(out, "%d\n", predict_label);
108 }
109
110 if (predict_label == target_label) {
111 ++correct;
112 }
113 ++total;
114 }
115 System.out.printf("Accuracy = %g%% (%d/%d)" + NL, (double)correct / total * 100, correct, total);
116 }
117
118 private static void exit_input_error(int line_num, Throwable cause) {
119 throw new RuntimeException("Wrong input format at line " + line_num, cause);
120 }
121
122 private static void exit_input_error(int line_num) {
123 throw new RuntimeException("Wrong input format at line " + line_num);
124 }
125
126 private static void exit_with_help() {
127 System.out.println("Usage: predict [options] test_file model_file output_file" + NL
128 + "options:" + NL
129 + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0)" + NL
130 );
131 System.exit(1);
132 }
133
134 public static void main(String[] argv) throws IOException {
135 int i;
136
137
138 for (i = 0; i < argv.length; i++) {
139 if (argv[i].charAt(0) != '-') break;
140 ++i;
141 switch (argv[i - 1].charAt(1)) {
142 case 'b':
143 try {
144 flag_predict_probability = (atoi(argv[i]) != 0);
145 } catch (NumberFormatException e) {
146 exit_with_help();
147 }
148 break;
149
150 default:
151 System.err.println("unknown option: -" + argv[i - 1].charAt(1) + NL);
152 exit_with_help();
153 break;
154 }
155 }
156 if (i >= argv.length || argv.length <= i + 2) {
157 exit_with_help();
158 }
159
160 BufferedReader reader = null;
161 Writer writer = null;
162 try {
163 reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), Linear.FILE_CHARSET));
164 writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), Linear.FILE_CHARSET));
165
166 Model model = Linear.loadModel(new File(argv[i + 1]));
167 doPredict(reader, writer, model);
168 }
169 finally {
170 closeQuietly(reader);
171 closeQuietly(writer);
172 }
173 }
174 }