1 package liblinear;
2
3 import static liblinear.Linear.NL;
4 import static liblinear.Linear.info;
5 import static org.netlib.blas.DAXPY.DAXPY;
6 import static org.netlib.blas.DDOT.DDOT;
7 import static org.netlib.blas.DNRM2.DNRM2;
8 import static org.netlib.blas.DSCAL.DSCAL;
9
10
11 class Tron {
12
13 private final Function fun_obj;
14
15 private final double eps;
16
17 private final int max_iter;
18
19 public Tron( final Function fun_obj ) {
20 this(fun_obj, 0.1);
21 }
22
23 public Tron( final Function fun_obj, double eps ) {
24 this(fun_obj, eps, 1000);
25 }
26
27 public Tron( final Function fun_obj, double eps, int max_iter ) {
28 this.fun_obj = fun_obj;
29 this.eps = eps;
30 this.max_iter = max_iter;
31 }
32
33
34 void tron(double[] w) {
35
36 double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
37
38
39 double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4;
40
41 int n = fun_obj.get_nr_variable();
42 int i, cg_iter;
43 double delta, snorm, one = 1.0;
44 double alpha, f, fnew, prered, actred, gs;
45 int search = 1, iter = 1, inc = 1;
46 double[] s = new double[n];
47 double[] r = new double[n];
48 double[] w_new = new double[n];
49 double[] g = new double[n];
50
51 for (i = 0; i < n; i++)
52 w[i] = 0;
53
54 f = fun_obj.fun(w);
55 fun_obj.grad(w, g);
56 delta = DNRM2(n, g, inc);
57
58 double gnorm1 = delta;
59 double gnorm = gnorm1;
60
61 if (gnorm <= eps * gnorm1) search = 0;
62
63 iter = 1;
64
65 while (iter <= max_iter && search != 0) {
66 cg_iter = trcg(delta, g, s, r);
67
68
69 System.arraycopy(w, 0, w_new, 0, n);
70 DAXPY(n, one, s, inc, w_new, inc);
71
72 gs = DDOT(n, g, inc, s, inc);
73
74 prered = -0.5 * (gs - DDOT(n, s, inc, r, inc));
75 fnew = fun_obj.fun(w_new);
76
77
78 actred = f - fnew;
79
80
81 snorm = DNRM2(n, s, inc);
82
83 if (iter == 1) delta = Math.min(delta, snorm);
84
85
86 if (fnew - f - gs <= 0)
87 alpha = sigma3;
88 else
89 alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs)));
90
91
92
93 if (actred < eta0 * prered)
94 delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 * delta);
95 else if (actred < eta1 * prered)
96 delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma2 * delta));
97 else if (actred < eta2 * prered)
98 delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma3 * delta));
99 else
100 delta = Math.max(delta, Math.min(alpha * snorm, sigma3 * delta));
101
102 info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d" + NL, iter, actred, prered, delta, f, gnorm, cg_iter);
103
104 if (actred > eta0 * prered) {
105 iter++;
106
107 System.arraycopy(w_new, 0, w, 0, n);
108 f = fnew;
109 fun_obj.grad(w, g);
110
111 gnorm = DNRM2(n, g, inc);
112
113 if (gnorm <= eps * gnorm1) break;
114 }
115 if (f < -1.0e+32) {
116 info("warning: f < -1.0e+32" + NL);
117 break;
118 }
119 if (Math.abs(actred) <= 0 && prered <= 0) {
120 info("warning: actred and prered <= 0" + NL);
121 break;
122 }
123 if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) {
124 info("warning: actred and prered too small" + NL);
125 break;
126 }
127 }
128 }
129
130
131 int trcg(double delta, double[] g, double[] s, double[] r) {
132 int i, inc = 1;
133 int n = fun_obj.get_nr_variable();
134 double one = 1;
135 double[] d = new double[n];
136 double[] Hd = new double[n];
137 double rTr, rnewTrnew, cgtol;
138
139 for (i = 0; i < n; i++) {
140 s[i] = 0;
141 r[i] = -g[i];
142 d[i] = r[i];
143 }
144 cgtol = 0.1 * DNRM2(n, g, inc);
145
146 int cg_iter = 0;
147
148 rTr = DDOT(n, r, inc, r, inc);
149
150 while (true) {
151 if (DNRM2(n, r, inc) <= cgtol) break;
152 cg_iter++;
153 fun_obj.Hv(d, Hd);
154
155 double alpha = rTr / DDOT(n, d, inc, Hd, inc);
156 DAXPY(n, alpha, d, inc, s, inc);
157
158
159 if (DNRM2(n, s, inc) > delta) {
160 info("cg reaches trust region boundary\n");
161 alpha = -alpha;
162
163 DAXPY(n, alpha, d, inc, s, inc);
164
165 double std = DDOT(n, s, inc, d, inc);
166 double sts = DDOT(n, s, inc, s, inc);
167 double dtd = DDOT(n, d, inc, d, inc);
168 double dsq = delta * delta;
169 double rad = Math.sqrt(std * std + dtd * (dsq - sts));
170 if (std >= 0)
171 alpha = (dsq - sts) / (std + rad);
172 else
173 alpha = (rad - std) / dtd;
174 DAXPY(n, alpha, d, inc, s, inc);
175 alpha = -alpha;
176 DAXPY(n, alpha, Hd, inc, r, inc);
177 break;
178 }
179 alpha = -alpha;
180 DAXPY(n, alpha, Hd, inc, r, inc);
181 rnewTrnew = DDOT(n, r, inc, r, inc);
182 double beta = rnewTrnew / rTr;
183 DSCAL(n, beta, d, inc);
184 DAXPY(n, one, r, inc, d, inc);
185 rTr = rnewTrnew;
186 }
187
188 return (cg_iter);
189 }
190
191 double norm_inf(int n, double[] x) {
192 double dmax = Math.abs(x[0]);
193 for (int i = 1; i < n; i++)
194 if (Math.abs(x[i]) >= dmax) dmax = Math.abs(x[i]);
195 return (dmax);
196 }
197 }