View Javadoc

1   package de.bwaldvogel.liblinear;
2   
3   import static org.fest.assertions.Assertions.assertThat;
4   import static org.fest.assertions.Fail.fail;
5   import static org.mockito.Mockito.doThrow;
6   import static org.mockito.Mockito.times;
7   import static org.mockito.Mockito.verify;
8   
9   import java.io.File;
10  import java.io.IOException;
11  import java.io.Writer;
12  import java.util.ArrayList;
13  import java.util.Collections;
14  import java.util.List;
15  import java.util.Random;
16  import java.util.Set;
17  import java.util.TreeSet;
18  
19  import org.fest.assertions.Delta;
20  import org.junit.BeforeClass;
21  import org.junit.Test;
22  import org.powermock.api.mockito.PowerMockito;
23  
24  
25  public class LinearTest {
26  
27      private static Random random = new Random(0);
28  
29      @BeforeClass
30      public static void disableDebugOutput() {
31          Linear.disableDebugOutput();
32      }
33  
34      public static Model createRandomModel() {
35          Model model = new Model();
36          model.solverType = SolverType.L2R_LR;
37          model.bias = 2;
38          model.label = new int[] {1, Integer.MAX_VALUE, 2};
39          model.w = new double[model.label.length * 300];
40          for (int i = 0; i < model.w.length; i++) {
41              // precision should be at least 1e-4
42              model.w[i] = Math.round(random.nextDouble() * 100000.0) / 10000.0;
43          }
44  
45          // force at least one value to be zero
46          model.w[random.nextInt(model.w.length)] = 0.0;
47          model.w[random.nextInt(model.w.length)] = -0.0;
48  
49          model.nr_feature = model.w.length / model.label.length - 1;
50          model.nr_class = model.label.length;
51          return model;
52      }
53  
54      public static Problem createRandomProblem(int numClasses) {
55          Problem prob = new Problem();
56          prob.bias = -1;
57          prob.l = random.nextInt(100) + 1;
58          prob.n = random.nextInt(100) + 1;
59          prob.x = new FeatureNode[prob.l][];
60          prob.y = new int[prob.l];
61  
62          for (int i = 0; i < prob.l; i++) {
63  
64              prob.y[i] = random.nextInt(numClasses);
65  
66              Set<Integer> randomNumbers = new TreeSet<Integer>();
67              int num = random.nextInt(prob.n) + 1;
68              for (int j = 0; j < num; j++) {
69                  randomNumbers.add(random.nextInt(prob.n) + 1);
70              }
71              List<Integer> randomIndices = new ArrayList<Integer>(randomNumbers);
72              Collections.sort(randomIndices);
73  
74              prob.x[i] = new FeatureNode[randomIndices.size()];
75              for (int j = 0; j < randomIndices.size(); j++) {
76                  prob.x[i][j] = new FeatureNode(randomIndices.get(j), random.nextDouble());
77              }
78          }
79          return prob;
80      }
81  
82      @Test
83      public void testRealloc() {
84  
85          int[] f = new int[] {1, 2, 3};
86          f = Linear.copyOf(f, 5);
87          f[3] = 4;
88          f[4] = 5;
89          assertThat(f).isEqualTo(new int[] {1, 2, 3, 4, 5});
90      }
91  
92      @Test
93      public void testAtoi() {
94          assertThat(Linear.atoi("+25")).isEqualTo(25);
95          assertThat(Linear.atoi("-345345")).isEqualTo(-345345);
96          assertThat(Linear.atoi("+0")).isEqualTo(0);
97          assertThat(Linear.atoi("0")).isEqualTo(0);
98          assertThat(Linear.atoi("2147483647")).isEqualTo(Integer.MAX_VALUE);
99          assertThat(Linear.atoi("-2147483648")).isEqualTo(Integer.MIN_VALUE);
100     }
101 
102     @Test(expected = NumberFormatException.class)
103     public void testAtoiInvalidData() {
104         Linear.atoi("+");
105     }
106 
107     @Test(expected = NumberFormatException.class)
108     public void testAtoiInvalidData2() {
109         Linear.atoi("abc");
110     }
111 
112     @Test(expected = NumberFormatException.class)
113     public void testAtoiInvalidData3() {
114         Linear.atoi(" ");
115     }
116 
117     @Test
118     public void testAtof() {
119         assertThat(Linear.atof("+25")).isEqualTo(25);
120         assertThat(Linear.atof("-25.12345678")).isEqualTo(-25.12345678);
121         assertThat(Linear.atof("0.345345299")).isEqualTo(0.345345299);
122     }
123 
124     @Test(expected = NumberFormatException.class)
125     public void testAtofInvalidData() {
126         Linear.atof("0.5t");
127     }
128 
129     @Test
130     public void testLoadSaveModel() throws Exception {
131 
132         Model model = null;
133         for (SolverType solverType : SolverType.values()) {
134             model = createRandomModel();
135             model.solverType = solverType;
136 
137             File tempFile = File.createTempFile("liblinear", "modeltest");
138             tempFile.deleteOnExit();
139             Linear.saveModel(tempFile, model);
140 
141             Model loadedModel = Linear.loadModel(tempFile);
142             assertThat(loadedModel).isEqualTo(model);
143         }
144     }
145 
146     @Test
147     public void testCrossValidation() throws Exception {
148 
149         int numClasses = random.nextInt(10) + 1;
150 
151         Problem prob = createRandomProblem(numClasses);
152 
153         Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.01);
154         int nr_fold = 10;
155         int[] target = new int[prob.l];
156         Linear.crossValidation(prob, param, nr_fold, target);
157 
158         for (int clazz : target) {
159             assertThat(clazz).isGreaterThanOrEqualTo(0).isLessThan(numClasses);
160         }
161     }
162 
163     @Test
164     public void testSaveModelWithIOException() throws Exception {
165         Model model = createRandomModel();
166 
167         Writer out = PowerMockito.mock(Writer.class);
168 
169         IOException ioException = new IOException("some reason");
170 
171         doThrow(ioException).when(out).flush();
172 
173         try {
174             Linear.saveModel(out, model);
175             fail("IOException expected");
176         } catch (IOException e) {
177             assertThat(e).isEqualTo(ioException);
178         }
179 
180         verify(out).flush();
181         verify(out, times(1)).close();
182     }
183 
184     /**
185      * compared input/output values with the C version (1.51)
186      *
187      * <pre>
188      * IN:
189      * res prob.l = 4
190      * res prob.n = 4
191      * 0: (2,1) (4,1)
192      * 1: (1,1)
193      * 2: (3,1)
194      * 3: (2,2) (3,1) (4,1)
195      *
196      * TRANSPOSED:
197      *
198      * res prob.l = 4
199      * res prob.n = 4
200      * 0: (2,1)
201      * 1: (1,1) (4,2)
202      * 2: (3,1) (4,1)
203      * 3: (1,1) (4,1)
204      * </pre>
205      */
206     @Test
207     public void testTranspose() throws Exception {
208         Problem prob = new Problem();
209         prob.bias = -1;
210         prob.l = 4;
211         prob.n = 4;
212         prob.x = new FeatureNode[4][];
213         prob.x[0] = new FeatureNode[2];
214         prob.x[1] = new FeatureNode[1];
215         prob.x[2] = new FeatureNode[1];
216         prob.x[3] = new FeatureNode[3];
217 
218         prob.x[0][0] = new FeatureNode(2, 1);
219         prob.x[0][1] = new FeatureNode(4, 1);
220 
221         prob.x[1][0] = new FeatureNode(1, 1);
222         prob.x[2][0] = new FeatureNode(3, 1);
223 
224         prob.x[3][0] = new FeatureNode(2, 2);
225         prob.x[3][1] = new FeatureNode(3, 1);
226         prob.x[3][2] = new FeatureNode(4, 1);
227 
228         prob.y = new int[4];
229         prob.y[0] = 0;
230         prob.y[1] = 1;
231         prob.y[2] = 1;
232         prob.y[3] = 0;
233 
234         Problem transposed = Linear.transpose(prob);
235 
236         assertThat(transposed.x[0].length).isEqualTo(1);
237         assertThat(transposed.x[1].length).isEqualTo(2);
238         assertThat(transposed.x[2].length).isEqualTo(2);
239         assertThat(transposed.x[3].length).isEqualTo(2);
240 
241         assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(2, 1));
242 
243         assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(1, 1));
244         assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(4, 2));
245 
246         assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(3, 1));
247         assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(4, 1));
248 
249         assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(1, 1));
250         assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(4, 1));
251 
252         assertThat(transposed.y).isEqualTo(prob.y);
253     }
254 
255     /**
256      *
257      * compared input/output values with the C version (1.51)
258      *
259      * <pre>
260      * IN:
261      * res prob.l = 5
262      * res prob.n = 10
263      * 0: (1,7) (3,3) (5,2)
264      * 1: (2,1) (4,5) (5,3) (7,4) (8,2)
265      * 2: (1,9) (3,1) (5,1) (10,7)
266      * 3: (1,2) (2,2) (3,9) (4,7) (5,8) (6,1) (7,5) (8,4)
267      * 4: (3,1) (10,3)
268      *
269      * TRANSPOSED:
270      *
271      * res prob.l = 5
272      * res prob.n = 10
273      * 0: (1,7) (3,9) (4,2)
274      * 1: (2,1) (4,2)
275      * 2: (1,3) (3,1) (4,9) (5,1)
276      * 3: (2,5) (4,7)
277      * 4: (1,2) (2,3) (3,1) (4,8)
278      * 5: (4,1)
279      * 6: (2,4) (4,5)
280      * 7: (2,2) (4,4)
281      * 8:
282      * 9: (3,7) (5,3)
283      * </pre>
284      */
285     @Test
286     public void testTranspose2() throws Exception {
287         Problem prob = new Problem();
288         prob.bias = -1;
289         prob.l = 5;
290         prob.n = 10;
291         prob.x = new FeatureNode[5][];
292         prob.x[0] = new FeatureNode[3];
293         prob.x[1] = new FeatureNode[5];
294         prob.x[2] = new FeatureNode[4];
295         prob.x[3] = new FeatureNode[8];
296         prob.x[4] = new FeatureNode[2];
297 
298         prob.x[0][0] = new FeatureNode(1, 7);
299         prob.x[0][1] = new FeatureNode(3, 3);
300         prob.x[0][2] = new FeatureNode(5, 2);
301 
302         prob.x[1][0] = new FeatureNode(2, 1);
303         prob.x[1][1] = new FeatureNode(4, 5);
304         prob.x[1][2] = new FeatureNode(5, 3);
305         prob.x[1][3] = new FeatureNode(7, 4);
306         prob.x[1][4] = new FeatureNode(8, 2);
307 
308         prob.x[2][0] = new FeatureNode(1, 9);
309         prob.x[2][1] = new FeatureNode(3, 1);
310         prob.x[2][2] = new FeatureNode(5, 1);
311         prob.x[2][3] = new FeatureNode(10, 7);
312 
313         prob.x[3][0] = new FeatureNode(1, 2);
314         prob.x[3][1] = new FeatureNode(2, 2);
315         prob.x[3][2] = new FeatureNode(3, 9);
316         prob.x[3][3] = new FeatureNode(4, 7);
317         prob.x[3][4] = new FeatureNode(5, 8);
318         prob.x[3][5] = new FeatureNode(6, 1);
319         prob.x[3][6] = new FeatureNode(7, 5);
320         prob.x[3][7] = new FeatureNode(8, 4);
321 
322         prob.x[4][0] = new FeatureNode(3, 1);
323         prob.x[4][1] = new FeatureNode(10, 3);
324 
325         prob.y = new int[5];
326         prob.y[0] = 0;
327         prob.y[1] = 1;
328         prob.y[2] = 1;
329         prob.y[3] = 0;
330         prob.y[4] = 1;
331 
332         Problem transposed = Linear.transpose(prob);
333 
334         assertThat(transposed.x[0]).hasSize(3);
335         assertThat(transposed.x[1]).hasSize(2);
336         assertThat(transposed.x[2]).hasSize(4);
337         assertThat(transposed.x[3]).hasSize(2);
338         assertThat(transposed.x[4]).hasSize(4);
339         assertThat(transposed.x[5]).hasSize(1);
340         assertThat(transposed.x[7]).hasSize(2);
341         assertThat(transposed.x[7]).hasSize(2);
342         assertThat(transposed.x[8]).hasSize(0);
343         assertThat(transposed.x[9]).hasSize(2);
344 
345         assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(1, 7));
346         assertThat(transposed.x[0][1]).isEqualTo(new FeatureNode(3, 9));
347         assertThat(transposed.x[0][2]).isEqualTo(new FeatureNode(4, 2));
348 
349         assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(2, 1));
350         assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(4, 2));
351 
352         assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(1, 3));
353         assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(3, 1));
354         assertThat(transposed.x[2][2]).isEqualTo(new FeatureNode(4, 9));
355         assertThat(transposed.x[2][3]).isEqualTo(new FeatureNode(5, 1));
356 
357         assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(2, 5));
358         assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(4, 7));
359 
360         assertThat(transposed.x[4][0]).isEqualTo(new FeatureNode(1, 2));
361         assertThat(transposed.x[4][1]).isEqualTo(new FeatureNode(2, 3));
362         assertThat(transposed.x[4][2]).isEqualTo(new FeatureNode(3, 1));
363         assertThat(transposed.x[4][3]).isEqualTo(new FeatureNode(4, 8));
364 
365         assertThat(transposed.x[5][0]).isEqualTo(new FeatureNode(4, 1));
366 
367         assertThat(transposed.x[6][0]).isEqualTo(new FeatureNode(2, 4));
368         assertThat(transposed.x[6][1]).isEqualTo(new FeatureNode(4, 5));
369 
370         assertThat(transposed.x[7][0]).isEqualTo(new FeatureNode(2, 2));
371         assertThat(transposed.x[7][1]).isEqualTo(new FeatureNode(4, 4));
372 
373         assertThat(transposed.x[9][0]).isEqualTo(new FeatureNode(3, 7));
374         assertThat(transposed.x[9][1]).isEqualTo(new FeatureNode(5, 3));
375 
376         assertThat(transposed.y).isEqualTo(prob.y);
377     }
378 
379     /**
380      * compared input/output values with the C version (1.51)
381      *
382      * IN:
383      * res prob.l = 3
384      * res prob.n = 4
385      * 0: (1,2) (3,1) (4,3)
386      * 1: (1,9) (2,7) (3,3) (4,3)
387      * 2: (2,1)
388      *
389      * TRANSPOSED:
390      *
391      * res prob.l = 3
392      *      * res prob.n = 4
393      * 0: (1,2) (2,9)
394      * 1: (2,7) (3,1)
395      * 2: (1,1) (2,3)
396      * 3: (1,3) (2,3)
397      *
398      */
399     @Test
400     public void testTranspose3() throws Exception {
401 
402         Problem prob = new Problem();
403         prob.l = 3;
404         prob.n = 4;
405         prob.y = new int[3];
406         prob.x = new FeatureNode[4][];
407         prob.x[0] = new FeatureNode[3];
408         prob.x[1] = new FeatureNode[4];
409         prob.x[2] = new FeatureNode[1];
410         prob.x[3] = new FeatureNode[1];
411 
412         prob.x[0][0] = new FeatureNode(1, 2);
413         prob.x[0][1] = new FeatureNode(3, 1);
414         prob.x[0][2] = new FeatureNode(4, 3);
415         prob.x[1][0] = new FeatureNode(1, 9);
416         prob.x[1][1] = new FeatureNode(2, 7);
417         prob.x[1][2] = new FeatureNode(3, 3);
418         prob.x[1][3] = new FeatureNode(4, 3);
419 
420         prob.x[2][0] = new FeatureNode(2, 1);
421 
422         prob.x[3][0] = new FeatureNode(3, 2);
423 
424         Problem transposed = Linear.transpose(prob);
425         assertThat(transposed.x).hasSize(4);
426         assertThat(transposed.x[0]).hasSize(2);
427         assertThat(transposed.x[1]).hasSize(2);
428         assertThat(transposed.x[2]).hasSize(2);
429         assertThat(transposed.x[3]).hasSize(2);
430 
431         assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(1, 2));
432         assertThat(transposed.x[0][1]).isEqualTo(new FeatureNode(2, 9));
433 
434         assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(2, 7));
435         assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(3, 1));
436 
437         assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(1, 1));
438         assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(2, 3));
439 
440         assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(1, 3));
441         assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(2, 3));
442     }
443 
444     /**
445      * create a very simple problem and check if the clearly separated examples are recognized as such
446      */
447     @Test
448     public void testTrainPredict() {
449         Problem prob = new Problem();
450         prob.bias = -1;
451         prob.l = 4;
452         prob.n = 4;
453         prob.x = new FeatureNode[4][];
454         prob.x[0] = new FeatureNode[2];
455         prob.x[1] = new FeatureNode[1];
456         prob.x[2] = new FeatureNode[1];
457         prob.x[3] = new FeatureNode[3];
458 
459         prob.x[0][0] = new FeatureNode(1, 1);
460         prob.x[0][1] = new FeatureNode(2, 1);
461 
462         prob.x[1][0] = new FeatureNode(3, 1);
463         prob.x[2][0] = new FeatureNode(3, 1);
464 
465         prob.x[3][0] = new FeatureNode(1, 2);
466         prob.x[3][1] = new FeatureNode(2, 1);
467         prob.x[3][2] = new FeatureNode(4, 1);
468 
469         prob.y = new int[4];
470         prob.y[0] = 0;
471         prob.y[1] = 1;
472         prob.y[2] = 1;
473         prob.y[3] = 0;
474 
475         for (SolverType solver : SolverType.values()) {
476             for (double C = 0.1; C <= 100.; C *= 1.2) {
477 
478                 // compared the behavior with the C version
479                 if (C < 0.2) if (solver == SolverType.L1R_L2LOSS_SVC) continue;
480                 if (C < 0.7) if (solver == SolverType.L1R_LR) continue;
481 
482                 Parameter param = new Parameter(solver, C, 0.1);
483                 Model model = Linear.train(prob, param);
484 
485                 double[] featureWeights = model.getFeatureWeights();
486                 if (solver == SolverType.MCSVM_CS) {
487                     assertThat(featureWeights.length).isEqualTo(8);
488                 } else {
489                     assertThat(featureWeights.length).isEqualTo(4);
490                 }
491 
492                 int i = 0;
493                 for (int value : prob.y) {
494                     int prediction = Linear.predict(model, prob.x[i]);
495                     assertThat(prediction).isEqualTo(value);
496                     if (model.isProbabilityModel()) {
497                         double[] estimates = new double[model.getNrClass()];
498                         int probabilityPrediction = Linear.predictProbability(model, prob.x[i], estimates);
499                         assertThat(probabilityPrediction).isEqualTo(prediction);
500                         assertThat(estimates[probabilityPrediction]).isGreaterThanOrEqualTo(1.0 / model.getNrClass());
501                         double estimationSum = 0;
502                         for (double estimate : estimates) {
503                             estimationSum += estimate;
504                         }
505                         assertThat(estimationSum).isEqualTo(1.0, Delta.delta(0.001));
506                     }
507                     i++;
508                 }
509             }
510         }
511     }
512 
513     @Test
514     public void testTrainUnsortedProblem() {
515         Problem prob = new Problem();
516         prob.bias = -1;
517         prob.l = 1;
518         prob.n = 2;
519         prob.x = new FeatureNode[4][];
520         prob.x[0] = new FeatureNode[2];
521 
522         prob.x[0][0] = new FeatureNode(2, 1);
523         prob.x[0][1] = new FeatureNode(1, 1);
524 
525         prob.y = new int[4];
526         prob.y[0] = 0;
527 
528         Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.1);
529         try {
530             Linear.train(prob, param);
531             fail("IllegalArgumentException expected");
532         } catch (IllegalArgumentException e) {
533             assertThat(e.getMessage()).contains("nodes").contains("sorted").contains("ascending").contains("order");
534         }
535     }
536 }