View Javadoc

1   package de.bwaldvogel.liblinear;
2   
3   import java.io.BufferedReader;
4   import java.io.BufferedWriter;
5   import java.io.Closeable;
6   import java.io.EOFException;
7   import java.io.File;
8   import java.io.FileInputStream;
9   import java.io.FileOutputStream;
10  import java.io.IOException;
11  import java.io.InputStreamReader;
12  import java.io.OutputStreamWriter;
13  import java.io.PrintStream;
14  import java.io.Reader;
15  import java.io.Writer;
16  import java.nio.charset.Charset;
17  import java.util.Formatter;
18  import java.util.Locale;
19  import java.util.Random;
20  import java.util.regex.Pattern;
21  
22  
23  /**
24   * <h2>Java port of <a href="http://www.csie.ntu.edu.tw/~cjlin/liblinear/">liblinear</a></h2>
25   *
26   * <p>The usage should be pretty similar to the C version of <tt>liblinear</tt>.</p>
27   * <p>Please consider reading the <tt>README</tt> file of <tt>liblinear</tt>.</p>
28   *
29   * <p><em>The port was done by Benedikt Waldvogel (mail at bwaldvogel.de)</em></p>
30   *
31   * @version 1.8
32   */
33  public class Linear {
34  
35      static final Charset       FILE_CHARSET        = Charset.forName("ISO-8859-1");
36  
37      static final Locale        DEFAULT_LOCALE      = Locale.ENGLISH;
38  
39      private static Object      OUTPUT_MUTEX        = new Object();
40      private static PrintStream DEBUG_OUTPUT        = System.out;
41  
42      /** platform-independent new-line string */
43      final static String        NL                  = System.getProperty("line.separator");
44  
45      private static final long  DEFAULT_RANDOM_SEED = 0L;
46      static Random              random              = new Random(DEFAULT_RANDOM_SEED);
47  
48      /**
49       * @param target predicted classes
50       */
51      public static void crossValidation(Problem prob, Parameter param, int nr_fold, int[] target) {
52          int i;
53          int[] fold_start = new int[nr_fold + 1];
54          int l = prob.l;
55          int[] perm = new int[l];
56  
57          for (i = 0; i < l; i++)
58              perm[i] = i;
59          for (i = 0; i < l; i++) {
60              int j = i + random.nextInt(l - i);
61              swap(perm, i, j);
62          }
63          for (i = 0; i <= nr_fold; i++)
64              fold_start[i] = i * l / nr_fold;
65  
66          for (i = 0; i < nr_fold; i++) {
67              int begin = fold_start[i];
68              int end = fold_start[i + 1];
69              int j, k;
70              Problem subprob = new Problem();
71  
72              subprob.bias = prob.bias;
73              subprob.n = prob.n;
74              subprob.l = l - (end - begin);
75              subprob.x = new FeatureNode[subprob.l][];
76              subprob.y = new int[subprob.l];
77  
78              k = 0;
79              for (j = 0; j < begin; j++) {
80                  subprob.x[k] = prob.x[perm[j]];
81                  subprob.y[k] = prob.y[perm[j]];
82                  ++k;
83              }
84              for (j = end; j < l; j++) {
85                  subprob.x[k] = prob.x[perm[j]];
86                  subprob.y[k] = prob.y[perm[j]];
87                  ++k;
88              }
89              Model submodel = train(subprob, param);
90              for (j = begin; j < end; j++)
91                  target[perm[j]] = predict(submodel, prob.x[perm[j]]);
92          }
93      }
94  
95      /** used as complex return type */
96      private static class GroupClassesReturn {
97  
98          final int[] count;
99          final int[] label;
100         final int   nr_class;
101         final int[] start;
102 
103         GroupClassesReturn( int nr_class, int[] label, int[] start, int[] count ) {
104             this.nr_class = nr_class;
105             this.label = label;
106             this.start = start;
107             this.count = count;
108         }
109     }
110 
111     private static GroupClassesReturn groupClasses(Problem prob, int[] perm) {
112         int l = prob.l;
113         int max_nr_class = 16;
114         int nr_class = 0;
115 
116         int[] label = new int[max_nr_class];
117         int[] count = new int[max_nr_class];
118         int[] data_label = new int[l];
119         int i;
120 
121         for (i = 0; i < l; i++) {
122             int this_label = prob.y[i];
123             int j;
124             for (j = 0; j < nr_class; j++) {
125                 if (this_label == label[j]) {
126                     ++count[j];
127                     break;
128                 }
129             }
130             data_label[i] = j;
131             if (j == nr_class) {
132                 if (nr_class == max_nr_class) {
133                     max_nr_class *= 2;
134                     label = copyOf(label, max_nr_class);
135                     count = copyOf(count, max_nr_class);
136                 }
137                 label[nr_class] = this_label;
138                 count[nr_class] = 1;
139                 ++nr_class;
140             }
141         }
142 
143         int[] start = new int[nr_class];
144         start[0] = 0;
145         for (i = 1; i < nr_class; i++)
146             start[i] = start[i - 1] + count[i - 1];
147         for (i = 0; i < l; i++) {
148             perm[start[data_label[i]]] = i;
149             ++start[data_label[i]];
150         }
151         start[0] = 0;
152         for (i = 1; i < nr_class; i++)
153             start[i] = start[i - 1] + count[i - 1];
154 
155         return new GroupClassesReturn(nr_class, label, start, count);
156     }
157 
158     static void info(String message) {
159         synchronized (OUTPUT_MUTEX) {
160             if (DEBUG_OUTPUT == null) return;
161             DEBUG_OUTPUT.printf(message);
162             DEBUG_OUTPUT.flush();
163         }
164     }
165 
166     static void info(String format, Object... args) {
167         synchronized (OUTPUT_MUTEX) {
168             if (DEBUG_OUTPUT == null) return;
169             DEBUG_OUTPUT.printf(format, args);
170             DEBUG_OUTPUT.flush();
171         }
172     }
173 
174     /**
175      * @param s the string to parse for the double value
176      * @throws IllegalArgumentException if s is empty or represents NaN or Infinity
177      * @throws NumberFormatException see {@link Double#parseDouble(String)}
178      */
179     static double atof(String s) {
180         if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
181         double d = Double.parseDouble(s);
182         if (Double.isNaN(d) || Double.isInfinite(d)) {
183             throw new IllegalArgumentException("NaN or Infinity in input: " + s);
184         }
185         return (d);
186     }
187 
188     /**
189      * @param s the string to parse for the integer value
190      * @throws IllegalArgumentException if s is empty
191      * @throws NumberFormatException see {@link Integer#parseInt(String)}
192      */
193     static int atoi(String s) throws NumberFormatException {
194         if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
195         // Integer.parseInt doesn't accept '+' prefixed strings
196         if (s.charAt(0) == '+') s = s.substring(1);
197         return Integer.parseInt(s);
198     }
199 
200     /**
201      * Java5 'backport' of Arrays.copyOf
202      */
203     public static double[] copyOf(double[] original, int newLength) {
204         double[] copy = new double[newLength];
205         System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
206         return copy;
207     }
208 
209     /**
210      * Java5 'backport' of Arrays.copyOf
211      */
212     public static int[] copyOf(int[] original, int newLength) {
213         int[] copy = new int[newLength];
214         System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
215         return copy;
216     }
217 
218     /**
219      * Loads the model from inputReader.
220      * It uses {@link java.util.Locale#ENGLISH} for number formatting.
221      *
222      * <p><b>Note: The inputReader is closed after reading or in case of an exception.</b></p>
223      */
224     public static Model loadModel(Reader inputReader) throws IOException {
225         Model model = new Model();
226 
227         model.label = null;
228 
229         Pattern whitespace = Pattern.compile("\\s+");
230 
231         BufferedReader reader = null;
232         if (inputReader instanceof BufferedReader) {
233             reader = (BufferedReader)inputReader;
234         } else {
235             reader = new BufferedReader(inputReader);
236         }
237 
238         try {
239             String line = null;
240             while ((line = reader.readLine()) != null) {
241                 String[] split = whitespace.split(line);
242                 if (split[0].equals("solver_type")) {
243                     SolverType solver = SolverType.valueOf(split[1]);
244                     if (solver == null) {
245                         throw new RuntimeException("unknown solver type");
246                     }
247                     model.solverType = solver;
248                 } else if (split[0].equals("nr_class")) {
249                     model.nr_class = atoi(split[1]);
250                     Integer.parseInt(split[1]);
251                 } else if (split[0].equals("nr_feature")) {
252                     model.nr_feature = atoi(split[1]);
253                 } else if (split[0].equals("bias")) {
254                     model.bias = atof(split[1]);
255                 } else if (split[0].equals("w")) {
256                     break;
257                 } else if (split[0].equals("label")) {
258                     model.label = new int[model.nr_class];
259                     for (int i = 0; i < model.nr_class; i++) {
260                         model.label[i] = atoi(split[i + 1]);
261                     }
262                 } else {
263                     throw new RuntimeException("unknown text in model file: [" + line + "]");
264                 }
265             }
266 
267             int w_size = model.nr_feature;
268             if (model.bias >= 0) w_size++;
269 
270             int nr_w = model.nr_class;
271             if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
272 
273             model.w = new double[w_size * nr_w];
274             int[] buffer = new int[128];
275 
276             for (int i = 0; i < w_size; i++) {
277                 for (int j = 0; j < nr_w; j++) {
278                     int b = 0;
279                     while (true) {
280                         int ch = reader.read();
281                         if (ch == -1) {
282                             throw new EOFException("unexpected EOF");
283                         }
284                         if (ch == ' ') {
285                             model.w[i * nr_w + j] = atof(new String(buffer, 0, b));
286                             break;
287                         } else {
288                             buffer[b++] = ch;
289                         }
290                     }
291                 }
292             }
293         }
294         finally {
295             closeQuietly(reader);
296         }
297 
298         return model;
299     }
300 
301     /**
302      * Loads the model from the file with ISO-8859-1 charset.
303      * It uses {@link java.util.Locale#ENGLISH} for number formatting.
304      */
305     public static Model loadModel(File modelFile) throws IOException {
306         BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
307         return loadModel(inputReader);
308     }
309 
310     static void closeQuietly(Closeable c) {
311         if (c == null) return;
312         try {
313             c.close();
314         } catch (Throwable t) {}
315     }
316 
317     public static int predict(Model model, FeatureNode[] x) {
318         double[] dec_values = new double[model.nr_class];
319         return predictValues(model, x, dec_values);
320     }
321 
322     /**
323      * @throws IllegalArgumentException if model is not probabilistic (see {@link Model#isProbabilityModel()})
324      */
325     public static int predictProbability(Model model, FeatureNode[] x, double[] prob_estimates) throws IllegalArgumentException {
326         if (!model.isProbabilityModel()) {
327             throw new IllegalArgumentException("probability output is only supported for logistic regression");
328         }
329         int nr_class = model.nr_class;
330         int nr_w;
331         if (nr_class == 2)
332             nr_w = 1;
333         else
334             nr_w = nr_class;
335 
336         int label = predictValues(model, x, prob_estimates);
337         for (int i = 0; i < nr_w; i++)
338             prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i]));
339 
340         if (nr_class == 2) // for binary classification
341             prob_estimates[1] = 1. - prob_estimates[0];
342         else {
343             double sum = 0;
344             for (int i = 0; i < nr_class; i++)
345                 sum += prob_estimates[i];
346 
347             for (int i = 0; i < nr_class; i++)
348                 prob_estimates[i] = prob_estimates[i] / sum;
349         }
350 
351         return label;
352     }
353 
354     public static int predictValues(Model model, FeatureNode[] x, double[] dec_values) {
355         int n;
356         if (model.bias >= 0)
357             n = model.nr_feature + 1;
358         else
359             n = model.nr_feature;
360 
361         double[] w = model.w;
362 
363         int nr_w;
364         if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS)
365             nr_w = 1;
366         else
367             nr_w = model.nr_class;
368 
369         for (int i = 0; i < nr_w; i++)
370             dec_values[i] = 0;
371 
372         for (FeatureNode lx : x) {
373             int idx = lx.index;
374             // the dimension of testing data may exceed that of training
375             if (idx <= n) {
376                 for (int i = 0; i < nr_w; i++) {
377                     dec_values[i] += w[(idx - 1) * nr_w + i] * lx.value;
378                 }
379             }
380         }
381 
382         if (model.nr_class == 2)
383             return (dec_values[0] > 0) ? model.label[0] : model.label[1];
384         else {
385             int dec_max_idx = 0;
386             for (int i = 1; i < model.nr_class; i++) {
387                 if (dec_values[i] > dec_values[dec_max_idx]) dec_max_idx = i;
388             }
389             return model.label[dec_max_idx];
390         }
391     }
392 
393 
394     static void printf(Formatter formatter, String format, Object... args) throws IOException {
395         formatter.format(format, args);
396         IOException ioException = formatter.ioException();
397         if (ioException != null) throw ioException;
398     }
399 
400     /**
401      * Writes the model to the modelOutput.
402      * It uses {@link java.util.Locale#ENGLISH} for number formatting.
403      *
404      * <p><b>Note: The modelOutput is closed after reading or in case of an exception.</b></p>
405      */
406     public static void saveModel(Writer modelOutput, Model model) throws IOException {
407         int nr_feature = model.nr_feature;
408         int w_size = nr_feature;
409         if (model.bias >= 0) w_size++;
410 
411         int nr_w = model.nr_class;
412         if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
413 
414         Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE);
415         try {
416             printf(formatter, "solver_type %s\n", model.solverType.name());
417             printf(formatter, "nr_class %d\n", model.nr_class);
418 
419             printf(formatter, "label");
420             for (int i = 0; i < model.nr_class; i++) {
421                 printf(formatter, " %d", model.label[i]);
422             }
423             printf(formatter, "\n");
424 
425             printf(formatter, "nr_feature %d\n", nr_feature);
426             printf(formatter, "bias %.16g\n", model.bias);
427 
428             printf(formatter, "w\n");
429             for (int i = 0; i < w_size; i++) {
430                 for (int j = 0; j < nr_w; j++) {
431                     double value = model.w[i * nr_w + j];
432 
433                     /** this optimization is the reason for {@link Model#equals(double[], double[])} */
434                     if (value == 0.0) {
435                         printf(formatter, "%d ", 0);
436                     } else {
437                         printf(formatter, "%.16g ", value);
438                     }
439                 }
440                 printf(formatter, "\n");
441             }
442 
443             formatter.flush();
444             IOException ioException = formatter.ioException();
445             if (ioException != null) throw ioException;
446         }
447         finally {
448             formatter.close();
449         }
450     }
451 
452     /**
453      * Writes the model to the file with ISO-8859-1 charset.
454      * It uses {@link java.util.Locale#ENGLISH} for number formatting.
455      */
456     public static void saveModel(File modelFile, Model model) throws IOException {
457         BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), FILE_CHARSET));
458         saveModel(modelOutput, model);
459     }
460 
461     /*
462      * this method corresponds to the following define in the C version:
463      * #define GETI(i) (y[i]+1)
464      */
465     private static int GETI(byte[] y, int i) {
466         return y[i] + 1;
467     }
468 
469     /**
470      * A coordinate descent algorithm for
471      * L1-loss and L2-loss SVM dual problems
472      *<pre>
473      *  min_\alpha  0.5(\alpha^T (Q + D)\alpha) - e^T \alpha,
474      *    s.t.      0 <= alpha_i <= upper_bound_i,
475      *
476      *  where Qij = yi yj xi^T xj and
477      *  D is a diagonal matrix
478      *
479      * In L1-SVM case:
480      *     upper_bound_i = Cp if y_i = 1
481      *      upper_bound_i = Cn if y_i = -1
482      *      D_ii = 0
483      * In L2-SVM case:
484      *      upper_bound_i = INF
485      *      D_ii = 1/(2*Cp) if y_i = 1
486      *      D_ii = 1/(2*Cn) if y_i = -1
487      *
488      * Given:
489      * x, y, Cp, Cn
490      * eps is the stopping tolerance
491      *
492      * solution will be put in w
493      *
494      * See Algorithm 3 of Hsieh et al., ICML 2008
495      *</pre>
496      */
497     private static void solve_l2r_l1l2_svc(Problem prob, double[] w, double eps, double Cp, double Cn, SolverType solver_type) {
498         int l = prob.l;
499         int w_size = prob.n;
500         int i, s, iter = 0;
501         double C, d, G;
502         double[] QD = new double[l];
503         int max_iter = 1000;
504         int[] index = new int[l];
505         double[] alpha = new double[l];
506         byte[] y = new byte[l];
507         int active_size = l;
508 
509         // PG: projected gradient, for shrinking and stopping
510         double PG;
511         double PGmax_old = Double.POSITIVE_INFINITY;
512         double PGmin_old = Double.NEGATIVE_INFINITY;
513         double PGmax_new, PGmin_new;
514 
515         // default solver_type: L2R_L2LOSS_SVC_DUAL
516         double diag[] = new double[] {0.5 / Cn, 0, 0.5 / Cp};
517         double upper_bound[] = new double[] {Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY};
518         if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) {
519             diag[0] = 0;
520             diag[2] = 0;
521             upper_bound[0] = Cn;
522             upper_bound[2] = Cp;
523         }
524 
525         for (i = 0; i < w_size; i++)
526             w[i] = 0;
527         for (i = 0; i < l; i++) {
528             alpha[i] = 0;
529             if (prob.y[i] > 0) {
530                 y[i] = +1;
531             } else {
532                 y[i] = -1;
533             }
534             QD[i] = diag[GETI(y, i)];
535 
536             for (FeatureNode xi : prob.x[i]) {
537                 QD[i] += xi.value * xi.value;
538             }
539             index[i] = i;
540         }
541 
542         while (iter < max_iter) {
543             PGmax_new = Double.NEGATIVE_INFINITY;
544             PGmin_new = Double.POSITIVE_INFINITY;
545 
546             for (i = 0; i < active_size; i++) {
547                 int j = i + random.nextInt(active_size - i);
548                 swap(index, i, j);
549             }
550 
551             for (s = 0; s < active_size; s++) {
552                 i = index[s];
553                 G = 0;
554                 byte yi = y[i];
555 
556                 for (FeatureNode xi : prob.x[i]) {
557                     G += w[xi.index - 1] * xi.value;
558                 }
559                 G = G * yi - 1;
560 
561                 C = upper_bound[GETI(y, i)];
562                 G += alpha[i] * diag[GETI(y, i)];
563 
564                 PG = 0;
565                 if (alpha[i] == 0) {
566                     if (G > PGmax_old) {
567                         active_size--;
568                         swap(index, s, active_size);
569                         s--;
570                         continue;
571                     } else if (G < 0) {
572                         PG = G;
573                     }
574                 } else if (alpha[i] == C) {
575                     if (G < PGmin_old) {
576                         active_size--;
577                         swap(index, s, active_size);
578                         s--;
579                         continue;
580                     } else if (G > 0) {
581                         PG = G;
582                     }
583                 } else {
584                     PG = G;
585                 }
586 
587                 PGmax_new = Math.max(PGmax_new, PG);
588                 PGmin_new = Math.min(PGmin_new, PG);
589 
590                 if (Math.abs(PG) > 1.0e-12) {
591                     double alpha_old = alpha[i];
592                     alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C);
593                     d = (alpha[i] - alpha_old) * yi;
594 
595                     for (FeatureNode xi : prob.x[i]) {
596                         w[xi.index - 1] += d * xi.value;
597                     }
598                 }
599             }
600 
601             iter++;
602             if (iter % 10 == 0) info(".");
603 
604             if (PGmax_new - PGmin_new <= eps) {
605                 if (active_size == l)
606                     break;
607                 else {
608                     active_size = l;
609                     info("*");
610                     PGmax_old = Double.POSITIVE_INFINITY;
611                     PGmin_old = Double.NEGATIVE_INFINITY;
612                     continue;
613                 }
614             }
615             PGmax_old = PGmax_new;
616             PGmin_old = PGmin_new;
617             if (PGmax_old <= 0) PGmax_old = Double.POSITIVE_INFINITY;
618             if (PGmin_old >= 0) PGmin_old = Double.NEGATIVE_INFINITY;
619         }
620 
621         info(NL + "optimization finished, #iter = %d" + NL, iter);
622         if (iter >= max_iter) info("%nWARNING: reaching max number of iterations%nUsing -s 2 may be faster (also see FAQ)%n%n");
623 
624         // calculate objective value
625 
626         double v = 0;
627         int nSV = 0;
628         for (i = 0; i < w_size; i++)
629             v += w[i] * w[i];
630         for (i = 0; i < l; i++) {
631             v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2);
632             if (alpha[i] > 0) ++nSV;
633         }
634         info("Objective value = %f" + NL, v / 2);
635         info("nSV = %d" + NL, nSV);
636     }
637 
638     /**
639      * A coordinate descent algorithm for
640      * the dual of L2-regularized logistic regression problems
641      *<pre>
642      *  min_\alpha  0.5(\alpha^T Q \alpha) + \sum \alpha_i log (\alpha_i) + (upper_bound_i - alpha_i) log (upper_bound_i - alpha_i) ,
643      *     s.t.      0 <= alpha_i <= upper_bound_i,
644      *
645      *  where Qij = yi yj xi^T xj and
646      *  upper_bound_i = Cp if y_i = 1
647      *  upper_bound_i = Cn if y_i = -1
648      *
649      * Given:
650      * x, y, Cp, Cn
651      * eps is the stopping tolerance
652      *
653      * solution will be put in w
654      *
655      * See Algorithm 5 of Yu et al., MLJ 2010
656      *</pre>
657      *
658      * @since 1.7
659      */
660     private static void solve_l2r_lr_dual(Problem prob, double w[], double eps, double Cp, double Cn) {
661         int l = prob.l;
662         int w_size = prob.n;
663         int i, s, iter = 0;
664         double xTx[] = new double[l];
665         int max_iter = 1000;
666         int index[] = new int[l];
667         double alpha[] = new double[2 * l]; // store alpha and C - alpha
668         byte y[] = new byte[l];
669         int max_inner_iter = 100; // for inner Newton
670         double innereps = 1e-2;
671         double innereps_min = Math.min(1e-8, eps);
672         double upper_bound[] = new double[] {Cn, 0, Cp};
673 
674         for (i = 0; i < w_size; i++)
675             w[i] = 0;
676         for (i = 0; i < l; i++) {
677             if (prob.y[i] > 0) {
678                 y[i] = +1;
679             } else {
680                 y[i] = -1;
681             }
682             alpha[2 * i] = Math.min(0.001 * upper_bound[GETI(y, i)], 1e-8);
683             alpha[2 * i + 1] = upper_bound[GETI(y, i)] - alpha[2 * i];
684 
685             xTx[i] = 0;
686             for (FeatureNode xi : prob.x[i]) {
687                 xTx[i] += (xi.value) * (xi.value);
688                 w[xi.index - 1] += y[i] * alpha[2 * i] * xi.value;
689             }
690             index[i] = i;
691         }
692 
693         while (iter < max_iter) {
694             for (i = 0; i < l; i++) {
695                 int j = i + random.nextInt(l - i);
696                 swap(index, i, j);
697             }
698             int newton_iter = 0;
699             double Gmax = 0;
700             for (s = 0; s < l; s++) {
701                 i = index[s];
702                 byte yi = y[i];
703                 double C = upper_bound[GETI(y, i)];
704                 double ywTx = 0, xisq = xTx[i];
705                 for (FeatureNode xi : prob.x[i]) {
706                     ywTx += w[xi.index - 1] * xi.value;
707                 }
708                 ywTx *= y[i];
709                 double a = xisq, b = ywTx;
710 
711                 // Decide to minimize g_1(z) or g_2(z)
712                 int ind1 = 2 * i, ind2 = 2 * i + 1, sign = 1;
713                 if (0.5 * a * (alpha[ind2] - alpha[ind1]) + b < 0) {
714                     ind1 = 2 * i + 1;
715                     ind2 = 2 * i;
716                     sign = -1;
717                 }
718 
719                 //  g_t(z) = z*log(z) + (C-z)*log(C-z) + 0.5a(z-alpha_old)^2 + sign*b(z-alpha_old)
720                 double alpha_old = alpha[ind1];
721                 double z = alpha_old;
722                 if (C - z < 0.5 * C) z = 0.1 * z;
723                 double gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z));
724                 Gmax = Math.max(Gmax, Math.abs(gp));
725 
726                 // Newton method on the sub-problem
727                 final double eta = 0.1; // xi in the paper
728                 int inner_iter = 0;
729                 while (inner_iter <= max_inner_iter) {
730                     if (Math.abs(gp) < innereps) break;
731                     double gpp = a + C / (C - z) / z;
732                     double tmpz = z - gp / gpp;
733                     if (tmpz <= 0)
734                         z *= eta;
735                     else
736                         // tmpz in (0, C)
737                         z = tmpz;
738                     gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z));
739                     newton_iter++;
740                     inner_iter++;
741                 }
742 
743                 if (inner_iter > 0) // update w
744                 {
745                     alpha[ind1] = z;
746                     alpha[ind2] = C - z;
747                     for (FeatureNode xi : prob.x[i]) {
748                         w[xi.index - 1] += sign * (z - alpha_old) * yi * xi.value;
749                     }
750                 }
751             }
752 
753             iter++;
754             if (iter % 10 == 0) info(".");
755 
756             if (Gmax < eps) break;
757 
758             if (newton_iter <= l / 10) {
759                 innereps = Math.max(innereps_min, 0.1 * innereps);
760             }
761 
762         }
763 
764         info("%noptimization finished, #iter = %d%n", iter);
765         if (iter >= max_iter) info("%nWARNING: reaching max number of iterations%nUsing -s 0 may be faster (also see FAQ)%n%n");
766 
767         // calculate objective value
768 
769         double v = 0;
770         for (i = 0; i < w_size; i++)
771             v += w[i] * w[i];
772         v *= 0.5;
773         for (i = 0; i < l; i++)
774             v += alpha[2 * i] * Math.log(alpha[2 * i]) + alpha[2 * i + 1] * Math.log(alpha[2 * i + 1]) - upper_bound[GETI(y, i)]
775                 * Math.log(upper_bound[GETI(y, i)]);
776         info("Objective value = %f%n", v);
777     }
778 
779     /**
780      * A coordinate descent algorithm for
781      * L1-regularized L2-loss support vector classification
782      *
783      *<pre>
784      *  min_w \sum |wj| + C \sum max(0, 1-yi w^T xi)^2,
785      *
786      * Given:
787      * x, y, Cp, Cn
788      * eps is the stopping tolerance
789      *
790      * solution will be put in w
791      *
792      * See Yuan et al. (2010) and appendix of LIBLINEAR paper, Fan et al. (2008)
793      *</pre>
794      *
795      * @since 1.5
796      */
797     private static void solve_l1r_l2_svc(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
798         int l = prob_col.l;
799         int w_size = prob_col.n;
800         int j, s, iter = 0;
801         int max_iter = 1000;
802         int active_size = w_size;
803         int max_num_linesearch = 20;
804 
805         double sigma = 0.01;
806         double d, G_loss, G, H;
807         double Gmax_old = Double.POSITIVE_INFINITY;
808         double Gmax_new, Gnorm1_new;
809         double Gnorm1_init = 0; // eclipse moans this variable might not be initialized
810         double d_old, d_diff;
811         double loss_old = 0; // eclipse moans this variable might not be initialized
812         double loss_new;
813         double appxcond, cond;
814 
815         int[] index = new int[w_size];
816         byte[] y = new byte[l];
817         double[] b = new double[l]; // b = 1-ywTx
818         double[] xj_sq = new double[w_size];
819 
820         double[] C = new double[] {Cn, 0, Cp};
821 
822         for (j = 0; j < l; j++) {
823             b[j] = 1;
824             if (prob_col.y[j] > 0)
825                 y[j] = 1;
826             else
827                 y[j] = -1;
828         }
829         for (j = 0; j < w_size; j++) {
830             w[j] = 0;
831             index[j] = j;
832             xj_sq[j] = 0;
833             for (FeatureNode xi : prob_col.x[j]) {
834                 int ind = xi.index - 1;
835                 double val = xi.value;
836                 xi.value *= y[ind]; // x->value stores yi*xij
837                 xj_sq[j] += C[GETI(y, ind)] * val * val;
838             }
839         }
840 
841         while (iter < max_iter) {
842             Gmax_new = 0;
843             Gnorm1_new = 0;
844 
845             for (j = 0; j < active_size; j++) {
846                 int i = j + random.nextInt(active_size - j);
847                 swap(index, i, j);
848             }
849 
850             for (s = 0; s < active_size; s++) {
851                 j = index[s];
852                 G_loss = 0;
853                 H = 0;
854 
855                 for (FeatureNode xi : prob_col.x[j]) {
856                     int ind = xi.index - 1;
857                     if (b[ind] > 0) {
858                         double val = xi.value;
859                         double tmp = C[GETI(y, ind)] * val;
860                         G_loss -= tmp * b[ind];
861                         H += tmp * val;
862                     }
863                 }
864                 G_loss *= 2;
865 
866                 G = G_loss;
867                 H *= 2;
868                 H = Math.max(H, 1e-12);
869 
870                 double Gp = G + 1;
871                 double Gn = G - 1;
872                 double violation = 0;
873                 if (w[j] == 0) {
874                     if (Gp < 0)
875                         violation = -Gp;
876                     else if (Gn > 0)
877                         violation = Gn;
878                     else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
879                         active_size--;
880                         swap(index, s, active_size);
881                         s--;
882                         continue;
883                     }
884                 } else if (w[j] > 0)
885                     violation = Math.abs(Gp);
886                 else
887                     violation = Math.abs(Gn);
888 
889                 Gmax_new = Math.max(Gmax_new, violation);
890                 Gnorm1_new += violation;
891 
892                 // obtain Newton direction d
893                 if (Gp <= H * w[j])
894                     d = -Gp / H;
895                 else if (Gn >= H * w[j])
896                     d = -Gn / H;
897                 else
898                     d = -w[j];
899 
900                 if (Math.abs(d) < 1.0e-12) continue;
901 
902                 double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d;
903                 d_old = 0;
904                 int num_linesearch;
905                 for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
906                     d_diff = d_old - d;
907                     cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta;
908 
909                     appxcond = xj_sq[j] * d * d + G_loss * d + cond;
910                     if (appxcond <= 0) {
911                         for (FeatureNode x : prob_col.x[j]) {
912                             b[x.index - 1] += d_diff * x.value;
913                         }
914                         break;
915                     }
916 
917                     if (num_linesearch == 0) {
918                         loss_old = 0;
919                         loss_new = 0;
920                         for (FeatureNode x : prob_col.x[j]) {
921                             int ind = x.index - 1;
922                             if (b[ind] > 0) {
923                                 loss_old += C[GETI(y, ind)] * b[ind] * b[ind];
924                             }
925                             double b_new = b[ind] + d_diff * x.value;
926                             b[ind] = b_new;
927                             if (b_new > 0) {
928                                 loss_new += C[GETI(y, ind)] * b_new * b_new;
929                             }
930                         }
931                     } else {
932                         loss_new = 0;
933                         for (FeatureNode x : prob_col.x[j]) {
934                             int ind = x.index - 1;
935                             double b_new = b[ind] + d_diff * x.value;
936                             b[ind] = b_new;
937                             if (b_new > 0) {
938                                 loss_new += C[GETI(y, ind)] * b_new * b_new;
939                             }
940                         }
941                     }
942 
943                     cond = cond + loss_new - loss_old;
944                     if (cond <= 0)
945                         break;
946                     else {
947                         d_old = d;
948                         d *= 0.5;
949                         delta *= 0.5;
950                     }
951                 }
952 
953                 w[j] += d;
954 
955                 // recompute b[] if line search takes too many steps
956                 if (num_linesearch >= max_num_linesearch) {
957                     info("#");
958                     for (int i = 0; i < l; i++)
959                         b[i] = 1;
960 
961                     for (int i = 0; i < w_size; i++) {
962                         if (w[i] == 0) continue;
963                         for (FeatureNode x : prob_col.x[i]) {
964                             b[x.index - 1] -= w[i] * x.value;
965                         }
966                     }
967                 }
968             }
969 
970             if (iter == 0) {
971                 Gnorm1_init = Gnorm1_new;
972             }
973             iter++;
974             if (iter % 10 == 0) info(".");
975 
976             if (Gmax_new <= eps * Gnorm1_init) {
977                 if (active_size == w_size)
978                     break;
979                 else {
980                     active_size = w_size;
981                     info("*");
982                     Gmax_old = Double.POSITIVE_INFINITY;
983                     continue;
984                 }
985             }
986 
987             Gmax_old = Gmax_new;
988         }
989 
990         info("%noptimization finished, #iter = %d%n", iter);
991         if (iter >= max_iter) info("%nWARNING: reaching max number of iterations%n");
992 
993         // calculate objective value
994 
995         double v = 0;
996         int nnz = 0;
997         for (j = 0; j < w_size; j++) {
998             for (FeatureNode x : prob_col.x[j]) {
999                 x.value *= prob_col.y[x.index - 1]; // restore x->value
1000             }
1001             if (w[j] != 0) {
1002                 v += Math.abs(w[j]);
1003                 nnz++;
1004             }
1005         }
1006         for (j = 0; j < l; j++)
1007             if (b[j] > 0) v += C[GETI(y, j)] * b[j] * b[j];
1008 
1009         info("Objective value = %f%n", v);
1010         info("#nonzeros/#features = %d/%d%n", nnz, w_size);
1011     }
1012 
1013     /**
1014      * A coordinate descent algorithm for
1015      * L1-regularized logistic regression problems
1016      *
1017      *<pre>
1018      *  min_w \sum |wj| + C \sum log(1+exp(-yi w^T xi)),
1019      *
1020      * Given:
1021      * x, y, Cp, Cn
1022      * eps is the stopping tolerance
1023      *
1024      * solution will be put in w
1025      *
1026      * See Yuan et al. (2011) and appendix of LIBLINEAR paper, Fan et al. (2008)
1027      *</pre>
1028      *
1029      * @since 1.5
1030      */
1031     private static void solve_l1r_lr(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
1032         int l = prob_col.l;
1033         int w_size = prob_col.n;
1034         int j, s, newton_iter = 0, iter = 0;
1035         int max_newton_iter = 100;
1036         int max_iter = 1000;
1037         int max_num_linesearch = 20;
1038         int active_size;
1039         int QP_active_size;
1040 
1041         double nu = 1e-12;
1042         double inner_eps = 1;
1043         double sigma = 0.01;
1044         double w_norm = 0, w_norm_new;
1045         double z, G, H;
1046         double Gnorm1_init = 0; // eclipse moans this variable might not be initialized
1047         double Gmax_old = Double.POSITIVE_INFINITY;
1048         double Gmax_new, Gnorm1_new;
1049         double QP_Gmax_old = Double.POSITIVE_INFINITY;
1050         double QP_Gmax_new, QP_Gnorm1_new;
1051         double delta, negsum_xTd, cond;
1052 
1053         int[] index = new int[w_size];
1054         byte[] y = new byte[l];
1055         double[] Hdiag = new double[w_size];
1056         double[] Grad = new double[w_size];
1057         double[] wpd = new double[w_size];
1058         double[] xjneg_sum = new double[w_size];
1059         double[] xTd = new double[l];
1060         double[] exp_wTx = new double[l];
1061         double[] exp_wTx_new = new double[l];
1062         double[] tau = new double[l];
1063         double[] D = new double[l];
1064 
1065         double[] C = {Cn, 0, Cp};
1066 
1067         for (j = 0; j < l; j++) {
1068             if (prob_col.y[j] > 0)
1069                 y[j] = 1;
1070             else
1071                 y[j] = -1;
1072 
1073             // assume initial w is 0
1074             exp_wTx[j] = 1;
1075             tau[j] = C[GETI(y, j)] * 0.5;
1076             D[j] = C[GETI(y, j)] * 0.25;
1077         }
1078         for (j = 0; j < w_size; j++) {
1079             w[j] = 0;
1080             wpd[j] = w[j];
1081             index[j] = j;
1082             xjneg_sum[j] = 0;
1083             for (FeatureNode x : prob_col.x[j]) {
1084                 int ind = x.index - 1;
1085                 if (y[ind] == -1) xjneg_sum[j] += C[GETI(y, ind)] * x.value;
1086             }
1087         }
1088 
1089         while (newton_iter < max_newton_iter) {
1090             Gmax_new = 0;
1091             Gnorm1_new = 0;
1092             active_size = w_size;
1093 
1094             for (s = 0; s < active_size; s++) {
1095                 j = index[s];
1096                 Hdiag[j] = nu;
1097                 Grad[j] = 0;
1098 
1099                 double tmp = 0;
1100                 for (FeatureNode x : prob_col.x[j]) {
1101                     int ind = x.index - 1;
1102                     Hdiag[j] += x.value * x.value * D[ind];
1103                     tmp += x.value * tau[ind];
1104                 }
1105                 Grad[j] = -tmp + xjneg_sum[j];
1106 
1107                 double Gp = Grad[j] + 1;
1108                 double Gn = Grad[j] - 1;
1109                 double violation = 0;
1110                 if (w[j] == 0) {
1111                     if (Gp < 0)
1112                         violation = -Gp;
1113                     else if (Gn > 0)
1114                         violation = Gn;
1115                     //outer-level shrinking
1116                     else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
1117                         active_size--;
1118                         swap(index, s, active_size);
1119                         s--;
1120                         continue;
1121                     }
1122                 } else if (w[j] > 0)
1123                     violation = Math.abs(Gp);
1124                 else
1125                     violation = Math.abs(Gn);
1126 
1127                 Gmax_new = Math.max(Gmax_new, violation);
1128                 Gnorm1_new += violation;
1129             }
1130 
1131             if (newton_iter == 0) Gnorm1_init = Gnorm1_new;
1132 
1133             if (Gnorm1_new <= eps * Gnorm1_init) break;
1134 
1135             iter = 0;
1136             QP_Gmax_old = Double.POSITIVE_INFINITY;
1137             QP_active_size = active_size;
1138 
1139             for (int i = 0; i < l; i++)
1140                 xTd[i] = 0;
1141 
1142             // optimize QP over wpd
1143             while (iter < max_iter) {
1144                 QP_Gmax_new = 0;
1145                 QP_Gnorm1_new = 0;
1146 
1147                 for (j = 0; j < QP_active_size; j++) {
1148                     int i = random.nextInt(QP_active_size - j);
1149                     swap(index, i, j);
1150                 }
1151 
1152                 for (s = 0; s < QP_active_size; s++) {
1153                     j = index[s];
1154                     H = Hdiag[j];
1155 
1156                     G = Grad[j] + (wpd[j] - w[j]) * nu;
1157                     for (FeatureNode x : prob_col.x[j]) {
1158                         int ind = x.index - 1;
1159                         G += x.value * D[ind] * xTd[ind];
1160                     }
1161 
1162                     double Gp = G + 1;
1163                     double Gn = G - 1;
1164                     double violation = 0;
1165                     if (wpd[j] == 0) {
1166                         if (Gp < 0)
1167                             violation = -Gp;
1168                         else if (Gn > 0)
1169                             violation = Gn;
1170                         //inner-level shrinking
1171                         else if (Gp > QP_Gmax_old / l && Gn < -QP_Gmax_old / l) {
1172                             QP_active_size--;
1173                             swap(index, s, QP_active_size);
1174                             s--;
1175                             continue;
1176                         }
1177                     } else if (wpd[j] > 0)
1178                         violation = Math.abs(Gp);
1179                     else
1180                         violation = Math.abs(Gn);
1181 
1182                     QP_Gmax_new = Math.max(QP_Gmax_new, violation);
1183                     QP_Gnorm1_new += violation;
1184 
1185                     // obtain solution of one-variable problem
1186                     if (Gp <= H * wpd[j])
1187                         z = -Gp / H;
1188                     else if (Gn >= H * wpd[j])
1189                         z = -Gn / H;
1190                     else
1191                         z = -wpd[j];
1192 
1193                     if (Math.abs(z) < 1.0e-12) continue;
1194                     z = Math.min(Math.max(z, -10.0), 10.0);
1195 
1196                     wpd[j] += z;
1197 
1198                     for (FeatureNode x : prob_col.x[j]) {
1199                         int ind = x.index - 1;
1200                         xTd[ind] += x.value * z;
1201                     }
1202                 }
1203 
1204                 iter++;
1205 
1206                 if (QP_Gnorm1_new <= inner_eps * Gnorm1_init) {
1207                     //inner stopping
1208                     if (QP_active_size == active_size)
1209                         break;
1210                     //active set reactivation
1211                     else {
1212                         QP_active_size = active_size;
1213                         QP_Gmax_old = Double.POSITIVE_INFINITY;
1214                         continue;
1215                     }
1216                 }
1217 
1218                 QP_Gmax_old = QP_Gmax_new;
1219             }
1220 
1221             if (iter >= max_iter) info("WARNING: reaching max number of inner iterations\n");
1222 
1223             delta = 0;
1224             w_norm_new = 0;
1225             for (j = 0; j < w_size; j++) {
1226                 delta += Grad[j] * (wpd[j] - w[j]);
1227                 if (wpd[j] != 0) w_norm_new += Math.abs(wpd[j]);
1228             }
1229             delta += (w_norm_new - w_norm);
1230 
1231             negsum_xTd = 0;
1232             for (int i = 0; i < l; i++)
1233                 if (y[i] == -1) negsum_xTd += C[GETI(y, i)] * xTd[i];
1234 
1235             int num_linesearch;
1236             for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
1237                 cond = w_norm_new - w_norm + negsum_xTd - sigma * delta;
1238 
1239                 for (int i = 0; i < l; i++) {
1240                     double exp_xTd = Math.exp(xTd[i]);
1241                     exp_wTx_new[i] = exp_wTx[i] * exp_xTd;
1242                     cond += C[GETI(y, i)] * Math.log((1 + exp_wTx_new[i]) / (exp_xTd + exp_wTx_new[i]));
1243                 }
1244 
1245                 if (cond <= 0) {
1246                     w_norm = w_norm_new;
1247                     for (j = 0; j < w_size; j++)
1248                         w[j] = wpd[j];
1249                     for (int i = 0; i < l; i++) {
1250                         exp_wTx[i] = exp_wTx_new[i];
1251                         double tau_tmp = 1 / (1 + exp_wTx[i]);
1252                         tau[i] = C[GETI(y, i)] * tau_tmp;
1253                         D[i] = C[GETI(y, i)] * exp_wTx[i] * tau_tmp * tau_tmp;
1254                     }
1255                     break;
1256                 } else {
1257                     w_norm_new = 0;
1258                     for (j = 0; j < w_size; j++) {
1259                         wpd[j] = (w[j] + wpd[j]) * 0.5;
1260                         if (wpd[j] != 0) w_norm_new += Math.abs(wpd[j]);
1261                     }
1262                     delta *= 0.5;
1263                     negsum_xTd *= 0.5;
1264                     for (int i = 0; i < l; i++)
1265                         xTd[i] *= 0.5;
1266                 }
1267             }
1268 
1269             // Recompute some info due to too many line search steps
1270             if (num_linesearch >= max_num_linesearch) {
1271                 for (int i = 0; i < l; i++)
1272                     exp_wTx[i] = 0;
1273 
1274                 for (int i = 0; i < w_size; i++) {
1275                     if (w[i] == 0) continue;
1276                     for (FeatureNode x : prob_col.x[i]) {
1277                         exp_wTx[x.index - 1] += w[i] * x.value;
1278                     }
1279                 }
1280 
1281                 for (int i = 0; i < l; i++)
1282                     exp_wTx[i] = Math.exp(exp_wTx[i]);
1283             }
1284 
1285             if (iter == 1) inner_eps *= 0.25;
1286 
1287             newton_iter++;
1288             Gmax_old = Gmax_new;
1289 
1290             info("iter %3d  #CD cycles %d%n", newton_iter, iter);
1291         }
1292 
1293         info("=========================%n");
1294         info("optimization finished, #iter = %d%n", newton_iter);
1295         if (newton_iter >= max_newton_iter) info("WARNING: reaching max number of iterations%n");
1296 
1297         // calculate objective value
1298 
1299         double v = 0;
1300         int nnz = 0;
1301         for (j = 0; j < w_size; j++)
1302             if (w[j] != 0) {
1303                 v += Math.abs(w[j]);
1304                 nnz++;
1305             }
1306         for (j = 0; j < l; j++)
1307             if (y[j] == 1)
1308                 v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]);
1309             else
1310                 v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]);
1311 
1312         info("Objective value = %f%n", v);
1313         info("#nonzeros/#features = %d/%d%n", nnz, w_size);
1314     }
1315 
1316     // transpose matrix X from row format to column format
1317     static Problem transpose(Problem prob) {
1318         int l = prob.l;
1319         int n = prob.n;
1320         int[] col_ptr = new int[n + 1];
1321         Problem prob_col = new Problem();
1322         prob_col.l = l;
1323         prob_col.n = n;
1324         prob_col.y = new int[l];
1325         prob_col.x = new FeatureNode[n][];
1326 
1327         for (int i = 0; i < l; i++)
1328             prob_col.y[i] = prob.y[i];
1329 
1330         for (int i = 0; i < l; i++) {
1331             for (FeatureNode x : prob.x[i]) {
1332                 col_ptr[x.index]++;
1333             }
1334         }
1335 
1336         for (int i = 0; i < n; i++) {
1337             prob_col.x[i] = new FeatureNode[col_ptr[i + 1]];
1338             col_ptr[i] = 0; // reuse the array to count the nr of elements
1339         }
1340 
1341         for (int i = 0; i < l; i++) {
1342             for (int j = 0; j < prob.x[i].length; j++) {
1343                 FeatureNode x = prob.x[i][j];
1344                 int index = x.index - 1;
1345                 prob_col.x[index][col_ptr[index]] = new FeatureNode(i + 1, x.value);
1346                 col_ptr[index]++;
1347             }
1348         }
1349 
1350         return prob_col;
1351     }
1352 
1353     static void swap(double[] array, int idxA, int idxB) {
1354         double temp = array[idxA];
1355         array[idxA] = array[idxB];
1356         array[idxB] = temp;
1357     }
1358 
1359     static void swap(int[] array, int idxA, int idxB) {
1360         int temp = array[idxA];
1361         array[idxA] = array[idxB];
1362         array[idxB] = temp;
1363     }
1364 
1365     static void swap(IntArrayPointer array, int idxA, int idxB) {
1366         int temp = array.get(idxA);
1367         array.set(idxA, array.get(idxB));
1368         array.set(idxB, temp);
1369     }
1370 
1371     /**
1372      * @throws IllegalArgumentException if the feature nodes of prob are not sorted in ascending order
1373      */
1374     public static Model train(Problem prob, Parameter param) {
1375 
1376         if (prob == null) throw new IllegalArgumentException("problem must not be null");
1377         if (param == null) throw new IllegalArgumentException("parameter must not be null");
1378 
1379         for (FeatureNode[] nodes : prob.x) {
1380             int indexBefore = 0;
1381             for (FeatureNode n : nodes) {
1382                 if (n.index <= indexBefore) {
1383                     throw new IllegalArgumentException("feature nodes must be sorted by index in ascending order");
1384                 }
1385                 indexBefore = n.index;
1386             }
1387         }
1388 
1389         int i, j;
1390         int l = prob.l;
1391         int n = prob.n;
1392         int w_size = prob.n;
1393         Model model = new Model();
1394 
1395         if (prob.bias >= 0)
1396             model.nr_feature = n - 1;
1397         else
1398             model.nr_feature = n;
1399         model.solverType = param.solverType;
1400         model.bias = prob.bias;
1401 
1402         int[] perm = new int[l];
1403         // group training data of the same class
1404         GroupClassesReturn rv = groupClasses(prob, perm);
1405         int nr_class = rv.nr_class;
1406         int[] label = rv.label;
1407         int[] start = rv.start;
1408         int[] count = rv.count;
1409 
1410         model.nr_class = nr_class;
1411         model.label = new int[nr_class];
1412         for (i = 0; i < nr_class; i++)
1413             model.label[i] = label[i];
1414 
1415         // calculate weighted C
1416         double[] weighted_C = new double[nr_class];
1417         for (i = 0; i < nr_class; i++) {
1418             weighted_C[i] = param.C;
1419         }
1420 
1421         for (i = 0; i < param.getNumWeights(); i++) {
1422             for (j = 0; j < nr_class; j++)
1423                 if (param.weightLabel[i] == label[j]) break;
1424             if (j == nr_class) throw new IllegalArgumentException("class label " + param.weightLabel[i] + " specified in weight is not found");
1425 
1426             weighted_C[j] *= param.weight[i];
1427         }
1428 
1429         // constructing the subproblem
1430         FeatureNode[][] x = new FeatureNode[l][];
1431         for (i = 0; i < l; i++)
1432             x[i] = prob.x[perm[i]];
1433 
1434         int k;
1435         Problem sub_prob = new Problem();
1436         sub_prob.l = l;
1437         sub_prob.n = n;
1438         sub_prob.x = new FeatureNode[sub_prob.l][];
1439         sub_prob.y = new int[sub_prob.l];
1440 
1441         for (k = 0; k < sub_prob.l; k++)
1442             sub_prob.x[k] = x[k];
1443 
1444         // multi-class svm by Crammer and Singer
1445         if (param.solverType == SolverType.MCSVM_CS) {
1446             model.w = new double[n * nr_class];
1447             for (i = 0; i < nr_class; i++) {
1448                 for (j = start[i]; j < start[i] + count[i]; j++) {
1449                     sub_prob.y[j] = i;
1450                 }
1451             }
1452 
1453             SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps);
1454             solver.solve(model.w);
1455         } else {
1456             if (nr_class == 2) {
1457                 model.w = new double[w_size];
1458 
1459                 int e0 = start[0] + count[0];
1460                 k = 0;
1461                 for (; k < e0; k++)
1462                     sub_prob.y[k] = +1;
1463                 for (; k < sub_prob.l; k++)
1464                     sub_prob.y[k] = -1;
1465 
1466                 train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]);
1467             } else {
1468                 model.w = new double[w_size * nr_class];
1469                 double[] w = new double[w_size];
1470                 for (i = 0; i < nr_class; i++) {
1471                     int si = start[i];
1472                     int ei = si + count[i];
1473 
1474                     k = 0;
1475                     for (; k < si; k++)
1476                         sub_prob.y[k] = -1;
1477                     for (; k < ei; k++)
1478                         sub_prob.y[k] = +1;
1479                     for (; k < sub_prob.l; k++)
1480                         sub_prob.y[k] = -1;
1481 
1482                     train_one(sub_prob, param, w, weighted_C[i], param.C);
1483 
1484                     for (j = 0; j < n; j++)
1485                         model.w[j * nr_class + i] = w[j];
1486                 }
1487             }
1488 
1489         }
1490         return model;
1491     }
1492 
1493     private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) {
1494         double eps = param.eps;
1495         int pos = 0;
1496         for (int i = 0; i < prob.l; i++)
1497             if (prob.y[i] == +1) pos++;
1498         int neg = prob.l - pos;
1499 
1500         Function fun_obj = null;
1501         switch (param.solverType) {
1502             case L2R_LR: {
1503                 fun_obj = new L2R_LrFunction(prob, Cp, Cn);
1504                 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1505                 tron_obj.tron(w);
1506                 break;
1507             }
1508             case L2R_L2LOSS_SVC: {
1509                 fun_obj = new L2R_L2_SvcFunction(prob, Cp, Cn);
1510                 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1511                 tron_obj.tron(w);
1512                 break;
1513             }
1514             case L2R_L2LOSS_SVC_DUAL:
1515                 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL);
1516                 break;
1517             case L2R_L1LOSS_SVC_DUAL:
1518                 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL);
1519                 break;
1520             case L1R_L2LOSS_SVC: {
1521                 Problem prob_col = transpose(prob);
1522                 solve_l1r_l2_svc(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1523                 break;
1524             }
1525             case L1R_LR: {
1526                 Problem prob_col = transpose(prob);
1527                 solve_l1r_lr(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1528                 break;
1529             }
1530             case L2R_LR_DUAL:
1531                 solve_l2r_lr_dual(prob, w, eps, Cp, Cn);
1532                 break;
1533             default:
1534                 throw new IllegalStateException("unknown solver type: " + param.solverType);
1535         }
1536     }
1537 
1538     public static void disableDebugOutput() {
1539         setDebugOutput(null);
1540     }
1541 
1542     public static void enableDebugOutput() {
1543         setDebugOutput(System.out);
1544     }
1545 
1546     public static void setDebugOutput(PrintStream debugOutput) {
1547         synchronized (OUTPUT_MUTEX) {
1548             DEBUG_OUTPUT = debugOutput;
1549         }
1550     }
1551 
1552     /**
1553      * resets the PRNG
1554      *
1555      * this is i.a. needed for regression testing (eg. the Weka wrapper)
1556      */
1557     public static void resetRandom() {
1558         random = new Random(DEFAULT_RANDOM_SEED);
1559     }
1560 }