View Javadoc

1   package liblinear;
2   
3   import static liblinear.Linear.NL;
4   import static liblinear.Linear.atof;
5   import static liblinear.Linear.atoi;
6   
7   import java.io.BufferedReader;
8   import java.io.File;
9   import java.io.FileReader;
10  import java.io.IOException;
11  import java.util.ArrayList;
12  import java.util.List;
13  import java.util.StringTokenizer;
14  
15  
16  public class Train {
17  
18      public static void main(String[] args) throws IOException, InvalidInputDataException {
19          new Train().run(args);
20      }
21  
22      private double    bias             = 1;
23      private boolean   cross_validation = false;
24      private String    inputFilename;
25      private String    modelFilename;
26      private int       nr_fold;
27      private Parameter param            = null;
28      private Problem   prob             = null;
29  
30      private void do_cross_validation() {
31          int[] target = new int[prob.l];
32  
33          long start, stop;
34          start = System.currentTimeMillis();
35          Linear.crossValidation(prob, param, nr_fold, target);
36          stop = System.currentTimeMillis();
37          System.out.println("time: " + (stop - start) + " ms");
38  
39          int total_correct = 0;
40          for (int i = 0; i < prob.l; i++)
41              if (target[i] == prob.y[i]) ++total_correct;
42  
43          System.out.printf("correct: %d" + NL, total_correct);
44          System.out.printf("Cross Validation Accuracy = %g%%\n", 100.0 * total_correct / prob.l);
45      }
46  
47      private void exit_with_help() {
48          System.out.println("Usage: train [options] training_set_file [model_file]" + NL //
49              + "options:" + NL//
50              + "-s type : set type of solver (default 1)" + NL//
51              + "   0 -- L2-regularized logistic regression" + NL//
52              + "   1 -- L2-regularized L2-loss support vector classification (dual)" + NL//
53              + "   2 -- L2-regularized L2-loss support vector classification (primal)" + NL//
54              + "   3 -- L2-regularized L1-loss support vector classification (dual)" + NL//
55              + "   4 -- multi-class support vector classification by Crammer and Singer" + NL//
56              + "   5 -- L1-regularized L2-loss support vector classification" + NL//
57              + "   6 -- L1-regularized logistic regression" + NL//
58              + "-c cost : set the parameter C (default 1)" + NL//
59              + "-e epsilon : set tolerance of termination criterion" + NL//
60              + "   -s 0 and 2" + NL//
61              + "       |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2," + NL//
62              + "       where f is the primal function and pos/neg are # of" + NL//
63              + "       positive/negative data (default 0.01)" + NL//
64              + "   -s 1, 3, and 4" + NL//
65              + "       Dual maximal violation <= eps; similar to libsvm (default 0.1)" + NL//
66              + "   -s 5 and 6" + NL//
67              + "       |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf," + NL//
68              + "       where f is the primal function (default 0.01)" + NL//
69              + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)" + NL//
70              + "-wi weight: weights adjust the parameter C of different classes (see README for details)" + NL//
71              + "-v n: n-fold cross validation mode" + NL//
72              + "-q : quiet mode (no outputs)" + NL//
73          );
74          System.exit(1);
75      }
76  
77  
78      Problem getProblem() {
79          return prob;
80      }
81  
82      double getBias() {
83          return bias;
84      }
85  
86      Parameter getParameter() {
87          return param;
88      }
89  
90      void parse_command_line(String argv[]) {
91          int i;
92  
93          // eps: see setting below
94          param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY);
95          // default values
96          bias = -1;
97          cross_validation = false;
98  
99          int nr_weight = 0;
100 
101         // parse options
102         for (i = 0; i < argv.length; i++) {
103             if (argv[i].charAt(0) != '-') break;
104             if (++i >= argv.length) exit_with_help();
105             switch (argv[i - 1].charAt(1)) {
106                 case 's':
107                     param.solverType = SolverType.values()[atoi(argv[i])];
108                     break;
109                 case 'c':
110                     param.setC(atof(argv[i]));
111                     break;
112                 case 'e':
113                     param.setEps(atof(argv[i]));
114                     break;
115                 case 'B':
116                     bias = atof(argv[i]);
117                     break;
118                 case 'w':
119                     ++nr_weight;
120                     {
121                         int[] old = param.weightLabel;
122                         param.weightLabel = new int[nr_weight];
123                         System.arraycopy(old, 0, param.weightLabel, 0, nr_weight - 1);
124                     }
125 
126                     {
127                         double[] old = param.weight;
128                         param.weight = new double[nr_weight];
129                         System.arraycopy(old, 0, param.weight, 0, nr_weight - 1);
130                     }
131 
132                     param.weightLabel[nr_weight - 1] = atoi(argv[i - 1].substring(2));
133                     param.weight[nr_weight - 1] = atof(argv[i]);
134                     break;
135                 case 'v':
136                     cross_validation = true;
137                     nr_fold = atoi(argv[i]);
138                     if (nr_fold < 2) {
139                         System.err.print("n-fold cross validation: n must >= 2\n");
140                         exit_with_help();
141                     }
142                     break;
143                 case 'q':
144                     Linear.disableDebugOutput();
145                     break;
146                 default:
147                     System.err.println("unknown option");
148                     exit_with_help();
149             }
150         }
151 
152         // determine filenames
153 
154         if (i >= argv.length) exit_with_help();
155 
156         inputFilename = argv[i];
157 
158         if (i < argv.length - 1)
159             modelFilename = argv[i + 1];
160         else {
161             int p = argv[i].lastIndexOf('/');
162             ++p; // whew...
163             modelFilename = argv[i].substring(p) + ".model";
164         }
165 
166         if (param.eps == Double.POSITIVE_INFINITY) {
167             if (param.solverType == SolverType.L2R_LR || param.solverType == SolverType.L2R_L2LOSS_SVC) {
168                 param.setEps(0.01);
169             } else if (param.solverType == SolverType.L2R_L2LOSS_SVC_DUAL || param.solverType == SolverType.L2R_L1LOSS_SVC_DUAL
170                 || param.solverType == SolverType.MCSVM_CS) {
171                 param.setEps(0.1);
172             } else if (param.solverType == SolverType.L1R_L2LOSS_SVC || param.solverType == SolverType.L1R_LR) {
173                 param.setEps(0.01);
174             }
175         }
176     }
177 
178     /**
179      * reads a problem from LibSVM format
180      * @param filename the name of the svm file
181      * @throws IOException obviously in case of any I/O exception ;)
182      * @throws InvalidInputDataException if the input file is not correctly formatted
183      */
184     public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException {
185         BufferedReader fp = new BufferedReader(new FileReader(file));
186         List<Integer> vy = new ArrayList<Integer>();
187         List<FeatureNode[]> vx = new ArrayList<FeatureNode[]>();
188         int max_index = 0;
189 
190         int lineNr = 0;
191 
192         try {
193             while (true) {
194                 String line = fp.readLine();
195                 if (line == null) break;
196                 lineNr++;
197 
198                 StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
199                 String token = st.nextToken();
200 
201                 try {
202                     vy.add(atoi(token));
203                 } catch (NumberFormatException e) {
204                     throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e);
205                 }
206 
207                 int m = st.countTokens() / 2;
208                 FeatureNode[] x;
209                 if (bias >= 0) {
210                     x = new FeatureNode[m + 1];
211                 } else {
212                     x = new FeatureNode[m];
213                 }
214                 int indexBefore = 0;
215                 for (int j = 0; j < m; j++) {
216 
217                     token = st.nextToken();
218                     int index;
219                     try {
220                         index = atoi(token);
221                     } catch (NumberFormatException e) {
222                         throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e);
223                     }
224 
225                     // assert that indices are valid and sorted
226                     if (index < 0) throw new InvalidInputDataException("invalid index: " + index, file, lineNr);
227                     if (index <= indexBefore) throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr);
228                     indexBefore = index;
229 
230                     token = st.nextToken();
231                     try {
232                         double value = atof(token);
233                         x[j] = new FeatureNode(index, value);
234                     } catch (NumberFormatException e) {
235                         throw new InvalidInputDataException("invalid value: " + token, file, lineNr);
236                     }
237                 }
238                 if (m > 0) {
239                     max_index = Math.max(max_index, x[m - 1].index);
240                 }
241 
242                 vx.add(x);
243             }
244 
245             return constructProblem(vy, vx, max_index, bias);
246         }
247         finally {
248             fp.close();
249         }
250     }
251 
252     void readProblem(String filename) throws IOException, InvalidInputDataException {
253         prob = Train.readProblem(new File(filename), bias);
254     }
255 
256     private static Problem constructProblem(List<Integer> vy, List<FeatureNode[]> vx, int max_index, double bias) {
257         Problem prob = new Problem();
258         prob.bias = bias;
259         prob.l = vy.size();
260         prob.n = max_index;
261         if (bias >= 0) {
262             prob.n++;
263         }
264         prob.x = new FeatureNode[prob.l][];
265         for (int i = 0; i < prob.l; i++) {
266             prob.x[i] = vx.get(i);
267 
268             if (bias >= 0) {
269                 assert prob.x[i][prob.x[i].length - 1] == null;
270                 prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias);
271             } else {
272                 assert prob.x[i][prob.x[i].length - 1] != null;
273             }
274         }
275 
276         prob.y = new int[prob.l];
277         for (int i = 0; i < prob.l; i++)
278             prob.y[i] = vy.get(i);
279 
280         return prob;
281     }
282 
283     private void run(String[] args) throws IOException, InvalidInputDataException {
284         parse_command_line(args);
285         readProblem(inputFilename);
286         if (cross_validation)
287             do_cross_validation();
288         else {
289             Model model = Linear.train(prob, param);
290             Linear.saveModel(new File(modelFilename), model);
291         }
292     }
293 }