1 package liblinear;
2
3 import java.io.BufferedReader;
4 import java.io.BufferedWriter;
5 import java.io.Closeable;
6 import java.io.EOFException;
7 import java.io.File;
8 import java.io.FileInputStream;
9 import java.io.FileOutputStream;
10 import java.io.IOException;
11 import java.io.InputStreamReader;
12 import java.io.OutputStreamWriter;
13 import java.io.PrintStream;
14 import java.io.Reader;
15 import java.io.Writer;
16 import java.nio.charset.Charset;
17 import java.util.Formatter;
18 import java.util.Locale;
19 import java.util.Random;
20 import java.util.regex.Pattern;
21
22
23
24
25
26
27
28
29
30
31
32
33 public class Linear {
34
35 static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
36
37 static final Locale DEFAULT_LOCALE = Locale.ENGLISH;
38
39 private static Object OUTPUT_MUTEX = new Object();
40 private static PrintStream DEBUG_OUTPUT = System.out;
41
42
43 final static String NL = System.getProperty("line.separator");
44
45 private static final long DEFAULT_RANDOM_SEED = 0L;
46 static Random random = new Random(DEFAULT_RANDOM_SEED);
47
48
49
50
51 public static void crossValidation(Problem prob, Parameter param, int nr_fold, int[] target) {
52 int i;
53 int[] fold_start = new int[nr_fold + 1];
54 int l = prob.l;
55 int[] perm = new int[l];
56
57 for (i = 0; i < l; i++)
58 perm[i] = i;
59 for (i = 0; i < l; i++) {
60 int j = i + random.nextInt(l - i);
61 swap(perm, i, j);
62 }
63 for (i = 0; i <= nr_fold; i++)
64 fold_start[i] = i * l / nr_fold;
65
66 for (i = 0; i < nr_fold; i++) {
67 int begin = fold_start[i];
68 int end = fold_start[i + 1];
69 int j, k;
70 Problem subprob = new Problem();
71
72 subprob.bias = prob.bias;
73 subprob.n = prob.n;
74 subprob.l = l - (end - begin);
75 subprob.x = new FeatureNode[subprob.l][];
76 subprob.y = new int[subprob.l];
77
78 k = 0;
79 for (j = 0; j < begin; j++) {
80 subprob.x[k] = prob.x[perm[j]];
81 subprob.y[k] = prob.y[perm[j]];
82 ++k;
83 }
84 for (j = end; j < l; j++) {
85 subprob.x[k] = prob.x[perm[j]];
86 subprob.y[k] = prob.y[perm[j]];
87 ++k;
88 }
89 Model submodel = train(subprob, param);
90 for (j = begin; j < end; j++)
91 target[perm[j]] = predict(submodel, prob.x[perm[j]]);
92 }
93 }
94
95
96 private static class GroupClassesReturn {
97
98 final int[] count;
99 final int[] label;
100 final int nr_class;
101 final int[] start;
102
103 GroupClassesReturn( int nr_class, int[] label, int[] start, int[] count ) {
104 this.nr_class = nr_class;
105 this.label = label;
106 this.start = start;
107 this.count = count;
108 }
109 }
110
111 private static GroupClassesReturn groupClasses(Problem prob, int[] perm) {
112 int l = prob.l;
113 int max_nr_class = 16;
114 int nr_class = 0;
115
116 int[] label = new int[max_nr_class];
117 int[] count = new int[max_nr_class];
118 int[] data_label = new int[l];
119 int i;
120
121 for (i = 0; i < l; i++) {
122 int this_label = prob.y[i];
123 int j;
124 for (j = 0; j < nr_class; j++) {
125 if (this_label == label[j]) {
126 ++count[j];
127 break;
128 }
129 }
130 data_label[i] = j;
131 if (j == nr_class) {
132 if (nr_class == max_nr_class) {
133 max_nr_class *= 2;
134 label = copyOf(label, max_nr_class);
135 count = copyOf(count, max_nr_class);
136 }
137 label[nr_class] = this_label;
138 count[nr_class] = 1;
139 ++nr_class;
140 }
141 }
142
143 int[] start = new int[nr_class];
144 start[0] = 0;
145 for (i = 1; i < nr_class; i++)
146 start[i] = start[i - 1] + count[i - 1];
147 for (i = 0; i < l; i++) {
148 perm[start[data_label[i]]] = i;
149 ++start[data_label[i]];
150 }
151 start[0] = 0;
152 for (i = 1; i < nr_class; i++)
153 start[i] = start[i - 1] + count[i - 1];
154
155 return new GroupClassesReturn(nr_class, label, start, count);
156 }
157
158 static void info(String message) {
159 synchronized (OUTPUT_MUTEX) {
160 if (DEBUG_OUTPUT == null) return;
161 DEBUG_OUTPUT.print(message);
162 DEBUG_OUTPUT.flush();
163 }
164 }
165
166 static void info(String format, Object... args) {
167 synchronized (OUTPUT_MUTEX) {
168 if (DEBUG_OUTPUT == null) return;
169 DEBUG_OUTPUT.printf(format, args);
170 DEBUG_OUTPUT.flush();
171 }
172 }
173
174
175
176
177
178
179 static double atof(String s) {
180 if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
181 double d = Double.parseDouble(s);
182 if (Double.isNaN(d) || Double.isInfinite(d)) {
183 throw new IllegalArgumentException("NaN or Infinity in input: " + s);
184 }
185 return (d);
186 }
187
188
189
190
191
192
193 static int atoi(String s) throws NumberFormatException {
194 if (s == null || s.length() < 1) throw new IllegalArgumentException("Can't convert empty string to integer");
195
196 if (s.charAt(0) == '+') s = s.substring(1);
197 return Integer.parseInt(s);
198 }
199
200
201
202
203 public static double[] copyOf(double[] original, int newLength) {
204 double[] copy = new double[newLength];
205 System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
206 return copy;
207 }
208
209
210
211
212 public static int[] copyOf(int[] original, int newLength) {
213 int[] copy = new int[newLength];
214 System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength));
215 return copy;
216 }
217
218
219
220
221
222
223
224 public static Model loadModel(Reader inputReader) throws IOException {
225 Model model = new Model();
226
227 model.label = null;
228
229 Pattern whitespace = Pattern.compile("\\s+");
230
231 BufferedReader reader = null;
232 if (inputReader instanceof BufferedReader) {
233 reader = (BufferedReader)inputReader;
234 } else {
235 reader = new BufferedReader(inputReader);
236 }
237
238 try {
239 String line = null;
240 while ((line = reader.readLine()) != null) {
241 String[] split = whitespace.split(line);
242 if (split[0].equals("solver_type")) {
243 SolverType solver = SolverType.valueOf(split[1]);
244 if (solver == null) {
245 throw new RuntimeException("unknown solver type");
246 }
247 model.solverType = solver;
248 } else if (split[0].equals("nr_class")) {
249 model.nr_class = atoi(split[1]);
250 Integer.parseInt(split[1]);
251 } else if (split[0].equals("nr_feature")) {
252 model.nr_feature = atoi(split[1]);
253 } else if (split[0].equals("bias")) {
254 model.bias = atof(split[1]);
255 } else if (split[0].equals("w")) {
256 break;
257 } else if (split[0].equals("label")) {
258 model.label = new int[model.nr_class];
259 for (int i = 0; i < model.nr_class; i++) {
260 model.label[i] = atoi(split[i + 1]);
261 }
262 } else {
263 throw new RuntimeException("unknown text in model file: [" + line + "]");
264 }
265 }
266
267 int w_size = model.nr_feature;
268 if (model.bias >= 0) w_size++;
269
270 int nr_w = model.nr_class;
271 if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
272
273 model.w = new double[w_size * nr_w];
274 int[] buffer = new int[128];
275
276 for (int i = 0; i < w_size; i++) {
277 for (int j = 0; j < nr_w; j++) {
278 int b = 0;
279 while (true) {
280 int ch = reader.read();
281 if (ch == -1) {
282 throw new EOFException("unexpected EOF");
283 }
284 if (ch == ' ') {
285 model.w[i * nr_w + j] = atof(new String(buffer, 0, b));
286 break;
287 } else {
288 buffer[b++] = ch;
289 }
290 }
291 }
292 }
293 }
294 finally {
295 closeQuietly(reader);
296 }
297
298 return model;
299 }
300
301
302
303
304
305 public static Model loadModel(File modelFile) throws IOException {
306 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
307 return loadModel(inputReader);
308 }
309
310 static void closeQuietly(Closeable c) {
311 if (c == null) return;
312 try {
313 c.close();
314 } catch (Throwable t) {}
315 }
316
317 public static int predict(Model model, FeatureNode[] x) {
318 double[] dec_values = new double[model.nr_class];
319 return predictValues(model, x, dec_values);
320 }
321
322 public static int predictProbability(Model model, FeatureNode[] x, double[] prob_estimates) {
323 if (model.solverType == SolverType.L2R_LR) {
324 int nr_class = model.nr_class;
325 int nr_w;
326 if (nr_class == 2)
327 nr_w = 1;
328 else
329 nr_w = nr_class;
330
331 int label = predictValues(model, x, prob_estimates);
332 for (int i = 0; i < nr_w; i++)
333 prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i]));
334
335 if (nr_class == 2)
336 prob_estimates[1] = 1. - prob_estimates[0];
337 else {
338 double sum = 0;
339 for (int i = 0; i < nr_class; i++)
340 sum += prob_estimates[i];
341
342 for (int i = 0; i < nr_class; i++)
343 prob_estimates[i] = prob_estimates[i] / sum;
344 }
345
346 return label;
347 } else
348 return 0;
349 }
350
351 public static int predictValues(Model model, FeatureNode[] x, double[] dec_values) {
352 int n;
353 if (model.bias >= 0)
354 n = model.nr_feature + 1;
355 else
356 n = model.nr_feature;
357
358 double[] w = model.w;
359
360 int nr_w;
361 if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS)
362 nr_w = 1;
363 else
364 nr_w = model.nr_class;
365
366 for (int i = 0; i < nr_w; i++)
367 dec_values[i] = 0;
368
369 for (FeatureNode lx : x) {
370 int idx = lx.index;
371
372 if (idx <= n) {
373 for (int i = 0; i < nr_w; i++) {
374 dec_values[i] += w[(idx - 1) * nr_w + i] * lx.value;
375 }
376 }
377 }
378
379 if (model.nr_class == 2)
380 return (dec_values[0] > 0) ? model.label[0] : model.label[1];
381 else {
382 int dec_max_idx = 0;
383 for (int i = 1; i < model.nr_class; i++) {
384 if (dec_values[i] > dec_values[dec_max_idx]) dec_max_idx = i;
385 }
386 return model.label[dec_max_idx];
387 }
388 }
389
390
391 static void printf(Formatter formatter, String format, Object... args) throws IOException {
392 formatter.format(format, args);
393 IOException ioException = formatter.ioException();
394 if (ioException != null) throw ioException;
395 }
396
397
398
399
400
401
402
403 public static void saveModel(Writer modelOutput, Model model) throws IOException {
404 int nr_feature = model.nr_feature;
405 int w_size = nr_feature;
406 if (model.bias >= 0) w_size++;
407
408 int nr_w = model.nr_class;
409 if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) nr_w = 1;
410
411 Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE);
412 try {
413 printf(formatter, "solver_type %s\n", model.solverType.name());
414 printf(formatter, "nr_class %d\n", model.nr_class);
415
416 printf(formatter, "label");
417 for (int i = 0; i < model.nr_class; i++) {
418 printf(formatter, " %d", model.label[i]);
419 }
420 printf(formatter, "\n");
421
422 printf(formatter, "nr_feature %d\n", nr_feature);
423 printf(formatter, "bias %.16g\n", model.bias);
424
425 printf(formatter, "w\n");
426 for (int i = 0; i < w_size; i++) {
427 for (int j = 0; j < nr_w; j++) {
428 double value = model.w[i * nr_w + j];
429
430
431 if (value == 0.0) {
432 printf(formatter, "%d ", 0);
433 } else {
434 printf(formatter, "%.16g ", value);
435 }
436 }
437 printf(formatter, "\n");
438 }
439
440 formatter.flush();
441 IOException ioException = formatter.ioException();
442 if (ioException != null) throw ioException;
443 }
444 finally {
445 formatter.close();
446 }
447 }
448
449
450
451
452
453 public static void saveModel(File modelFile, Model model) throws IOException {
454 BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), FILE_CHARSET));
455 saveModel(modelOutput, model);
456 }
457
458
459
460
461
462 private static int GETI(byte[] y, int i) {
463 return y[i] + 1;
464 }
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492 private static void solve_l2r_l1l2_svc(Problem prob, double[] w, double eps, double Cp, double Cn, SolverType solver_type) {
493 int l = prob.l;
494 int w_size = prob.n;
495 int i, s, iter = 0;
496 double C, d, G;
497 double[] QD = new double[l];
498 int max_iter = 1000;
499 int[] index = new int[l];
500 double[] alpha = new double[l];
501 byte[] y = new byte[l];
502 int active_size = l;
503
504
505 double PG;
506 double PGmax_old = Double.POSITIVE_INFINITY;
507 double PGmin_old = Double.NEGATIVE_INFINITY;
508 double PGmax_new, PGmin_new;
509
510
511 double diag[] = new double[] {0.5 / Cn, 0, 0.5 / Cp};
512 double upper_bound[] = new double[] {Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY};
513 if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) {
514 diag[0] = 0;
515 diag[2] = 0;
516 upper_bound[0] = Cn;
517 upper_bound[2] = Cp;
518 }
519
520 for (i = 0; i < w_size; i++)
521 w[i] = 0;
522 for (i = 0; i < l; i++) {
523 alpha[i] = 0;
524 if (prob.y[i] > 0) {
525 y[i] = +1;
526 } else {
527 y[i] = -1;
528 }
529 QD[i] = diag[GETI(y, i)];
530
531 for (FeatureNode xi : prob.x[i]) {
532 QD[i] += xi.value * xi.value;
533 }
534 index[i] = i;
535 }
536
537 while (iter < max_iter) {
538 PGmax_new = Double.NEGATIVE_INFINITY;
539 PGmin_new = Double.POSITIVE_INFINITY;
540
541 for (i = 0; i < active_size; i++) {
542 int j = i + random.nextInt(active_size - i);
543 swap(index, i, j);
544 }
545
546 for (s = 0; s < active_size; s++) {
547 i = index[s];
548 G = 0;
549 byte yi = y[i];
550
551 for (FeatureNode xi : prob.x[i]) {
552 G += w[xi.index - 1] * xi.value;
553 }
554 G = G * yi - 1;
555
556 C = upper_bound[GETI(y, i)];
557 G += alpha[i] * diag[GETI(y, i)];
558
559 PG = 0;
560 if (alpha[i] == 0) {
561 if (G > PGmax_old) {
562 active_size--;
563 swap(index, s, active_size);
564 s--;
565 continue;
566 } else if (G < 0) {
567 PG = G;
568 }
569 } else if (alpha[i] == C) {
570 if (G < PGmin_old) {
571 active_size--;
572 swap(index, s, active_size);
573 s--;
574 continue;
575 } else if (G > 0) {
576 PG = G;
577 }
578 } else {
579 PG = G;
580 }
581
582 PGmax_new = Math.max(PGmax_new, PG);
583 PGmin_new = Math.min(PGmin_new, PG);
584
585 if (Math.abs(PG) > 1.0e-12) {
586 double alpha_old = alpha[i];
587 alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C);
588 d = (alpha[i] - alpha_old) * yi;
589
590 for (FeatureNode xi : prob.x[i]) {
591 w[xi.index - 1] += d * xi.value;
592 }
593 }
594 }
595
596 iter++;
597 if (iter % 10 == 0) info(".");
598
599 if (PGmax_new - PGmin_new <= eps) {
600 if (active_size == l)
601 break;
602 else {
603 active_size = l;
604 info("*");
605 PGmax_old = Double.POSITIVE_INFINITY;
606 PGmin_old = Double.NEGATIVE_INFINITY;
607 continue;
608 }
609 }
610 PGmax_old = PGmax_new;
611 PGmin_old = PGmin_new;
612 if (PGmax_old <= 0) PGmax_old = Double.POSITIVE_INFINITY;
613 if (PGmin_old >= 0) PGmin_old = Double.NEGATIVE_INFINITY;
614 }
615
616 info(NL + "optimization finished, #iter = %d" + NL, iter);
617 if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\nUsing -s 2 may be faster (also see FAQ)\n\n");
618
619
620
621 double v = 0;
622 int nSV = 0;
623 for (i = 0; i < w_size; i++)
624 v += w[i] * w[i];
625 for (i = 0; i < l; i++) {
626 v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2);
627 if (alpha[i] > 0) ++nSV;
628 }
629 info("Objective value = %f" + NL, v / 2);
630 info("nSV = %d" + NL, nSV);
631 }
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647 private static void solve_l1r_l2_svc(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
648 int l = prob_col.l;
649 int w_size = prob_col.n;
650 int j, s, iter = 0;
651 int max_iter = 1000;
652 int active_size = w_size;
653 int max_num_linesearch = 20;
654
655 double sigma = 0.01;
656 double d, G_loss, G, H;
657 double Gmax_old = Double.POSITIVE_INFINITY;
658 double Gmax_new;
659 double Gmax_init = 0;
660 double d_old, d_diff;
661 double loss_old = 0;
662 double loss_new;
663 double appxcond, cond;
664
665 int[] index = new int[w_size];
666 byte[] y = new byte[l];
667 double[] b = new double[l];
668 double[] xj_sq = new double[w_size];
669
670 double[] C = new double[] {Cn, 0, Cp};
671
672 for (j = 0; j < l; j++) {
673 b[j] = 1;
674 if (prob_col.y[j] > 0)
675 y[j] = 1;
676 else
677 y[j] = -1;
678 }
679 for (j = 0; j < w_size; j++) {
680 w[j] = 0;
681 index[j] = j;
682 xj_sq[j] = 0;
683 for (FeatureNode xi : prob_col.x[j]) {
684 int ind = xi.index - 1;
685 double val = xi.value;
686 xi.value *= y[ind];
687 xj_sq[j] += C[GETI(y, ind)] * val * val;
688 }
689 }
690
691 while (iter < max_iter) {
692 Gmax_new = 0;
693
694 for (j = 0; j < active_size; j++) {
695 int i = j + random.nextInt(active_size - j);
696 swap(index, i, j);
697 }
698
699 for (s = 0; s < active_size; s++) {
700 j = index[s];
701 G_loss = 0;
702 H = 0;
703
704 for (FeatureNode xi : prob_col.x[j]) {
705 int ind = xi.index - 1;
706 if (b[ind] > 0) {
707 double val = xi.value;
708 double tmp = C[GETI(y, ind)] * val;
709 G_loss -= tmp * b[ind];
710 H += tmp * val;
711 }
712 }
713 G_loss *= 2;
714
715 G = G_loss;
716 H *= 2;
717 H = Math.max(H, 1e-12);
718
719 double Gp = G + 1;
720 double Gn = G - 1;
721 double violation = 0;
722 if (w[j] == 0) {
723 if (Gp < 0)
724 violation = -Gp;
725 else if (Gn > 0)
726 violation = Gn;
727 else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
728 active_size--;
729 swap(index, s, active_size);
730 s--;
731 continue;
732 }
733 } else if (w[j] > 0)
734 violation = Math.abs(Gp);
735 else
736 violation = Math.abs(Gn);
737
738 Gmax_new = Math.max(Gmax_new, violation);
739
740
741 if (Gp <= H * w[j])
742 d = -Gp / H;
743 else if (Gn >= H * w[j])
744 d = -Gn / H;
745 else
746 d = -w[j];
747
748 if (Math.abs(d) < 1.0e-12) continue;
749
750 double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d;
751 d_old = 0;
752 int num_linesearch;
753 for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
754 d_diff = d_old - d;
755 cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta;
756
757 appxcond = xj_sq[j] * d * d + G_loss * d + cond;
758 if (appxcond <= 0) {
759 for (FeatureNode x : prob_col.x[j]) {
760 b[x.index - 1] += d_diff * x.value;
761 }
762 break;
763 }
764
765 if (num_linesearch == 0) {
766 loss_old = 0;
767 loss_new = 0;
768 for (FeatureNode x : prob_col.x[j]) {
769 int ind = x.index - 1;
770 if (b[ind] > 0) {
771 loss_old += C[GETI(y, ind)] * b[ind] * b[ind];
772 }
773 double b_new = b[ind] + d_diff * x.value;
774 b[ind] = b_new;
775 if (b_new > 0) {
776 loss_new += C[GETI(y, ind)] * b_new * b_new;
777 }
778 }
779 } else {
780 loss_new = 0;
781 for (FeatureNode x : prob_col.x[j]) {
782 int ind = x.index - 1;
783 double b_new = b[ind] + d_diff * x.value;
784 b[ind] = b_new;
785 if (b_new > 0) {
786 loss_new += C[GETI(y, ind)] * b_new * b_new;
787 }
788 }
789 }
790
791 cond = cond + loss_new - loss_old;
792 if (cond <= 0)
793 break;
794 else {
795 d_old = d;
796 d *= 0.5;
797 delta *= 0.5;
798 }
799 }
800
801 w[j] += d;
802
803
804 if (num_linesearch >= max_num_linesearch) {
805 info("#");
806 for (int i = 0; i < l; i++)
807 b[i] = 1;
808
809 for (int i = 0; i < w_size; i++) {
810 if (w[i] == 0) continue;
811 for (FeatureNode x : prob_col.x[i]) {
812 b[x.index - 1] -= w[i] * x.value;
813 }
814 }
815 }
816 }
817
818 if (iter == 0) Gmax_init = Gmax_new;
819 iter++;
820 if (iter % 10 == 0) info(".");
821
822 if (Gmax_new <= eps * Gmax_init) {
823 if (active_size == w_size)
824 break;
825 else {
826 active_size = w_size;
827 info("*");
828 Gmax_old = Double.POSITIVE_INFINITY;
829 continue;
830 }
831 }
832
833 Gmax_old = Gmax_new;
834 }
835
836 info("\noptimization finished, #iter = %d\n", iter);
837 if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\n");
838
839
840
841 double v = 0;
842 int nnz = 0;
843 for (j = 0; j < w_size; j++) {
844 for (FeatureNode x : prob_col.x[j]) {
845 x.value *= prob_col.y[x.index - 1];
846 }
847 if (w[j] != 0) {
848 v += Math.abs(w[j]);
849 nnz++;
850 }
851 }
852 for (j = 0; j < l; j++)
853 if (b[j] > 0) v += C[GETI(y, j)] * b[j] * b[j];
854
855 info("Objective value = %f\n", v);
856 info("#nonzeros/#features = %d/%d\n", nnz, w_size);
857 }
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873 private static void solve_l1r_lr(Problem prob_col, double[] w, double eps, double Cp, double Cn) {
874 int l = prob_col.l;
875 int w_size = prob_col.n;
876 int j, s, iter = 0;
877 int max_iter = 1000;
878 int active_size = w_size;
879 int max_num_linesearch = 20;
880
881 double x_min = 0;
882 double sigma = 0.01;
883 double d, G, H;
884 double Gmax_old = Double.POSITIVE_INFINITY;
885 double Gmax_new;
886 double Gmax_init = 0;
887 double sum1, appxcond1;
888 double sum2, appxcond2;
889 double cond;
890
891 int[] index = new int[w_size];
892 byte[] y = new byte[l];
893 double[] exp_wTx = new double[l];
894 double[] exp_wTx_new = new double[l];
895 double[] xj_max = new double[w_size];
896 double[] C_sum = new double[w_size];
897 double[] xjneg_sum = new double[w_size];
898 double[] xjpos_sum = new double[w_size];
899
900 double[] C = new double[] {Cn, 0, Cp};
901
902 for (j = 0; j < l; j++) {
903 exp_wTx[j] = 1;
904 if (prob_col.y[j] > 0)
905 y[j] = 1;
906 else
907 y[j] = -1;
908 }
909 for (j = 0; j < w_size; j++) {
910 w[j] = 0;
911 index[j] = j;
912 xj_max[j] = 0;
913 C_sum[j] = 0;
914 xjneg_sum[j] = 0;
915 xjpos_sum[j] = 0;
916 for (FeatureNode x : prob_col.x[j]) {
917 int ind = x.index - 1;
918 double val = x.value;
919 x_min = Math.min(x_min, val);
920 xj_max[j] = Math.max(xj_max[j], val);
921 C_sum[j] += C[GETI(y, ind)];
922 if (y[ind] == -1)
923 xjneg_sum[j] += C[GETI(y, ind)] * val;
924 else
925 xjpos_sum[j] += C[GETI(y, ind)] * val;
926 }
927 }
928
929 while (iter < max_iter) {
930 Gmax_new = 0;
931
932 for (j = 0; j < active_size; j++) {
933 int i = j + random.nextInt(active_size) - j;
934 swap(index, i, j);
935 }
936
937 for (s = 0; s < active_size; s++) {
938 j = index[s];
939 sum1 = 0;
940 sum2 = 0;
941 H = 0;
942
943 for (FeatureNode x : prob_col.x[j]) {
944 int ind = x.index - 1;
945 double exp_wTxind = exp_wTx[ind];
946 double tmp1 = x.value / (1 + exp_wTxind);
947 double tmp2 = C[GETI(y, ind)] * tmp1;
948 double tmp3 = tmp2 * exp_wTxind;
949 sum2 += tmp2;
950 sum1 += tmp3;
951 H += tmp1 * tmp3;
952 }
953
954 G = -sum2 + xjneg_sum[j];
955
956 double Gp = G + 1;
957 double Gn = G - 1;
958 double violation = 0;
959 if (w[j] == 0) {
960 if (Gp < 0)
961 violation = -Gp;
962 else if (Gn > 0)
963 violation = Gn;
964 else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
965 active_size--;
966 swap(index, s, active_size);
967 s--;
968 continue;
969 }
970 } else if (w[j] > 0)
971 violation = Math.abs(Gp);
972 else
973 violation = Math.abs(Gn);
974
975 Gmax_new = Math.max(Gmax_new, violation);
976
977
978 if (Gp <= H * w[j])
979 d = -Gp / H;
980 else if (Gn >= H * w[j])
981 d = -Gn / H;
982 else
983 d = -w[j];
984
985 if (Math.abs(d) < 1.0e-12) continue;
986
987 d = Math.min(Math.max(d, -10.0), 10.0);
988
989 double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d;
990 int num_linesearch;
991 for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) {
992 cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta;
993
994 if (x_min >= 0) {
995 double tmp = Math.exp(d * xj_max[j]);
996 appxcond1 = Math.log(1 + sum1 * (tmp - 1) / xj_max[j] / C_sum[j]) * C_sum[j] + cond - d * xjpos_sum[j];
997 appxcond2 = Math.log(1 + sum2 * (1 / tmp - 1) / xj_max[j] / C_sum[j]) * C_sum[j] + cond + d * xjneg_sum[j];
998 if (Math.min(appxcond1, appxcond2) <= 0) {
999 for (FeatureNode x : prob_col.x[j]) {
1000 exp_wTx[x.index - 1] *= Math.exp(d * x.value);
1001 }
1002 break;
1003 }
1004 }
1005
1006 cond += d * xjneg_sum[j];
1007
1008 int i = 0;
1009 for (FeatureNode x : prob_col.x[j]) {
1010 int ind = x.index - 1;
1011 double exp_dx = Math.exp(d * x.value);
1012 exp_wTx_new[i] = exp_wTx[ind] * exp_dx;
1013 cond += C[GETI(y, ind)] * Math.log((1 + exp_wTx_new[i]) / (exp_dx + exp_wTx_new[i]));
1014 i++;
1015 }
1016
1017 if (cond <= 0) {
1018 i = 0;
1019 for (FeatureNode x : prob_col.x[j]) {
1020 int ind = x.index - 1;
1021 exp_wTx[ind] = exp_wTx_new[i];
1022 i++;
1023 }
1024 break;
1025 } else {
1026 d *= 0.5;
1027 delta *= 0.5;
1028 }
1029 }
1030
1031 w[j] += d;
1032
1033
1034 if (num_linesearch >= max_num_linesearch) {
1035 info("#");
1036 for (int i = 0; i < l; i++)
1037 exp_wTx[i] = 0;
1038
1039 for (int i = 0; i < w_size; i++) {
1040 if (w[i] == 0) continue;
1041 for (FeatureNode x : prob_col.x[i]) {
1042 exp_wTx[x.index - 1] += w[i] * x.value;
1043 }
1044 }
1045
1046 for (int i = 0; i < l; i++)
1047 exp_wTx[i] = Math.exp(exp_wTx[i]);
1048 }
1049 }
1050
1051 if (iter == 0) Gmax_init = Gmax_new;
1052 iter++;
1053 if (iter % 10 == 0) info(".");
1054
1055 if (Gmax_new <= eps * Gmax_init) {
1056 if (active_size == w_size)
1057 break;
1058 else {
1059 active_size = w_size;
1060 info("*");
1061 Gmax_old = Double.POSITIVE_INFINITY;
1062 continue;
1063 }
1064 }
1065
1066 Gmax_old = Gmax_new;
1067 }
1068
1069 info("\noptimization finished, #iter = %d\n", iter);
1070 if (iter >= max_iter) info("\nWARNING: reaching max number of iterations\n");
1071
1072
1073
1074 double v = 0;
1075 int nnz = 0;
1076 for (j = 0; j < w_size; j++)
1077 if (w[j] != 0) {
1078 v += Math.abs(w[j]);
1079 nnz++;
1080 }
1081 for (j = 0; j < l; j++)
1082 if (y[j] == 1)
1083 v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]);
1084 else
1085 v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]);
1086
1087 info("Objective value = %f\n", v);
1088 info("#nonzeros/#features = %d/%d\n", nnz, w_size);
1089 }
1090
1091
1092 static Problem transpose(Problem prob) {
1093 int l = prob.l;
1094 int n = prob.n;
1095 int[] col_ptr = new int[n + 1];
1096 Problem prob_col = new Problem();
1097 prob_col.l = l;
1098 prob_col.n = n;
1099 prob_col.y = new int[l];
1100 prob_col.x = new FeatureNode[n][];
1101
1102 for (int i = 0; i < l; i++)
1103 prob_col.y[i] = prob.y[i];
1104
1105 for (int i = 0; i < l; i++) {
1106 for (FeatureNode x : prob.x[i]) {
1107 col_ptr[x.index]++;
1108 }
1109 }
1110
1111 for (int i = 0; i < n; i++) {
1112 prob_col.x[i] = new FeatureNode[col_ptr[i + 1]];
1113 col_ptr[i] = 0;
1114 }
1115
1116 for (int i = 0; i < l; i++) {
1117 for (int j = 0; j < prob.x[i].length; j++) {
1118 FeatureNode x = prob.x[i][j];
1119 int index = x.index - 1;
1120 prob_col.x[index][col_ptr[index]] = new FeatureNode(i + 1, x.value);
1121 col_ptr[index]++;
1122 }
1123 }
1124
1125 return prob_col;
1126 }
1127
1128 static void swap(double[] array, int idxA, int idxB) {
1129 double temp = array[idxA];
1130 array[idxA] = array[idxB];
1131 array[idxB] = temp;
1132 }
1133
1134 static void swap(int[] array, int idxA, int idxB) {
1135 int temp = array[idxA];
1136 array[idxA] = array[idxB];
1137 array[idxB] = temp;
1138 }
1139
1140 static void swap(IntArrayPointer array, int idxA, int idxB) {
1141 int temp = array.get(idxA);
1142 array.set(idxA, array.get(idxB));
1143 array.set(idxB, temp);
1144 }
1145
1146
1147
1148
1149 public static Model train(Problem prob, Parameter param) {
1150
1151 if (prob == null) throw new IllegalArgumentException("problem must not be null");
1152 if (param == null) throw new IllegalArgumentException("parameter must not be null");
1153
1154 for (FeatureNode[] nodes : prob.x) {
1155 int indexBefore = 0;
1156 for (FeatureNode n : nodes) {
1157 if (n.index <= indexBefore) {
1158 throw new IllegalArgumentException("feature nodes must be sorted by index in ascending order");
1159 }
1160 indexBefore = n.index;
1161 }
1162 }
1163
1164 int i, j;
1165 int l = prob.l;
1166 int n = prob.n;
1167 int w_size = prob.n;
1168 Model model = new Model();
1169
1170 if (prob.bias >= 0)
1171 model.nr_feature = n - 1;
1172 else
1173 model.nr_feature = n;
1174 model.solverType = param.solverType;
1175 model.bias = prob.bias;
1176
1177 int[] perm = new int[l];
1178
1179 GroupClassesReturn rv = groupClasses(prob, perm);
1180 int nr_class = rv.nr_class;
1181 int[] label = rv.label;
1182 int[] start = rv.start;
1183 int[] count = rv.count;
1184
1185 model.nr_class = nr_class;
1186 model.label = new int[nr_class];
1187 for (i = 0; i < nr_class; i++)
1188 model.label[i] = label[i];
1189
1190
1191 double[] weighted_C = new double[nr_class];
1192 for (i = 0; i < nr_class; i++) {
1193 weighted_C[i] = param.C;
1194 }
1195
1196 for (i = 0; i < param.getNumWeights(); i++) {
1197 for (j = 0; j < nr_class; j++)
1198 if (param.weightLabel[i] == label[j]) break;
1199 if (j == nr_class) throw new IllegalArgumentException("class label " + param.weightLabel[i] + " specified in weight is not found");
1200
1201 weighted_C[j] *= param.weight[i];
1202 }
1203
1204
1205 FeatureNode[][] x = new FeatureNode[l][];
1206 for (i = 0; i < l; i++)
1207 x[i] = prob.x[perm[i]];
1208
1209 int k;
1210 Problem sub_prob = new Problem();
1211 sub_prob.l = l;
1212 sub_prob.n = n;
1213 sub_prob.x = new FeatureNode[sub_prob.l][];
1214 sub_prob.y = new int[sub_prob.l];
1215
1216 for (k = 0; k < sub_prob.l; k++)
1217 sub_prob.x[k] = x[k];
1218
1219
1220 if (param.solverType == SolverType.MCSVM_CS) {
1221 model.w = new double[n * nr_class];
1222 for (i = 0; i < nr_class; i++) {
1223 for (j = start[i]; j < start[i] + count[i]; j++) {
1224 sub_prob.y[j] = i;
1225 }
1226 }
1227
1228 SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps);
1229 solver.solve(model.w);
1230 } else {
1231 if (nr_class == 2) {
1232 model.w = new double[w_size];
1233
1234 int e0 = start[0] + count[0];
1235 k = 0;
1236 for (; k < e0; k++)
1237 sub_prob.y[k] = +1;
1238 for (; k < sub_prob.l; k++)
1239 sub_prob.y[k] = -1;
1240
1241 train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]);
1242 } else {
1243 model.w = new double[w_size * nr_class];
1244 double[] w = new double[w_size];
1245 for (i = 0; i < nr_class; i++) {
1246 int si = start[i];
1247 int ei = si + count[i];
1248
1249 k = 0;
1250 for (; k < si; k++)
1251 sub_prob.y[k] = -1;
1252 for (; k < ei; k++)
1253 sub_prob.y[k] = +1;
1254 for (; k < sub_prob.l; k++)
1255 sub_prob.y[k] = -1;
1256
1257 train_one(sub_prob, param, w, weighted_C[i], param.C);
1258
1259 for (j = 0; j < n; j++)
1260 model.w[j * nr_class + i] = w[j];
1261 }
1262 }
1263
1264 }
1265 return model;
1266 }
1267
1268 private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) {
1269 double eps = param.eps;
1270 int pos = 0;
1271 for (int i = 0; i < prob.l; i++)
1272 if (prob.y[i] == +1) pos++;
1273 int neg = prob.l - pos;
1274
1275 Function fun_obj = null;
1276 switch (param.solverType) {
1277 case L2R_LR: {
1278 fun_obj = new L2R_LrFunction(prob, Cp, Cn);
1279 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1280 tron_obj.tron(w);
1281 break;
1282 }
1283 case L2R_L2LOSS_SVC: {
1284 fun_obj = new L2R_L2_SvcFunction(prob, Cp, Cn);
1285 Tron tron_obj = new Tron(fun_obj, eps * Math.min(pos, neg) / prob.l);
1286 tron_obj.tron(w);
1287 break;
1288 }
1289 case L2R_L2LOSS_SVC_DUAL:
1290 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL);
1291 break;
1292 case L2R_L1LOSS_SVC_DUAL:
1293 solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL);
1294 break;
1295 case L1R_L2LOSS_SVC: {
1296 Problem prob_col = transpose(prob);
1297 solve_l1r_l2_svc(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1298 break;
1299 }
1300 case L1R_LR: {
1301 Problem prob_col = transpose(prob);
1302 solve_l1r_lr(prob_col, w, eps * Math.min(pos, neg) / prob.l, Cp, Cn);
1303 break;
1304 }
1305 default:
1306 throw new IllegalStateException("unknown solver type: " + param.solverType);
1307 }
1308 }
1309
1310 public static void disableDebugOutput() {
1311 setDebugOutput(null);
1312 }
1313
1314 public static void enableDebugOutput() {
1315 setDebugOutput(System.out);
1316 }
1317
1318 public static void setDebugOutput(PrintStream debugOutput) {
1319 synchronized (OUTPUT_MUTEX) {
1320 DEBUG_OUTPUT = debugOutput;
1321 }
1322 }
1323
1324
1325
1326
1327
1328
1329 public static void resetRandom() {
1330 random = new Random(DEFAULT_RANDOM_SEED);
1331 }
1332 }