1 package de.bwaldvogel.liblinear;
2
3 import static de.bwaldvogel.liblinear.Linear.atof;
4 import static de.bwaldvogel.liblinear.Linear.atoi;
5
6 import java.io.BufferedReader;
7 import java.io.File;
8 import java.io.FileReader;
9 import java.io.IOException;
10 import java.util.ArrayList;
11 import java.util.List;
12 import java.util.NoSuchElementException;
13 import java.util.StringTokenizer;
14
15
16 public class Train {
17
18 public static void main(String[] args) throws IOException, InvalidInputDataException {
19 new Train().run(args);
20 }
21
22 private double bias = 1;
23 private boolean cross_validation = false;
24 private String inputFilename;
25 private String modelFilename;
26 private int nr_fold;
27 private Parameter param = null;
28 private Problem prob = null;
29
30 private void do_cross_validation() {
31 int[] target = new int[prob.l];
32
33 long start, stop;
34 start = System.currentTimeMillis();
35 Linear.crossValidation(prob, param, nr_fold, target);
36 stop = System.currentTimeMillis();
37 System.out.println("time: " + (stop - start) + " ms");
38
39 int total_correct = 0;
40 for (int i = 0; i < prob.l; i++)
41 if (target[i] == prob.y[i]) ++total_correct;
42
43 System.out.printf("correct: %d%n", total_correct);
44 System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l);
45 }
46
47 private void exit_with_help() {
48 System.out.printf("Usage: train [options] training_set_file [model_file]%n"
49 + "options:%n"
50 + "-s type : set type of solver (default 1)%n"
51 + " 0 -- L2-regularized logistic regression (primal)%n"
52 + " 1 -- L2-regularized L2-loss support vector classification (dual)%n"
53 + " 2 -- L2-regularized L2-loss support vector classification (primal)%n"
54 + " 3 -- L2-regularized L1-loss support vector classification (dual)%n"
55 + " 4 -- multi-class support vector classification by Crammer and Singer%n"
56 + " 5 -- L1-regularized L2-loss support vector classification%n"
57 + " 6 -- L1-regularized logistic regression%n"
58 + " 7 -- L2-regularized logistic regression (dual)%n"
59 + "-c cost : set the parameter C (default 1)%n"
60 + "-e epsilon : set tolerance of termination criterion%n"
61 + " -s 0 and 2%n"
62 + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n"
63 + " where f is the primal function and pos/neg are # of%n"
64 + " positive/negative data (default 0.01)%n"
65 + " -s 1, 3, 4 and 7%n"
66 + " Dual maximal violation <= eps; similar to libsvm (default 0.1)%n"
67 + " -s 5 and 6%n"
68 + " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n"
69 + " where f is the primal function (default 0.01)%n"
70 + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n"
71 + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n"
72 + "-v n: n-fold cross validation mode%n"
73 + "-q : quiet mode (no outputs)%n");
74 System.exit(1);
75 }
76
77
78 Problem getProblem() {
79 return prob;
80 }
81
82 double getBias() {
83 return bias;
84 }
85
86 Parameter getParameter() {
87 return param;
88 }
89
90 void parse_command_line(String argv[]) {
91 int i;
92
93
94 param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY);
95
96 bias = -1;
97 cross_validation = false;
98
99 int nr_weight = 0;
100
101
102 for (i = 0; i < argv.length; i++) {
103 if (argv[i].charAt(0) != '-') break;
104 if (++i >= argv.length) exit_with_help();
105 switch (argv[i - 1].charAt(1)) {
106 case 's':
107 param.solverType = SolverType.values()[atoi(argv[i])];
108 break;
109 case 'c':
110 param.setC(atof(argv[i]));
111 break;
112 case 'e':
113 param.setEps(atof(argv[i]));
114 break;
115 case 'B':
116 bias = atof(argv[i]);
117 break;
118 case 'w':
119 ++nr_weight;
120 {
121 int[] old = param.weightLabel;
122 param.weightLabel = new int[nr_weight];
123 System.arraycopy(old, 0, param.weightLabel, 0, nr_weight - 1);
124 }
125
126 {
127 double[] old = param.weight;
128 param.weight = new double[nr_weight];
129 System.arraycopy(old, 0, param.weight, 0, nr_weight - 1);
130 }
131
132 param.weightLabel[nr_weight - 1] = atoi(argv[i - 1].substring(2));
133 param.weight[nr_weight - 1] = atof(argv[i]);
134 break;
135 case 'v':
136 cross_validation = true;
137 nr_fold = atoi(argv[i]);
138 if (nr_fold < 2) {
139 System.err.println("n-fold cross validation: n must >= 2");
140 exit_with_help();
141 }
142 break;
143 case 'q':
144 Linear.disableDebugOutput();
145 break;
146 default:
147 System.err.println("unknown option");
148 exit_with_help();
149 }
150 }
151
152
153
154 if (i >= argv.length) exit_with_help();
155
156 inputFilename = argv[i];
157
158 if (i < argv.length - 1)
159 modelFilename = argv[i + 1];
160 else {
161 int p = argv[i].lastIndexOf('/');
162 ++p;
163 modelFilename = argv[i].substring(p) + ".model";
164 }
165
166 if (param.eps == Double.POSITIVE_INFINITY) {
167 if (param.solverType == SolverType.L2R_LR || param.solverType == SolverType.L2R_L2LOSS_SVC) {
168 param.setEps(0.01);
169 } else if (param.solverType == SolverType.L2R_L2LOSS_SVC_DUAL || param.solverType == SolverType.L2R_L1LOSS_SVC_DUAL
170 || param.solverType == SolverType.MCSVM_CS || param.solverType == SolverType.L2R_LR_DUAL) {
171 param.setEps(0.1);
172 } else if (param.solverType == SolverType.L1R_L2LOSS_SVC || param.solverType == SolverType.L1R_LR) {
173 param.setEps(0.01);
174 }
175 }
176 }
177
178
179
180
181
182
183
184 public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException {
185 BufferedReader fp = new BufferedReader(new FileReader(file));
186 List<Integer> vy = new ArrayList<Integer>();
187 List<FeatureNode[]> vx = new ArrayList<FeatureNode[]>();
188 int max_index = 0;
189
190 int lineNr = 0;
191
192 try {
193 while (true) {
194 String line = fp.readLine();
195 if (line == null) break;
196 lineNr++;
197
198 StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
199 String token;
200 try {
201 token = st.nextToken();
202 } catch (NoSuchElementException e) {
203 throw new InvalidInputDataException("empty line", file, lineNr, e);
204 }
205
206 try {
207 vy.add(atoi(token));
208 } catch (NumberFormatException e) {
209 throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e);
210 }
211
212 int m = st.countTokens() / 2;
213 FeatureNode[] x;
214 if (bias >= 0) {
215 x = new FeatureNode[m + 1];
216 } else {
217 x = new FeatureNode[m];
218 }
219 int indexBefore = 0;
220 for (int j = 0; j < m; j++) {
221
222 token = st.nextToken();
223 int index;
224 try {
225 index = atoi(token);
226 } catch (NumberFormatException e) {
227 throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e);
228 }
229
230
231 if (index < 0) throw new InvalidInputDataException("invalid index: " + index, file, lineNr);
232 if (index <= indexBefore) throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr);
233 indexBefore = index;
234
235 token = st.nextToken();
236 try {
237 double value = atof(token);
238 x[j] = new FeatureNode(index, value);
239 } catch (NumberFormatException e) {
240 throw new InvalidInputDataException("invalid value: " + token, file, lineNr);
241 }
242 }
243 if (m > 0) {
244 max_index = Math.max(max_index, x[m - 1].index);
245 }
246
247 vx.add(x);
248 }
249
250 return constructProblem(vy, vx, max_index, bias);
251 }
252 finally {
253 fp.close();
254 }
255 }
256
257 void readProblem(String filename) throws IOException, InvalidInputDataException {
258 prob = Train.readProblem(new File(filename), bias);
259 }
260
261 private static Problem constructProblem(List<Integer> vy, List<FeatureNode[]> vx, int max_index, double bias) {
262 Problem prob = new Problem();
263 prob.bias = bias;
264 prob.l = vy.size();
265 prob.n = max_index;
266 if (bias >= 0) {
267 prob.n++;
268 }
269 prob.x = new FeatureNode[prob.l][];
270 for (int i = 0; i < prob.l; i++) {
271 prob.x[i] = vx.get(i);
272
273 if (bias >= 0) {
274 assert prob.x[i][prob.x[i].length - 1] == null;
275 prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias);
276 }
277 }
278
279 prob.y = new int[prob.l];
280 for (int i = 0; i < prob.l; i++)
281 prob.y[i] = vy.get(i);
282
283 return prob;
284 }
285
286 private void run(String[] args) throws IOException, InvalidInputDataException {
287 parse_command_line(args);
288 readProblem(inputFilename);
289 if (cross_validation)
290 do_cross_validation();
291 else {
292 Model model = Linear.train(prob, param);
293 Linear.saveModel(new File(modelFilename), model);
294 }
295 }
296 }