View Javadoc

1   package liblinear;
2   
3   import static liblinear.Linear.NL;
4   import static liblinear.Linear.info;
5   import static org.netlib.blas.DAXPY.DAXPY;
6   import static org.netlib.blas.DDOT.DDOT;
7   import static org.netlib.blas.DNRM2.DNRM2;
8   import static org.netlib.blas.DSCAL.DSCAL;
9   
10  
11  class Tron {
12  
13      private final Function fun_obj;
14  
15      private final double   eps;
16  
17      private final int      max_iter;
18  
19      public Tron( final Function fun_obj ) {
20          this(fun_obj, 0.1);
21      }
22  
23      public Tron( final Function fun_obj, double eps ) {
24          this(fun_obj, eps, 1000);
25      }
26  
27      public Tron( final Function fun_obj, double eps, int max_iter ) {
28          this.fun_obj = fun_obj;
29          this.eps = eps;
30          this.max_iter = max_iter;
31      }
32  
33      // void tron(double *w)
34      void tron(double[] w) {
35          // Parameters for updating the iterates.
36          double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
37  
38          // Parameters for updating the trust region size delta.
39          double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4;
40  
41          int n = fun_obj.get_nr_variable();
42          int i, cg_iter;
43          double delta, snorm, one = 1.0;
44          double alpha, f, fnew, prered, actred, gs;
45          int search = 1, iter = 1, inc = 1;
46          double[] s = new double[n];
47          double[] r = new double[n];
48          double[] w_new = new double[n];
49          double[] g = new double[n];
50  
51          for (i = 0; i < n; i++)
52              w[i] = 0;
53  
54          f = fun_obj.fun(w);
55          fun_obj.grad(w, g);
56          delta = DNRM2(n, g, inc);
57          // delta = dnrm2_(&n, g, &inc);
58          double gnorm1 = delta;
59          double gnorm = gnorm1;
60  
61          if (gnorm <= eps * gnorm1) search = 0;
62  
63          iter = 1;
64  
65          while (iter <= max_iter && search != 0) {
66              cg_iter = trcg(delta, g, s, r);
67  
68              // memcpy(w_new, w, sizeof(double)*n);
69              System.arraycopy(w, 0, w_new, 0, n);
70              DAXPY(n, one, s, inc, w_new, inc);
71  
72              gs = DDOT(n, g, inc, s, inc);
73              // gs = ddot_(&n, g, &inc, s, &inc);
74              prered = -0.5 * (gs - DDOT(n, s, inc, r, inc));
75              fnew = fun_obj.fun(w_new);
76  
77              // Compute the actual reduction.
78              actred = f - fnew;
79  
80              // On the first iteration, adjust the initial step bound.
81              snorm = DNRM2(n, s, inc);
82              // snorm = dnrm2_(&n, s, &inc);
83              if (iter == 1) delta = Math.min(delta, snorm);
84  
85              // Compute prediction alpha*snorm of the step.
86              if (fnew - f - gs <= 0)
87                  alpha = sigma3;
88              else
89                  alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs)));
90  
91              // Update the trust region bound according to the ratio of actual to
92              // predicted reduction.
93              if (actred < eta0 * prered)
94                  delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 * delta);
95              else if (actred < eta1 * prered)
96                  delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma2 * delta));
97              else if (actred < eta2 * prered)
98                  delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma3 * delta));
99              else
100                 delta = Math.max(delta, Math.min(alpha * snorm, sigma3 * delta));
101 
102             info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d" + NL, iter, actred, prered, delta, f, gnorm, cg_iter);
103 
104             if (actred > eta0 * prered) {
105                 iter++;
106                 // memcpy(w, w_new, sizeof(double)*n);
107                 System.arraycopy(w_new, 0, w, 0, n);
108                 f = fnew;
109                 fun_obj.grad(w, g);
110 
111                 gnorm = DNRM2(n, g, inc);
112                 // gnorm = dnrm2_(&n, g, &inc);
113                 if (gnorm <= eps * gnorm1) break;
114             }
115             if (f < -1.0e+32) {
116                 info("warning: f < -1.0e+32" + NL);
117                 break;
118             }
119             if (Math.abs(actred) <= 0 && prered <= 0) {
120                 info("warning: actred and prered <= 0" + NL);
121                 break;
122             }
123             if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) {
124                 info("warning: actred and prered too small" + NL);
125                 break;
126             }
127         }
128     }
129 
130     // int TRON::trcg(double delta, double *g, double *s, double *r)
131     int trcg(double delta, double[] g, double[] s, double[] r) {
132         int i, inc = 1;
133         int n = fun_obj.get_nr_variable();
134         double one = 1;
135         double[] d = new double[n];
136         double[] Hd = new double[n];
137         double rTr, rnewTrnew, cgtol;
138 
139         for (i = 0; i < n; i++) {
140             s[i] = 0;
141             r[i] = -g[i];
142             d[i] = r[i];
143         }
144         cgtol = 0.1 * DNRM2(n, g, inc);
145 
146         int cg_iter = 0;
147         // rTr = ddot_(&n, r, &inc, r, &inc);
148         rTr = DDOT(n, r, inc, r, inc);
149 
150         while (true) {
151             if (DNRM2(n, r, inc) <= cgtol) break;
152             cg_iter++;
153             fun_obj.Hv(d, Hd);
154 
155             double alpha = rTr / DDOT(n, d, inc, Hd, inc);
156             DAXPY(n, alpha, d, inc, s, inc);
157             // daxpy_(&n, &alpha, d, &inc, s, &inc);
158             // if (dnrm2_(&n, s, &inc) > delta)
159             if (DNRM2(n, s, inc) > delta) {
160                 info("cg reaches trust region boundary\n");
161                 alpha = -alpha;
162                 // daxpy_(&n, &alpha, d, &inc, s, &inc);
163                 DAXPY(n, alpha, d, inc, s, inc);
164 
165                 double std = DDOT(n, s, inc, d, inc);
166                 double sts = DDOT(n, s, inc, s, inc);
167                 double dtd = DDOT(n, d, inc, d, inc);
168                 double dsq = delta * delta;
169                 double rad = Math.sqrt(std * std + dtd * (dsq - sts));
170                 if (std >= 0)
171                     alpha = (dsq - sts) / (std + rad);
172                 else
173                     alpha = (rad - std) / dtd;
174                 DAXPY(n, alpha, d, inc, s, inc);
175                 alpha = -alpha;
176                 DAXPY(n, alpha, Hd, inc, r, inc);
177                 break;
178             }
179             alpha = -alpha;
180             DAXPY(n, alpha, Hd, inc, r, inc);
181             rnewTrnew = DDOT(n, r, inc, r, inc);
182             double beta = rnewTrnew / rTr;
183             DSCAL(n, beta, d, inc);
184             DAXPY(n, one, r, inc, d, inc);
185             rTr = rnewTrnew;
186         }
187 
188         return (cg_iter);
189     }
190 
191     double norm_inf(int n, double[] x) {
192         double dmax = Math.abs(x[0]);
193         for (int i = 1; i < n; i++)
194             if (Math.abs(x[i]) >= dmax) dmax = Math.abs(x[i]);
195         return (dmax);
196     }
197 }