View Javadoc

1   package liblinear;
2   
3   import java.io.BufferedReader;
4   import java.io.BufferedWriter;
5   import java.io.Closeable;
6   import java.io.EOFException;
7   import java.io.File;
8   import java.io.FileInputStream;
9   import java.io.FileOutputStream;
10  import java.io.IOException;
11  import java.io.InputStreamReader;
12  import java.io.OutputStreamWriter;
13  import java.io.PrintStream;
14  import java.io.Reader;
15  import java.io.Writer;
16  import java.nio.charset.Charset;
17  import java.util.Formatter;
18  import java.util.Locale;
19  import java.util.Random;
20  import java.util.regex.Pattern;
21  
22  
23  /**
24   * <h2>Java port of <a href="http://www.csie.ntu.edu.tw/~cjlin/liblinear/">liblinear</a> 1.51</h2>
25   *
26   * <p>The usage should be pretty similar to the C version of <tt>liblinear</tt>.</p>
27   * <p>Please consider reading the <tt>README</tt> file of <tt>liblinear</tt>.</p>
28   *
29   * <p><em>The port was done by Benedikt Waldvogel (mail at bwaldvogel.de)</em></p>
30   *
31   * @version 1.51
32   */
33  public class Linear {
34  
35      static final Charset       FILE_CHARSET        = Charset.forName("ISO-8859-1");
36  
37      static final Locale        DEFAULT_LOCALE      = Locale.ENGLISH;
38  
39      private static Object      OUTPUT_MUTEX        = new Object();
40      private static PrintStream DEBUG_OUTPUT        = System.out;
41  
42      /** platform-independent new-line string */
43      final static String        NL                  = System.getProperty("line.separator");
44  
45      private static final long  DEFAULT_RANDOM_SEED = 0L;
46      static Random              random              = new Random(DEFAULT_RANDOM_SEED);
47  
48      /**
49       * @param target predicted classes
50       */
51      public static void crossValidation(Problem prob, Parameter param, int nr_fold, int[] target) {
52          int i;
53          int[] fold_start = new int[nr_fold + 1];
54          int l = prob.l;
55          int[] perm = new int[l];
56  
57          for (i = 0; i < l; i++)
58              perm[i] = i;
59          for (i = 0; i < l; i++) {
60              int j = i + random.nextInt(l - i);
61              swap(perm, i, j);
62          }
63          for (i = 0; i <= nr_fold; i++)
64              fold_start[i] = i * l / nr_fold;
65  
66          for (i = 0; i < nr_fold; i++) {
67              int begin = fold_start[i];
68              int end = fold_start[i + 1];
69              int j, k;
70              Problem subprob = new Problem();
71  
72              subprob.bias = prob.bias;
73              subprob.n = prob.n;
74              subprob.l = l - (end - begin);
75              subprob.x = new FeatureNode[subprob.l][];
76              subprob.y = new int[subprob.l];
77  
78              k = 0;
79              for (j = 0; j < begin; j++) {
80                  subprob.x[k] = prob.x[perm[j]];
81                  subprob.y[k] = prob.y[perm[j]];
82                  ++k;
83              }
84              for (j = end; j < l; j++) {
85                  subprob.x[k] = prob.x[perm[j]];
86                  subprob.y[k] = prob.y[perm[j]];
87                  ++k;
88              }
89              Model submodel = train(subprob, param);
90              for (j = begin; j < end; j++)
91                  target[perm[j]] = predict(submodel, prob.x[perm[j]]);
92          }
93      }
94  
95      /** used as complex return type */
96      private static class GroupClassesReturn {
97  
98          final int[] count;
99          final int[] label;
100         final int   nr_class;
101         final int[] start;
102 
103         GroupClassesReturn( int nr_class, int[] label, int[] start, int[] count ) {
104             this.nr_class = nr_class;
105             this.label = label;
106             this.start = start;
107             this.count = count;
108         }
109     }
110 
111     private static GroupClassesReturn groupClasses(Problem prob, int[] perm) {
112         int l = prob.l;
113         int max_nr_class = 16;
114         int nr_class = 0;
115 
116         int[] label = new int[max_nr_class];
117         int[] count = new int[max_nr_class];
118         int[] data_label = new int[l];
119         int i;
120 
121         for (i = 0; i < l; i++) {
122             int this_label = prob.y[i];
123             int j;
124             for (j = 0; j < nr_class; j++) {
125                 if (this_label == label[j]) {
126                     ++count[j];
127                     break;
128                 }
129             }
130             data_label[i] = j;
131             if (j == nr_class) {
132                 if (nr_class == max_nr_class) {
133                     max_nr_class *= 2;
134                     label = copyOf(label, max_nr_class);
135                     count = copyOf(count, max_nr_class);
136                 }
137                 label[nr_class] = this_label;
138                 count[nr_class] = 1;
139                 ++nr_class;
140             }
141         }
142 
143         int[] start = new int[nr_class];
144         start[0] = 0;
145         for (i = 1; i < nr_class; i++)
146             start[i] = start[i - 1] + count[i - 1];
147         for (i = 0; i < l; i++) {
148             perm[start[data_label[i]]] = i;
149             ++start[data_label[i]];
150         }
151         start[0] = 0;
152         for (i = 1; i < nr_class; i++)
153             start[i] = start[i - 1] + count[i - 1];
154 
155         return new GroupClassesReturn(nr_class, label, start, count);
156     }
157 
158     static void info(String message) {
159         synchronized (OUTPUT_MUTEX) {
160             if (DEBUG_OUTPUT == null) return;
161             DEBUG_OUTPUT.print(message);
162             DEBUG_OUTPUT.flush();
163         }
164     }
165 
166     static void info(String format, Object... args) {
167         synchronized (OUTPUT_MUTEX) {
168             if (DEBUG_OUTPUT == null) return;
169             DEBUG_OUTPUT.printf(format, args);
170             DEBUG_OUTPUT.flush();
171         }
172     }
173 
174     /**
175      * @param s the string to parse for the double value
176      * @throws IllegalArgumentException if s is empty or represents NaN or Infinity
177      * @throws NumberFormatException see {@link Double#parseDouble(String)}
178      */
179     static double atof(String s) {
180         if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
181         double d = Double.parseDouble(s);
182         if (Double.isNaN(d) || Double.isInfinite(d)) {
183             throw new IllegalArgumentException("NaN or Infinity in input: " + s);
184         }
185         return (d);
186     }
187 
188     /**
189      * @param s the string to parse for the integer value
190      * @throws IllegalArgumentException if s is empty
191      * @throws NumberFormatException see {@link Integer#parseInt(String)}
192      */
193     static int atoi(String s) throws NumberFormatException {
194         if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
195         // Integer.parseInt doesn't accept '+' prefixed strings
196         if (s.charAt(0) == '+') s = s.substring(1);
197         return Integer.parseInt(s);
198     }
199 
200     /**
201      * Java5 'backport' of Arrays.copyOf
202      */
203     public static double[] copyOf(double[] original, int newLength) {
204         double[] copy = new double[newLength];
205         System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
206         return copy;
207     }
208 
209     /**
210      * Java5 'backport' of Arrays.copyOf
211      */
212     public static int[] copyOf(int[] original, int newLength) {
213         int[] copy = new int[newLength];
214         System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
215         return copy;
216     }
217 
218     /**
219      * Loads the model from inputReader.
220      * It uses {@link Locale.ENGLISH} for number formatting.
221      *
222      * <p><b>Note: The inputReader is closed after reading or in case of an exception.</b></p>
223      */
224     public static Model loadModel(Reader inputReader) throws IOException {
225         Model model = new Model();
226 
227         model.label = null;
228 
229         Pattern whitespace = Pattern.compile("\\s+");
230 
231         BufferedReader reader = null;
232         if (inputReader instanceof BufferedReader) {
233             reader = (BufferedReader)inputReader;
234         } else {
235             reader = new BufferedReader(inputReader);
236         }
237 
238         try {
239             String line = null;
240             while ((line = reader.readLine()) != null) {
241                 String[] split = whitespace.split(line);
242                 if (split[0].equals("solver_type")) {
243                     SolverType solver = SolverType.valueOf(split[1]);
244                     if (solver == null) {
245                         throw new RuntimeException("unknown solver type");
246                     }
247                     model.solverType = solver;
248                 } else if (split[0].equals("nr_class")) {
249                     model.nr_class = atoi(split[1]);
250                     Integer.parseInt(split[1]);
251                 } else if (split[0].equals("nr_feature")) {
252                     model.nr_feature = atoi(split[1]);
253                 } else if (split[0].equals("bias")) {
254                     model.bias = atof(split[1]);
255                 } else if (split[0].equals("w")) {
256                     break;
257                 } else if (split[0].equals("label")) {
258                     model.label = new int[model.nr_class];
259                     for (int i = 0; i < model.nr_class; i++) {
260                         model.label[i] = atoi(split[i + 1]);
261                     }
262                 } else {
263                     throw new RuntimeException("unknown text in model file: [" + line + "]");
264                 }
265             }
266 
267             int w_size = model.nr_feature;
268             if (model.bias >= 0) w_size++;
269 
270             int nr_w = model.nr_class;
271             if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
272 
273             model.w = new double[w_size * nr_w];
274             int[] buffer = new int[128];
275 
276             for (int i = 0; i < w_size; i++) {
277                 for (int j = 0; j < nr_w; j++) {
278                     int b = 0;
279                     while (true) {
280                         int ch = reader.read();
281                         if (ch == -1) {
282                             throw new EOFException("unexpected EOF");
283                         }
284                         if (ch == ' ') {
285                             model.w[i * nr_w + j] = atof(new String(buffer, 0, b));
286                             break;
287                         } else {
288                             buffer[b++] = ch;
289                         }
290                     }
291                 }
292             }
293         }
294         finally {
295             closeQuietly(reader);
296         }
297 
298         return model;
299     }
300 
301     /**
302      * Loads the model from the file with ISO-8859-1 charset.
303      * It uses {@link Locale.ENGLISH} for number formatting.
304      */
305     public static Model loadModel(File modelFile) throws IOException {
306         BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
307         return loadModel(inputReader);
308     }
309 
310     static void closeQuietly(Closeable c) {
311         if (c == null) return;
312         try {
313             c.close();
314         } catch (Throwable t) {}
315     }
316 
317     public static int predict(Model model, FeatureNode[] x) {
318         double[] dec_values = new double[model.nr_class];
319         return predictValues(model, x, dec_values);
320     }
321 
322     public static int predictProbability(Model model, FeatureNode[] x, double[] prob_estimates) {
323         if (model.solverType == SolverType.L2R_LR) {
324             int nr_class = model.nr_class;
325             int nr_w;
326             if (nr_class == 2)
327                 nr_w = 1;
328             else
329                 nr_w = nr_class;
330 
331             int label = predictValues(model, x, prob_estimates);
332             for (int i = 0; i < nr_w; i++)
333                 prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i]));
334 
335             if (nr_class == 2) // for binary classification
336                 prob_estimates[1] = 1. - prob_estimates[0];
337             else {
338                 double sum = 0;
339                 for (int i = 0; i < nr_class; i++)
340                     sum += prob_estimates[i];
341 
342                 for (int i = 0; i < nr_class; i++)
343                     prob_estimates[i] = prob_estimates[i] / sum;
344             }
345 
346             return label;
347         } else
348             return 0;
349     }
350 
351     public static int predictValues(Model model, FeatureNode[] x, double[] dec_values) {
352         int n;
353         if (model.bias >= 0)
354             n = model.nr_feature + 1;
355         else
356             n = model.nr_feature;
357 
358         double[] w = model.w;
359 
360         int nr_w;
361         if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS)
362             nr_w = 1;
363         else
364             nr_w = model.nr_class;
365 
366         for (int i = 0; i < nr_w; i++)
367             dec_values[i] = 0;
368 
369         for (FeatureNode lx : x) {
370             int idx = lx.index;
371             // the dimension of testing data may exceed that of training
372             if (idx <= n) {
373                 for (int i = 0; i < nr_w; i++) {
374                     dec_values[i] += w[(idx - 1) * nr_w + i] * lx.value;
375                 }
376             }
377         }
378 
379         if (model.nr_class == 2)
380             return (dec_values[0] > 0) ? model.label[0] : model.label[1];
381         else {
382             int dec_max_idx = 0;
383             for (int i = 1; i < model.nr_class; i++) {
384                 if (dec_values[i] > dec_values[dec_max_idx]) dec_max_idx = i;
385             }
386             return model.label[dec_max_idx];
387         }
388     }
389 
390 
391     static void printf(Formatter formatter, String format, Object... args) throws IOException {
392         formatter.format(format, args);
393         IOException ioException = formatter.ioException();
394         if (ioException != null) throw ioException;
395     }
396 
397     /**
398      * Writes the model to the modelOutput.
399      * It uses {@link Locale.ENGLISH} for number formatting.
400      *
401      * <p><b>Note: The modelOutput is closed after reading or in case of an exception.</b></p>
402      */
403     public static void saveModel(Writer modelOutput, Model model) throws IOException {
404         int nr_feature = model.nr_feature;
405         int w_size = nr_feature;
406         if (model.bias >= 0) w_size++;
407 
408         int nr_w = model.nr_class;
409         if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
410 
411         Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE);
412         try {
413             printf(formatter, "solver_type %s\n", model.solverType.name());
414             printf(formatter, "nr_class %d\n", model.nr_class);
415 
416             printf(formatter, "label");
417             for (int i = 0; i < model.nr_class; i++) {
418                 printf(formatter, " %d", model.label[i]);
419             }
420             printf(formatter, "\n");
421 
422             printf(formatter, "nr_feature %d\n", nr_feature);
423             printf(formatter, "bias %.16g\n", model.bias);
424 
425             printf(formatter, "w\n");
426             for (int i = 0; i < w_size; i++) {
427                 for (int j = 0; j < nr_w; j++) {
428                     double value = model.w[i * nr_w + j];
429 
430                     /** this optimization is the reason for {@link Model#equals(double[], double[])} */
431                     if (value == 0.0) {
432                         printf(formatter, "%d ", 0);
433                     } else {
434                         printf(formatter, "%.16g ", value);
435                     }
436                 }
437                 printf(formatter, "\n");
438             }
439 
440             formatter.flush();
441             IOException ioException = formatter.ioException();
442             if (ioException != null) throw ioException;
443         }
444         finally {
445             formatter.close();
446         }
447     }
448 
449     /**
450      * Writes the model to the file with ISO-8859-1 charset.
451      * It uses {@link Locale.ENGLISH} for number formatting.
452      */
453     public static void saveModel(File modelFile, Model model) throws IOException {
454         BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), FILE_CHARSET));
455         saveModel(modelOutput, model);
456     }
457 
458     /*
459      * this method corresponds to the following define in the C version:
460      * #define GETI(i) (y[i]+1)
461      */
462     private static int GETI(byte[] y, int i) {
463         return y[i] + 1;
464     }
465 
466     /**
467      * A coordinate descent algorithm for
468      * L1-loss and L2-loss SVM dual problems
469      *<pre>
470      *  min_\alpha  0.5(\alpha^T (Q + D)\alpha) - e^T \alpha,
471      *    s.t.      0 <= alpha_i <= upper_bound_i,
472      *
473      *  where Qij = yi yj xi^T xj and
474      *  D is a diagonal matrix
475      *
476      * In L1-SVM case:
477      *     upper_bound_i = Cp if y_i = 1
478      *      upper_bound_i = Cn if y_i = -1
479      *      D_ii = 0
480      * In L2-SVM case:
481      *      upper_bound_i = INF
482      *      D_ii = 1/(2*Cp) if y_i = 1
483      *      D_ii = 1/(2*Cn) if y_i = -1
484      *
485      * Given:
486      * x, y, Cp, Cn
487      * eps is the stopping tolerance
488      *
489      * solution will be put in w
490      *</pre>
491      */
492     private static void solve_l2r_l1l2_svc(Problem prob, double[] w, double eps, double Cp, double Cn, SolverType solver_type) {
493         int l = prob.l;
494         int w_size = prob.n;
495         int i, s, iter = 0;
496         double C, d, G;
497         double[] QD = new double[l];
498         int max_iter = 1000;
499         int[] index = new int[l];
500         double[] alpha = new double[l];
501         byte[] y = new byte[l];
502         int active_size = l;
503 
504         // PG: projected gradient, for shrinking and stopping
505         double PG;
506         double PGmax_old = Double.POSITIVE_INFINITY;
507         double PGmin_old = Double.NEGATIVE_INFINITY;
508         double PGmax_new, PGmin_new;
509 
510         // default solver_type: L2R_L2LOSS_SVC_DUAL
511         double diag[] = new double[] {0.5 / Cn, 0, 0.5 / Cp};
512         double upper_bound[] = new double[] {Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY};
513         if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) {
514             diag[0] = 0;
515             diag[2] = 0;
516             upper_bound[0] = Cn;
517             upper_bound[2] = Cp;
518         }
519 
520         for (i = 0; i < w_size; i++)
521             w[i] = 0;
522         for (i = 0; i < l; i++) {
523             alpha[i] = 0;
524             if (prob.y[i] > 0) {
525                 y[i] = +1;
526             } else {
527                 y[i] = -1;
528             }
529             QD[i] = diag[GETI(y, i)];
530 
531             for (FeatureNode xi : prob.x[i]) {
532                 QD[i] += xi.value * xi.value;
533             }
534             index[i] = i;
535         }
536 
537         while (iter < max_iter) {
538             PGmax_new = Double.NEGATIVE_INFINITY;
539             PGmin_new = Double.POSITIVE_INFINITY;
540 
541             for (i = 0; i < active_size; i++) {
542                 int j = i + random.nextInt(active_size - i);
543                 swap(index, i, j);
544             }
545 
546             for (s = 0; s < active_size; s++) {
547                 i = index[s];
548                 G = 0;
549                 byte yi = y[i];
550 
551                 for (FeatureNode xi : prob.x[i]) {
552                     G += w[xi.index - 1] * xi.value;
553                 }
554                 G = G * yi - 1;
555 
556                 C = upper_bound[GETI(y, i)];
557                 G += alpha[i] * diag[GETI(y, i)];
558 
559                 PG = 0;
560                 if (alpha[i] == 0) {
561                     if (G > PGmax_old) {
562                         active_size--;
563                         swap(index, s, active_size);
564                         s--;
565                         continue;
566                     } else if (G < 0) {
567                         PG = G;
568                     }
569                 } else if (alpha[i] == C) {
570                     if (G < PGmin_old) {
571                         active_size--;
572                         swap(index, s, active_size);
573                         s--;
574                         continue;
575                     } else if (G > 0) {
576                         PG = G;
577                     }
578                 } else {
579                     PG = G;
580                 }
581 
582                 PGmax_new = Math.max(PGmax_new, PG);
583                 PGmin_new = Math.min(PGmin_new, PG);
584 
585                 if (Math.abs(PG) > 1.0e-12) {
586                     double alpha_old = alpha[i];
587                     alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C);
588                     d = (alpha[i] - alpha_old) * yi;
589 
590                     for (FeatureNode xi : prob.x[i]) {
591                         w[xi.index - 1] += d * xi.value;
592                     }
593                 }
594             }
595 
596             iter++;
597             if (iter % 10 == 0) info(".");
598 
599             if (PGmax_new - PGmin_new <= eps) {
600                 if (active_size == l)
601                     break;
602                 else {
603                     active_size = l;
604                     info("*");
605                     PGmax_old = Double.POSITIVE_INFINITY;
606                     PGmin_old = Double.NEGATIVE_INFINITY;
607                     continue;
608                 }
609             }
610             PGmax_old = PGmax_new;
611             PGmin_old = PGmin_new;
612             if (PGmax_old <= 0) PGmax_old = Double.POSITIVE_INFINITY;
613             if (PGmin_old >= 0) PGmin_old = Double.NEGATIVE_INFINITY;
614         }
615 
616         info(NL + "optimization finished, #iter = %d" + NL, iter);
617         if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\nUsing -s 2 may be faster (also see FAQ)\n\n");
618 
619         // calculate objective value
620 
621         double v = 0;
622         int nSV = 0;
623         for (i = 0; i < w_size; i++)
624             v += w[i] * w[i];
625         for (i = 0; i < l; i++) {
626             v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2);
627             if (alpha[i] > 0) ++nSV;
628         }
629         info("Objective value = %f" + NL, v / 2);
630         info("nSV = %d" + NL, nSV);
631     }
632 
633     /**
634      * A coordinate descent algorithm for
635      * L1-regularized L2-loss support vector classification
636      *
637      *<pre>
638      *  min_w \sum |wj| + C \sum max(0, 1-yi w^T xi)^2,
639      *
640      * Given:
641      * x, y, Cp, Cn
642      * eps is the stopping tolerance
643      *
644      * solution will be put in w
645      *</pre>
646      */
647     private static void solve_l1r_l2_svc(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
648         int l = prob_col.l;
649         int w_size = prob_col.n;
650         int j, s, iter = 0;
651         int max_iter = 1000;
652         int active_size = w_size;
653         int max_num_linesearch = 20;
654 
655         double sigma = 0.01;
656         double d, G_loss, G, H;
657         double Gmax_old = Double.POSITIVE_INFINITY;
658         double Gmax_new;
659         double Gmax_init = 0; // eclipse moans this variable might not be initialized
660         double d_old, d_diff;
661         double loss_old = 0; // eclipse moans this variable might not be initialized
662         double loss_new;
663         double appxcond, cond;
664 
665         int[] index = new int[w_size];
666         byte[] y = new byte[l];
667         double[] b = new double[l]; // b = 1-ywTx
668         double[] xj_sq = new double[w_size];
669 
670         double[] C = new double[] {Cn, 0, Cp};
671 
672         for (j = 0; j < l; j++) {
673             b[j] = 1;
674             if (prob_col.y[j] > 0)
675                 y[j] = 1;
676             else
677                 y[j] = -1;
678         }
679         for (j = 0; j < w_size; j++) {
680             w[j] = 0;
681             index[j] = j;
682             xj_sq[j] = 0;
683             for (FeatureNode xi : prob_col.x[j]) {
684                 int ind = xi.index - 1;
685                 double val = xi.value;
686                 xi.value *= y[ind]; // x->value stores yi*xij
687                 xj_sq[j] += C[GETI(y, ind)] * val * val;
688             }
689         }
690 
691         while (iter < max_iter) {
692             Gmax_new = 0;
693 
694             for (j = 0; j < active_size; j++) {
695                 int i = j + random.nextInt(active_size - j);
696                 swap(index, i, j);
697             }
698 
699             for (s = 0; s < active_size; s++) {
700                 j = index[s];
701                 G_loss = 0;
702                 H = 0;
703 
704                 for (FeatureNode xi : prob_col.x[j]) {
705                     int ind = xi.index - 1;
706                     if (b[ind] > 0) {
707                         double val = xi.value;
708                         double tmp = C[GETI(y, ind)] * val;
709                         G_loss -= tmp * b[ind];
710                         H += tmp * val;
711                     }
712                 }
713                 G_loss *= 2;
714 
715                 G = G_loss;
716                 H *= 2;
717                 H = Math.max(H, 1e-12);
718 
719                 double Gp = G + 1;
720                 double Gn = G - 1;
721                 double violation = 0;
722                 if (w[j] == 0) {
723                     if (Gp < 0)
724                         violation = -Gp;
725                     else if (Gn > 0)
726                         violation = Gn;
727                     else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
728                         active_size--;
729                         swap(index, s, active_size);
730                         s--;
731                         continue;
732                     }
733                 } else if (w[j] > 0)
734                     violation = Math.abs(Gp);
735                 else
736                     violation = Math.abs(Gn);
737 
738                 Gmax_new = Math.max(Gmax_new, violation);
739 
740                 // obtain Newton direction d
741                 if (Gp <= H * w[j])
742                     d = -Gp / H;
743                 else if (Gn >= H * w[j])
744                     d = -Gn / H;
745                 else
746                     d = -w[j];
747 
748                 if (Math.abs(d) < 1.0e-12) continue;
749 
750                 double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d;
751                 d_old = 0;
752                 int num_linesearch;
753                 for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
754                     d_diff = d_old - d;
755                     cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta;
756 
757                     appxcond = xj_sq[j] * d * d + G_loss * d + cond;
758                     if (appxcond <= 0) {
759                         for (FeatureNode x : prob_col.x[j]) {
760                             b[x.index - 1] += d_diff * x.value;
761                         }
762                         break;
763                     }
764 
765                     if (num_linesearch == 0) {
766                         loss_old = 0;
767                         loss_new = 0;
768                         for (FeatureNode x : prob_col.x[j]) {
769                             int ind = x.index - 1;
770                             if (b[ind] > 0) {
771                                 loss_old += C[GETI(y, ind)] * b[ind] * b[ind];
772                             }
773                             double b_new = b[ind] + d_diff * x.value;
774                             b[ind] = b_new;
775                             if (b_new > 0) {
776                                 loss_new += C[GETI(y, ind)] * b_new * b_new;
777                             }
778                         }
779                     } else {
780                         loss_new = 0;
781                         for (FeatureNode x : prob_col.x[j]) {
782                             int ind = x.index - 1;
783                             double b_new = b[ind] + d_diff * x.value;
784                             b[ind] = b_new;
785                             if (b_new > 0) {
786                                 loss_new += C[GETI(y, ind)] * b_new * b_new;
787                             }
788                         }
789                     }
790 
791                     cond = cond + loss_new - loss_old;
792                     if (cond <= 0)
793                         break;
794                     else {
795                         d_old = d;
796                         d *= 0.5;
797                         delta *= 0.5;
798                     }
799                 }
800 
801                 w[j] += d;
802 
803                 // recompute b[] if line search takes too many steps
804                 if (num_linesearch >= max_num_linesearch) {
805                     info("#");
806                     for (int i = 0; i < l; i++)
807                         b[i] = 1;
808 
809                     for (int i = 0; i < w_size; i++) {
810                         if (w[i] == 0) continue;
811                         for (FeatureNode x : prob_col.x[i]) {
812                             b[x.index - 1] -= w[i] * x.value;
813                         }
814                     }
815                 }
816             }
817 
818             if (iter == 0) Gmax_init = Gmax_new;
819             iter++;
820             if (iter % 10 == 0) info(".");
821 
822             if (Gmax_new <= eps * Gmax_init) {
823                 if (active_size == w_size)
824                     break;
825                 else {
826                     active_size = w_size;
827                     info("*");
828                     Gmax_old = Double.POSITIVE_INFINITY;
829                     continue;
830                 }
831             }
832 
833             Gmax_old = Gmax_new;
834         }
835 
836         info("\noptimization finished, #iter = %d\n", iter);
837         if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\n");
838 
839         // calculate objective value
840 
841         double v = 0;
842         int nnz = 0;
843         for (j = 0; j < w_size; j++) {
844             for (FeatureNode x : prob_col.x[j]) {
845                 x.value *= prob_col.y[x.index - 1]; // restore x->value
846             }
847             if (w[j] != 0) {
848                 v += Math.abs(w[j]);
849                 nnz++;
850             }
851         }
852         for (j = 0; j < l; j++)
853             if (b[j] > 0) v += C[GETI(y, j)] * b[j] * b[j];
854 
855         info("Objective value = %f\n", v);
856         info("#nonzeros/#features = %d/%d\n", nnz, w_size);
857     }
858 
859     /**
860      * A coordinate descent algorithm for
861      * L1-regularized logistic regression problems
862      *
863      *<pre>
864      *  min_w \sum |wj| + C \sum log(1+exp(-yi w^T xi)),
865      *
866      * Given:
867      * x, y, Cp, Cn
868      * eps is the stopping tolerance
869      *
870      * solution will be put in w
871      *</pre>
872      */
873     private static void solve_l1r_lr(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
874         int l = prob_col.l;
875         int w_size = prob_col.n;
876         int j, s, iter = 0;
877         int max_iter = 1000;
878         int active_size = w_size;
879         int max_num_linesearch = 20;
880 
881         double x_min = 0;
882         double sigma = 0.01;
883         double d, G, H;
884         double Gmax_old = Double.POSITIVE_INFINITY;
885         double Gmax_new;
886         double Gmax_init = 0; // eclipse moans this variable might not be initialized
887         double sum1, appxcond1;
888         double sum2, appxcond2;
889         double cond;
890 
891         int[] index = new int[w_size];
892         byte[] y = new byte[l];
893         double[] exp_wTx = new double[l];
894         double[] exp_wTx_new = new double[l];
895         double[] xj_max = new double[w_size];
896         double[] C_sum = new double[w_size];
897         double[] xjneg_sum = new double[w_size];
898         double[] xjpos_sum = new double[w_size];
899 
900         double[] C = new double[] {Cn, 0, Cp};
901 
902         for (j = 0; j < l; j++) {
903             exp_wTx[j] = 1;
904             if (prob_col.y[j] > 0)
905                 y[j] = 1;
906             else
907                 y[j] = -1;
908         }
909         for (j = 0; j < w_size; j++) {
910             w[j] = 0;
911             index[j] = j;
912             xj_max[j] = 0;
913             C_sum[j] = 0;
914             xjneg_sum[j] = 0;
915             xjpos_sum[j] = 0;
916             for (FeatureNode x : prob_col.x[j]) {
917                 int ind = x.index - 1;
918                 double val = x.value;
919                 x_min = Math.min(x_min, val);
920                 xj_max[j] = Math.max(xj_max[j], val);
921                 C_sum[j] += C[GETI(y, ind)];
922                 if (y[ind] == -1)
923                     xjneg_sum[j] += C[GETI(y, ind)] * val;
924                 else
925                     xjpos_sum[j] += C[GETI(y, ind)] * val;
926             }
927         }
928 
929         while (iter < max_iter) {
930             Gmax_new = 0;
931 
932             for (j = 0; j < active_size; j++) {
933                 int i = j + random.nextInt(active_size) - j;
934                 swap(index, i, j);
935             }
936 
937             for (s = 0; s < active_size; s++) {
938                 j = index[s];
939                 sum1 = 0;
940                 sum2 = 0;
941                 H = 0;
942 
943                 for (FeatureNode x : prob_col.x[j]) {
944                     int ind = x.index - 1;
945                     double exp_wTxind = exp_wTx[ind];
946                     double tmp1 = x.value / (1 + exp_wTxind);
947                     double tmp2 = C[GETI(y, ind)] * tmp1;
948                     double tmp3 = tmp2 * exp_wTxind;
949                     sum2 += tmp2;
950                     sum1 += tmp3;
951                     H += tmp1 * tmp3;
952                 }
953 
954                 G = -sum2 + xjneg_sum[j];
955 
956                 double Gp = G + 1;
957                 double Gn = G - 1;
958                 double violation = 0;
959                 if (w[j] == 0) {
960                     if (Gp < 0)
961                         violation = -Gp;
962                     else if (Gn > 0)
963                         violation = Gn;
964                     else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
965                         active_size--;
966                         swap(index, s, active_size);
967                         s--;
968                         continue;
969                     }
970                 } else if (w[j] > 0)
971                     violation = Math.abs(Gp);
972                 else
973                     violation = Math.abs(Gn);
974 
975                 Gmax_new = Math.max(Gmax_new, violation);
976 
977                 // obtain Newton direction d
978                 if (Gp <= H * w[j])
979                     d = -Gp / H;
980                 else if (Gn >= H * w[j])
981                     d = -Gn / H;
982                 else
983                     d = -w[j];
984 
985                 if (Math.abs(d) < 1.0e-12) continue;
986 
987                 d = Math.min(Math.max(d, -10.0), 10.0);
988 
989                 double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d;
990                 int num_linesearch;
991                 for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
992                     cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta;
993 
994                     if (x_min >= 0) {
995                         double tmp = Math.exp(d * xj_max[j]);
996                         appxcond1 = Math.log(1 + sum1 * (tmp - 1) / xj_max[j] / C_sum[j]) * C_sum[j] + cond - d * xjpos_sum[j];
997                         appxcond2 = Math.log(1 + sum2 * (1 / tmp - 1) / xj_max[j] / C_sum[j]) * C_sum[j] + cond + d * xjneg_sum[j];
998                         if (Math.min(appxcond1, appxcond2) <= 0) {
999                             for (FeatureNode x : prob_col.x[j]) {
1000                                 exp_wTx[x.index - 1] *= Math.exp(d * x.value);
1001                             }
1002                             break;
1003                         }
1004                     }
1005 
1006                     cond += d * xjneg_sum[j];
1007 
1008                     int i = 0;
1009                     for (FeatureNode x : prob_col.x[j]) {
1010                         int ind = x.index - 1;
1011                         double exp_dx = Math.exp(d * x.value);
1012                         exp_wTx_new[i] = exp_wTx[ind] * exp_dx;
1013                         cond += C[GETI(y, ind)] * Math.log((1 + exp_wTx_new[i]) / (exp_dx + exp_wTx_new[i]));
1014                         i++;
1015                     }
1016 
1017                     if (cond <= 0) {
1018                         i = 0;
1019                         for (FeatureNode x : prob_col.x[j]) {
1020                             int ind = x.index - 1;
1021                             exp_wTx[ind] = exp_wTx_new[i];
1022                             i++;
1023                         }
1024                         break;
1025                     } else {
1026                         d *= 0.5;
1027                         delta *= 0.5;
1028                     }
1029                 }
1030 
1031                 w[j] += d;
1032 
1033                 // recompute exp_wTx[] if line search takes too many steps
1034                 if (num_linesearch >= max_num_linesearch) {
1035                     info("#");
1036                     for (int i = 0; i < l; i++)
1037                         exp_wTx[i] = 0;
1038 
1039                     for (int i = 0; i < w_size; i++) {
1040                         if (w[i] == 0) continue;
1041                         for (FeatureNode x : prob_col.x[i]) {
1042                             exp_wTx[x.index - 1] += w[i] * x.value;
1043                         }
1044                     }
1045 
1046                     for (int i = 0; i < l; i++)
1047                         exp_wTx[i] = Math.exp(exp_wTx[i]);
1048                 }
1049             }
1050 
1051             if (iter == 0) Gmax_init = Gmax_new;
1052             iter++;
1053             if (iter % 10 == 0) info(".");
1054 
1055             if (Gmax_new <= eps * Gmax_init) {
1056                 if (active_size == w_size)
1057                     break;
1058                 else {
1059                     active_size = w_size;
1060                     info("*");
1061                     Gmax_old = Double.POSITIVE_INFINITY;
1062                     continue;
1063                 }
1064             }
1065 
1066             Gmax_old = Gmax_new;
1067         }
1068 
1069         info("\noptimization finished, #iter = %d\n", iter);
1070         if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\n");
1071 
1072         // calculate objective value
1073 
1074         double v = 0;
1075         int nnz = 0;
1076         for (j = 0; j < w_size; j++)
1077             if (w[j] != 0) {
1078                 v += Math.abs(w[j]);
1079                 nnz++;
1080             }
1081         for (j = 0; j < l; j++)
1082             if (y[j] == 1)
1083                 v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]);
1084             else
1085                 v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]);
1086 
1087         info("Objective value = %f\n", v);
1088         info("#nonzeros/#features = %d/%d\n", nnz, w_size);
1089     }
1090 
1091     // transpose matrix X from row format to column format
1092     static Problem transpose(Problem prob) {
1093         int l = prob.l;
1094         int n = prob.n;
1095         int[] col_ptr = new int[n + 1];
1096         Problem prob_col = new Problem();
1097         prob_col.l = l;
1098         prob_col.n = n;
1099         prob_col.y = new int[l];
1100         prob_col.x = new FeatureNode[n][];
1101 
1102         for (int i = 0; i < l; i++)
1103             prob_col.y[i] = prob.y[i];
1104 
1105         for (int i = 0; i < l; i++) {
1106             for (FeatureNode x : prob.x[i]) {
1107                 col_ptr[x.index]++;
1108             }
1109         }
1110 
1111         for (int i = 0; i < n; i++) {
1112             prob_col.x[i] = new FeatureNode[col_ptr[i + 1]];
1113             col_ptr[i] = 0; // reuse the array to count the nr of elements
1114         }
1115 
1116         for (int i = 0; i < l; i++) {
1117             for (int j = 0; j < prob.x[i].length; j++) {
1118                 FeatureNode x = prob.x[i][j];
1119                 int index = x.index - 1;
1120                 prob_col.x[index][col_ptr[index]] = new FeatureNode(i + 1, x.value);
1121                 col_ptr[index]++;
1122             }
1123         }
1124 
1125         return prob_col;
1126     }
1127 
1128     static void swap(double[] array, int idxA, int idxB) {
1129         double temp = array[idxA];
1130         array[idxA] = array[idxB];
1131         array[idxB] = temp;
1132     }
1133 
1134     static void swap(int[] array, int idxA, int idxB) {
1135         int temp = array[idxA];
1136         array[idxA] = array[idxB];
1137         array[idxB] = temp;
1138     }
1139 
1140     static void swap(IntArrayPointer array, int idxA, int idxB) {
1141         int temp = array.get(idxA);
1142         array.set(idxA, array.get(idxB));
1143         array.set(idxB, temp);
1144     }
1145 
1146     /**
1147      * @throws IllegalArgumentException if the feature nodes of prob are not sorted in ascending order
1148      */
1149     public static Model train(Problem prob, Parameter param) {
1150 
1151         if (prob == null) throw new IllegalArgumentException("problem must not be null");
1152         if (param == null) throw new IllegalArgumentException("parameter must not be null");
1153 
1154         for (FeatureNode[] nodes : prob.x) {
1155             int indexBefore = 0;
1156             for (FeatureNode n : nodes) {
1157                 if (n.index <= indexBefore) {
1158                     throw new IllegalArgumentException("feature nodes must be sorted by index in ascending order");
1159                 }
1160                 indexBefore = n.index;
1161             }
1162         }
1163 
1164         int i, j;
1165         int l = prob.l;
1166         int n = prob.n;
1167         int w_size = prob.n;
1168         Model model = new Model();
1169 
1170         if (prob.bias >= 0)
1171             model.nr_feature = n - 1;
1172         else
1173             model.nr_feature = n;
1174         model.solverType = param.solverType;
1175         model.bias = prob.bias;
1176 
1177         int[] perm = new int[l];
1178         // group training data of the same class
1179         GroupClassesReturn rv = groupClasses(prob, perm);
1180         int nr_class = rv.nr_class;
1181         int[] label = rv.label;
1182         int[] start = rv.start;
1183         int[] count = rv.count;
1184 
1185         model.nr_class = nr_class;
1186         model.label = new int[nr_class];
1187         for (i = 0; i < nr_class; i++)
1188             model.label[i] = label[i];
1189 
1190         // calculate weighted C
1191         double[] weighted_C = new double[nr_class];
1192         for (i = 0; i < nr_class; i++) {
1193             weighted_C[i] = param.C;
1194         }
1195 
1196         for (i = 0; i < param.getNumWeights(); i++) {
1197             for (j = 0; j < nr_class; j++)
1198                 if (param.weightLabel[i] == label[j]) break;
1199             if (j == nr_class) throw new IllegalArgumentException("class label " + param.weightLabel[i] + " specified in weight is not found");
1200 
1201             weighted_C[j] *= param.weight[i];
1202         }
1203 
1204         // constructing the subproblem
1205         FeatureNode[][] x = new FeatureNode[l][];
1206         for (i = 0; i < l; i++)
1207             x[i] = prob.x[perm[i]];
1208 
1209         int k;
1210         Problem sub_prob = new Problem();
1211         sub_prob.l = l;
1212         sub_prob.n = n;
1213         sub_prob.x = new FeatureNode[sub_prob.l][];
1214         sub_prob.y = new int[sub_prob.l];
1215 
1216         for (k = 0; k < sub_prob.l; k++)
1217             sub_prob.x[k] = x[k];
1218 
1219         // multi-class svm by Crammer and Singer
1220         if (param.solverType == SolverType.MCSVM_CS) {
1221             model.w = new double[n * nr_class];
1222             for (i = 0; i < nr_class; i++) {
1223                 for (j = start[i]; j < start[i] + count[i]; j++) {
1224                     sub_prob.y[j] = i;
1225                 }
1226             }
1227 
1228             SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps);
1229             solver.solve(model.w);
1230         } else {
1231             if (nr_class == 2) {
1232                 model.w = new double[w_size];
1233 
1234                 int e0 = start[0] + count[0];
1235                 k = 0;
1236                 for (; k < e0; k++)
1237                     sub_prob.y[k] = +1;
1238                 for (; k < sub_prob.l; k++)
1239                     sub_prob.y[k] = -1;
1240 
1241                 train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]);
1242             } else {
1243                 model.w = new double[w_size * nr_class];
1244                 double[] w = new double[w_size];
1245                 for (i = 0; i < nr_class; i++) {
1246                     int si = start[i];
1247                     int ei = si + count[i];
1248 
1249                     k = 0;
1250                     for (; k < si; k++)
1251                         sub_prob.y[k] = -1;
1252                     for (; k < ei; k++)
1253                         sub_prob.y[k] = +1;
1254                     for (; k < sub_prob.l; k++)
1255                         sub_prob.y[k] = -1;
1256 
1257                     train_one(sub_prob, param, w, weighted_C[i], param.C);
1258 
1259                     for (j = 0; j < n; j++)
1260                         model.w[j * nr_class + i] = w[j];
1261                 }
1262             }
1263 
1264         }
1265         return model;
1266     }
1267 
1268     private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) {
1269         double eps = param.eps;
1270         int pos = 0;
1271         for (int i = 0; i < prob.l; i++)
1272             if (prob.y[i] == +1) pos++;
1273         int neg = prob.l - pos;
1274 
1275         Function fun_obj = null;
1276         switch (param.solverType) {
1277             case L2R_LR: {
1278                 fun_obj = new L2R_LrFunction(prob, Cp, Cn);
1279                 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1280                 tron_obj.tron(w);
1281                 break;
1282             }
1283             case L2R_L2LOSS_SVC: {
1284                 fun_obj = new L2R_L2_SvcFunction(prob, Cp, Cn);
1285                 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1286                 tron_obj.tron(w);
1287                 break;
1288             }
1289             case L2R_L2LOSS_SVC_DUAL:
1290                 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL);
1291                 break;
1292             case L2R_L1LOSS_SVC_DUAL:
1293                 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL);
1294                 break;
1295             case L1R_L2LOSS_SVC: {
1296                 Problem prob_col = transpose(prob);
1297                 solve_l1r_l2_svc(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1298                 break;
1299             }
1300             case L1R_LR: {
1301                 Problem prob_col = transpose(prob);
1302                 solve_l1r_lr(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1303                 break;
1304             }
1305             default:
1306                 throw new IllegalStateException("unknown solver type: " + param.solverType);
1307         }
1308     }
1309 
1310     public static void disableDebugOutput() {
1311         setDebugOutput(null);
1312     }
1313 
1314     public static void enableDebugOutput() {
1315         setDebugOutput(System.out);
1316     }
1317 
1318     public static void setDebugOutput(PrintStream debugOutput) {
1319         synchronized (OUTPUT_MUTEX) {
1320             DEBUG_OUTPUT = debugOutput;
1321         }
1322     }
1323 
1324     /**
1325      * resets the PRNG
1326      *
1327      * this is i.a. needed for regression testing (eg. the Weka wrapper)
1328      */
1329     public static void resetRandom() {
1330         random = new Random(DEFAULT_RANDOM_SEED);
1331     }
1332 }