View Javadoc

1   package de.bwaldvogel.liblinear;
2   
3   import static de.bwaldvogel.liblinear.Linear.atof;
4   import static de.bwaldvogel.liblinear.Linear.atoi;
5   
6   import java.io.BufferedReader;
7   import java.io.File;
8   import java.io.FileReader;
9   import java.io.IOException;
10  import java.util.ArrayList;
11  import java.util.List;
12  import java.util.NoSuchElementException;
13  import java.util.StringTokenizer;
14  
15  
16  public class Train {
17  
18      public static void main(String[] args) throws IOException, InvalidInputDataException {
19          new Train().run(args);
20      }
21  
22      private double    bias             = 1;
23      private boolean   cross_validation = false;
24      private String    inputFilename;
25      private String    modelFilename;
26      private int       nr_fold;
27      private Parameter param            = null;
28      private Problem   prob             = null;
29  
30      private void do_cross_validation() {
31          int[] target = new int[prob.l];
32  
33          long start, stop;
34          start = System.currentTimeMillis();
35          Linear.crossValidation(prob, param, nr_fold, target);
36          stop = System.currentTimeMillis();
37          System.out.println("time: " + (stop - start) + " ms");
38  
39          int total_correct = 0;
40          for (int i = 0; i < prob.l; i++)
41              if (target[i] == prob.y[i]) ++total_correct;
42  
43          System.out.printf("correct: %d%n", total_correct);
44          System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l);
45      }
46  
47      private void exit_with_help() {
48          System.out.printf("Usage: train [options] training_set_file [model_file]%n" //
49              + "options:%n"
50              + "-s type : set type of solver (default 1)%n"
51              + "   0 -- L2-regularized logistic regression (primal)%n"
52              + "   1 -- L2-regularized L2-loss support vector classification (dual)%n"
53              + "   2 -- L2-regularized L2-loss support vector classification (primal)%n"
54              + "   3 -- L2-regularized L1-loss support vector classification (dual)%n"
55              + "   4 -- multi-class support vector classification by Crammer and Singer%n"
56              + "   5 -- L1-regularized L2-loss support vector classification%n"
57              + "   6 -- L1-regularized logistic regression%n"
58              + "   7 -- L2-regularized logistic regression (dual)%n"
59              + "-c cost : set the parameter C (default 1)%n"
60              + "-e epsilon : set tolerance of termination criterion%n"
61              + "   -s 0 and 2%n"
62              + "       |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n"
63              + "       where f is the primal function and pos/neg are # of%n"
64              + "       positive/negative data (default 0.01)%n"
65              + "   -s 1, 3, 4 and 7%n"
66              + "       Dual maximal violation <= eps; similar to libsvm (default 0.1)%n"
67              + "   -s 5 and 6%n"
68              + "       |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n"
69              + "       where f is the primal function (default 0.01)%n"
70              + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n"
71              + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n"
72              + "-v n: n-fold cross validation mode%n"
73              + "-q : quiet mode (no outputs)%n");
74          System.exit(1);
75      }
76  
77  
78      Problem getProblem() {
79          return prob;
80      }
81  
82      double getBias() {
83          return bias;
84      }
85  
86      Parameter getParameter() {
87          return param;
88      }
89  
90      void parse_command_line(String argv[]) {
91          int i;
92  
93          // eps: see setting below
94          param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY);
95          // default values
96          bias = -1;
97          cross_validation = false;
98  
99          int nr_weight = 0;
100 
101         // parse options
102         for (i = 0; i < argv.length; i++) {
103             if (argv[i].charAt(0) != '-') break;
104             if (++i >= argv.length) exit_with_help();
105             switch (argv[i - 1].charAt(1)) {
106                 case 's':
107                     param.solverType = SolverType.values()[atoi(argv[i])];
108                     break;
109                 case 'c':
110                     param.setC(atof(argv[i]));
111                     break;
112                 case 'e':
113                     param.setEps(atof(argv[i]));
114                     break;
115                 case 'B':
116                     bias = atof(argv[i]);
117                     break;
118                 case 'w':
119                     ++nr_weight;
120                     {
121                         int[] old = param.weightLabel;
122                         param.weightLabel = new int[nr_weight];
123                         System.arraycopy(old, 0, param.weightLabel, 0, nr_weight - 1);
124                     }
125 
126                     {
127                         double[] old = param.weight;
128                         param.weight = new double[nr_weight];
129                         System.arraycopy(old, 0, param.weight, 0, nr_weight - 1);
130                     }
131 
132                     param.weightLabel[nr_weight - 1] = atoi(argv[i - 1].substring(2));
133                     param.weight[nr_weight - 1] = atof(argv[i]);
134                     break;
135                 case 'v':
136                     cross_validation = true;
137                     nr_fold = atoi(argv[i]);
138                     if (nr_fold < 2) {
139                         System.err.println("n-fold cross validation: n must >= 2");
140                         exit_with_help();
141                     }
142                     break;
143                 case 'q':
144                     Linear.disableDebugOutput();
145                     break;
146                 default:
147                     System.err.println("unknown option");
148                     exit_with_help();
149             }
150         }
151 
152         // determine filenames
153 
154         if (i >= argv.length) exit_with_help();
155 
156         inputFilename = argv[i];
157 
158         if (i < argv.length - 1)
159             modelFilename = argv[i + 1];
160         else {
161             int p = argv[i].lastIndexOf('/');
162             ++p; // whew...
163             modelFilename = argv[i].substring(p) + ".model";
164         }
165 
166         if (param.eps == Double.POSITIVE_INFINITY) {
167             if (param.solverType == SolverType.L2R_LR || param.solverType == SolverType.L2R_L2LOSS_SVC) {
168                 param.setEps(0.01);
169             } else if (param.solverType == SolverType.L2R_L2LOSS_SVC_DUAL || param.solverType == SolverType.L2R_L1LOSS_SVC_DUAL
170                 || param.solverType == SolverType.MCSVM_CS || param.solverType == SolverType.L2R_LR_DUAL) {
171                 param.setEps(0.1);
172             } else if (param.solverType == SolverType.L1R_L2LOSS_SVC || param.solverType == SolverType.L1R_LR) {
173                 param.setEps(0.01);
174             }
175         }
176     }
177 
178     /**
179      * reads a problem from LibSVM format
180      * @param file the SVM file
181      * @throws IOException obviously in case of any I/O exception ;)
182      * @throws InvalidInputDataException if the input file is not correctly formatted
183      */
184     public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException {
185         BufferedReader fp = new BufferedReader(new FileReader(file));
186         List<Integer> vy = new ArrayList<Integer>();
187         List<FeatureNode[]> vx = new ArrayList<FeatureNode[]>();
188         int max_index = 0;
189 
190         int lineNr = 0;
191 
192         try {
193             while (true) {
194                 String line = fp.readLine();
195                 if (line == null) break;
196                 lineNr++;
197 
198                 StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
199                 String token;
200                 try {
201                     token = st.nextToken();
202                 } catch (NoSuchElementException e) {
203                     throw new InvalidInputDataException("empty line", file, lineNr, e);
204                 }
205 
206                 try {
207                     vy.add(atoi(token));
208                 } catch (NumberFormatException e) {
209                     throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e);
210                 }
211 
212                 int m = st.countTokens() / 2;
213                 FeatureNode[] x;
214                 if (bias >= 0) {
215                     x = new FeatureNode[m + 1];
216                 } else {
217                     x = new FeatureNode[m];
218                 }
219                 int indexBefore = 0;
220                 for (int j = 0; j < m; j++) {
221 
222                     token = st.nextToken();
223                     int index;
224                     try {
225                         index = atoi(token);
226                     } catch (NumberFormatException e) {
227                         throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e);
228                     }
229 
230                     // assert that indices are valid and sorted
231                     if (index < 0) throw new InvalidInputDataException("invalid index: " + index, file, lineNr);
232                     if (index <= indexBefore) throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr);
233                     indexBefore = index;
234 
235                     token = st.nextToken();
236                     try {
237                         double value = atof(token);
238                         x[j] = new FeatureNode(index, value);
239                     } catch (NumberFormatException e) {
240                         throw new InvalidInputDataException("invalid value: " + token, file, lineNr);
241                     }
242                 }
243                 if (m > 0) {
244                     max_index = Math.max(max_index, x[m - 1].index);
245                 }
246 
247                 vx.add(x);
248             }
249 
250             return constructProblem(vy, vx, max_index, bias);
251         }
252         finally {
253             fp.close();
254         }
255     }
256 
257     void readProblem(String filename) throws IOException, InvalidInputDataException {
258         prob = Train.readProblem(new File(filename), bias);
259     }
260 
261     private static Problem constructProblem(List<Integer> vy, List<FeatureNode[]> vx, int max_index, double bias) {
262         Problem prob = new Problem();
263         prob.bias = bias;
264         prob.l = vy.size();
265         prob.n = max_index;
266         if (bias >= 0) {
267             prob.n++;
268         }
269         prob.x = new FeatureNode[prob.l][];
270         for (int i = 0; i < prob.l; i++) {
271             prob.x[i] = vx.get(i);
272 
273             if (bias >= 0) {
274                 assert prob.x[i][prob.x[i].length - 1] == null;
275                 prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias);
276             }
277         }
278 
279         prob.y = new int[prob.l];
280         for (int i = 0; i < prob.l; i++)
281             prob.y[i] = vy.get(i);
282 
283         return prob;
284     }
285 
286     private void run(String[] args) throws IOException, InvalidInputDataException {
287         parse_command_line(args);
288         readProblem(inputFilename);
289         if (cross_validation)
290             do_cross_validation();
291         else {
292             Model model = Linear.train(prob, param);
293             Linear.saveModel(new File(modelFilename), model);
294         }
295     }
296 }