1 package liblinear;
2
3 import static liblinear.Linear.NL;
4 import static liblinear.Linear.atof;
5 import static liblinear.Linear.atoi;
6
7 import java.io.BufferedReader;
8 import java.io.File;
9 import java.io.FileReader;
10 import java.io.IOException;
11 import java.util.ArrayList;
12 import java.util.List;
13 import java.util.StringTokenizer;
14
15
16 public class Train {
17
18 public static void main(String[] args) throws IOException, InvalidInputDataException {
19 new Train().run(args);
20 }
21
22 private double bias = 1;
23 private boolean cross_validation = false;
24 private String inputFilename;
25 private String modelFilename;
26 private int nr_fold;
27 private Parameter param = null;
28 private Problem prob = null;
29
30 private void do_cross_validation() {
31 int[] target = new int[prob.l];
32
33 long start, stop;
34 start = System.currentTimeMillis();
35 Linear.crossValidation(prob, param, nr_fold, target);
36 stop = System.currentTimeMillis();
37 System.out.println("time: " + (stop - start) + " ms");
38
39 int total_correct = 0;
40 for (int i = 0; i < prob.l; i++)
41 if (target[i] == prob.y[i]) ++total_correct;
42
43 System.out.printf("correct: %d" + NL, total_correct);
44 System.out.printf("Cross Validation Accuracy = %g%%\n", 100.0 * total_correct / prob.l);
45 }
46
47 private void exit_with_help() {
48 System.out.println("Usage: train [options] training_set_file [model_file]" + NL
49 + "options:" + NL
50 + "-s type : set type of solver (default 1)" + NL
51 + " 0 -- L2-regularized logistic regression" + NL
52 + " 1 -- L2-regularized L2-loss support vector classification (dual)" + NL
53 + " 2 -- L2-regularized L2-loss support vector classification (primal)" + NL
54 + " 3 -- L2-regularized L1-loss support vector classification (dual)" + NL
55 + " 4 -- multi-class support vector classification by Crammer and Singer" + NL
56 + " 5 -- L1-regularized L2-loss support vector classification" + NL
57 + " 6 -- L1-regularized logistic regression" + NL
58 + "-c cost : set the parameter C (default 1)" + NL
59 + "-e epsilon : set tolerance of termination criterion" + NL
60 + " -s 0 and 2" + NL
61 + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2," + NL
62 + " where f is the primal function and pos/neg are # of" + NL
63 + " positive/negative data (default 0.01)" + NL
64 + " -s 1, 3, and 4" + NL
65 + " Dual maximal violation <= eps; similar to libsvm (default 0.1)" + NL
66 + " -s 5 and 6" + NL
67 + " |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf," + NL
68 + " where f is the primal function (default 0.01)" + NL
69 + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)" + NL
70 + "-wi weight: weights adjust the parameter C of different classes (see README for details)" + NL
71 + "-v n: n-fold cross validation mode" + NL
72 + "-q : quiet mode (no outputs)" + NL
73 );
74 System.exit(1);
75 }
76
77
78 Problem getProblem() {
79 return prob;
80 }
81
82 double getBias() {
83 return bias;
84 }
85
86 Parameter getParameter() {
87 return param;
88 }
89
90 void parse_command_line(String argv[]) {
91 int i;
92
93
94 param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY);
95
96 bias = -1;
97 cross_validation = false;
98
99 int nr_weight = 0;
100
101
102 for (i = 0; i < argv.length; i++) {
103 if (argv[i].charAt(0) != '-') break;
104 if (++i >= argv.length) exit_with_help();
105 switch (argv[i - 1].charAt(1)) {
106 case 's':
107 param.solverType = SolverType.values()[atoi(argv[i])];
108 break;
109 case 'c':
110 param.setC(atof(argv[i]));
111 break;
112 case 'e':
113 param.setEps(atof(argv[i]));
114 break;
115 case 'B':
116 bias = atof(argv[i]);
117 break;
118 case 'w':
119 ++nr_weight;
120 {
121 int[] old = param.weightLabel;
122 param.weightLabel = new int[nr_weight];
123 System.arraycopy(old, 0, param.weightLabel, 0, nr_weight - 1);
124 }
125
126 {
127 double[] old = param.weight;
128 param.weight = new double[nr_weight];
129 System.arraycopy(old, 0, param.weight, 0, nr_weight - 1);
130 }
131
132 param.weightLabel[nr_weight - 1] = atoi(argv[i - 1].substring(2));
133 param.weight[nr_weight - 1] = atof(argv[i]);
134 break;
135 case 'v':
136 cross_validation = true;
137 nr_fold = atoi(argv[i]);
138 if (nr_fold < 2) {
139 System.err.print("n-fold cross validation: n must >= 2\n");
140 exit_with_help();
141 }
142 break;
143 case 'q':
144 Linear.disableDebugOutput();
145 break;
146 default:
147 System.err.println("unknown option");
148 exit_with_help();
149 }
150 }
151
152
153
154 if (i >= argv.length) exit_with_help();
155
156 inputFilename = argv[i];
157
158 if (i < argv.length - 1)
159 modelFilename = argv[i + 1];
160 else {
161 int p = argv[i].lastIndexOf('/');
162 ++p;
163 modelFilename = argv[i].substring(p) + ".model";
164 }
165
166 if (param.eps == Double.POSITIVE_INFINITY) {
167 if (param.solverType == SolverType.L2R_LR || param.solverType == SolverType.L2R_L2LOSS_SVC) {
168 param.setEps(0.01);
169 } else if (param.solverType == SolverType.L2R_L2LOSS_SVC_DUAL || param.solverType == SolverType.L2R_L1LOSS_SVC_DUAL
170 || param.solverType == SolverType.MCSVM_CS) {
171 param.setEps(0.1);
172 } else if (param.solverType == SolverType.L1R_L2LOSS_SVC || param.solverType == SolverType.L1R_LR) {
173 param.setEps(0.01);
174 }
175 }
176 }
177
178
179
180
181
182
183
184 public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException {
185 BufferedReader fp = new BufferedReader(new FileReader(file));
186 List<Integer> vy = new ArrayList<Integer>();
187 List<FeatureNode[]> vx = new ArrayList<FeatureNode[]>();
188 int max_index = 0;
189
190 int lineNr = 0;
191
192 try {
193 while (true) {
194 String line = fp.readLine();
195 if (line == null) break;
196 lineNr++;
197
198 StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
199 String token = st.nextToken();
200
201 try {
202 vy.add(atoi(token));
203 } catch (NumberFormatException e) {
204 throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e);
205 }
206
207 int m = st.countTokens() / 2;
208 FeatureNode[] x;
209 if (bias >= 0) {
210 x = new FeatureNode[m + 1];
211 } else {
212 x = new FeatureNode[m];
213 }
214 int indexBefore = 0;
215 for (int j = 0; j < m; j++) {
216
217 token = st.nextToken();
218 int index;
219 try {
220 index = atoi(token);
221 } catch (NumberFormatException e) {
222 throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e);
223 }
224
225
226 if (index < 0) throw new InvalidInputDataException("invalid index: " + index, file, lineNr);
227 if (index <= indexBefore) throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr);
228 indexBefore = index;
229
230 token = st.nextToken();
231 try {
232 double value = atof(token);
233 x[j] = new FeatureNode(index, value);
234 } catch (NumberFormatException e) {
235 throw new InvalidInputDataException("invalid value: " + token, file, lineNr);
236 }
237 }
238 if (m > 0) {
239 max_index = Math.max(max_index, x[m - 1].index);
240 }
241
242 vx.add(x);
243 }
244
245 return constructProblem(vy, vx, max_index, bias);
246 }
247 finally {
248 fp.close();
249 }
250 }
251
252 void readProblem(String filename) throws IOException, InvalidInputDataException {
253 prob = Train.readProblem(new File(filename), bias);
254 }
255
256 private static Problem constructProblem(List<Integer> vy, List<FeatureNode[]> vx, int max_index, double bias) {
257 Problem prob = new Problem();
258 prob.bias = bias;
259 prob.l = vy.size();
260 prob.n = max_index;
261 if (bias >= 0) {
262 prob.n++;
263 }
264 prob.x = new FeatureNode[prob.l][];
265 for (int i = 0; i < prob.l; i++) {
266 prob.x[i] = vx.get(i);
267
268 if (bias >= 0) {
269 assert prob.x[i][prob.x[i].length - 1] == null;
270 prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias);
271 } else {
272 assert prob.x[i][prob.x[i].length - 1] != null;
273 }
274 }
275
276 prob.y = new int[prob.l];
277 for (int i = 0; i < prob.l; i++)
278 prob.y[i] = vy.get(i);
279
280 return prob;
281 }
282
283 private void run(String[] args) throws IOException, InvalidInputDataException {
284 parse_command_line(args);
285 readProblem(inputFilename);
286 if (cross_validation)
287 do_cross_validation();
288 else {
289 Model model = Linear.train(prob, param);
290 Linear.saveModel(new File(modelFilename), model);
291 }
292 }
293 }