1 package de.bwaldvogel.liblinear;
2
3 import static org.fest.assertions.Assertions.assertThat;
4 import static org.fest.assertions.Fail.fail;
5 import static org.mockito.Mockito.doThrow;
6 import static org.mockito.Mockito.times;
7 import static org.mockito.Mockito.verify;
8
9 import java.io.File;
10 import java.io.IOException;
11 import java.io.Writer;
12 import java.util.ArrayList;
13 import java.util.Collections;
14 import java.util.List;
15 import java.util.Random;
16 import java.util.Set;
17 import java.util.TreeSet;
18
19 import org.fest.assertions.Delta;
20 import org.junit.BeforeClass;
21 import org.junit.Test;
22 import org.powermock.api.mockito.PowerMockito;
23
24
25 public class LinearTest {
26
27 private static Random random = new Random(0);
28
29 @BeforeClass
30 public static void disableDebugOutput() {
31 Linear.disableDebugOutput();
32 }
33
34 public static Model createRandomModel() {
35 Model model = new Model();
36 model.solverType = SolverType.L2R_LR;
37 model.bias = 2;
38 model.label = new int[] {1, Integer.MAX_VALUE, 2};
39 model.w = new double[model.label.length * 300];
40 for (int i = 0; i < model.w.length; i++) {
41
42 model.w[i] = Math.round(random.nextDouble() * 100000.0) / 10000.0;
43 }
44
45
46 model.w[random.nextInt(model.w.length)] = 0.0;
47 model.w[random.nextInt(model.w.length)] = -0.0;
48
49 model.nr_feature = model.w.length / model.label.length - 1;
50 model.nr_class = model.label.length;
51 return model;
52 }
53
54 public static Problem createRandomProblem(int numClasses) {
55 Problem prob = new Problem();
56 prob.bias = -1;
57 prob.l = random.nextInt(100) + 1;
58 prob.n = random.nextInt(100) + 1;
59 prob.x = new FeatureNode[prob.l][];
60 prob.y = new int[prob.l];
61
62 for (int i = 0; i < prob.l; i++) {
63
64 prob.y[i] = random.nextInt(numClasses);
65
66 Set<Integer> randomNumbers = new TreeSet<Integer>();
67 int num = random.nextInt(prob.n) + 1;
68 for (int j = 0; j < num; j++) {
69 randomNumbers.add(random.nextInt(prob.n) + 1);
70 }
71 List<Integer> randomIndices = new ArrayList<Integer>(randomNumbers);
72 Collections.sort(randomIndices);
73
74 prob.x[i] = new FeatureNode[randomIndices.size()];
75 for (int j = 0; j < randomIndices.size(); j++) {
76 prob.x[i][j] = new FeatureNode(randomIndices.get(j), random.nextDouble());
77 }
78 }
79 return prob;
80 }
81
82 @Test
83 public void testRealloc() {
84
85 int[] f = new int[] {1, 2, 3};
86 f = Linear.copyOf(f, 5);
87 f[3] = 4;
88 f[4] = 5;
89 assertThat(f).isEqualTo(new int[] {1, 2, 3, 4, 5});
90 }
91
92 @Test
93 public void testAtoi() {
94 assertThat(Linear.atoi("+25")).isEqualTo(25);
95 assertThat(Linear.atoi("-345345")).isEqualTo(-345345);
96 assertThat(Linear.atoi("+0")).isEqualTo(0);
97 assertThat(Linear.atoi("0")).isEqualTo(0);
98 assertThat(Linear.atoi("2147483647")).isEqualTo(Integer.MAX_VALUE);
99 assertThat(Linear.atoi("-2147483648")).isEqualTo(Integer.MIN_VALUE);
100 }
101
102 @Test(expected = NumberFormatException.class)
103 public void testAtoiInvalidData() {
104 Linear.atoi("+");
105 }
106
107 @Test(expected = NumberFormatException.class)
108 public void testAtoiInvalidData2() {
109 Linear.atoi("abc");
110 }
111
112 @Test(expected = NumberFormatException.class)
113 public void testAtoiInvalidData3() {
114 Linear.atoi(" ");
115 }
116
117 @Test
118 public void testAtof() {
119 assertThat(Linear.atof("+25")).isEqualTo(25);
120 assertThat(Linear.atof("-25.12345678")).isEqualTo(-25.12345678);
121 assertThat(Linear.atof("0.345345299")).isEqualTo(0.345345299);
122 }
123
124 @Test(expected = NumberFormatException.class)
125 public void testAtofInvalidData() {
126 Linear.atof("0.5t");
127 }
128
129 @Test
130 public void testLoadSaveModel() throws Exception {
131
132 Model model = null;
133 for (SolverType solverType : SolverType.values()) {
134 model = createRandomModel();
135 model.solverType = solverType;
136
137 File tempFile = File.createTempFile("liblinear", "modeltest");
138 tempFile.deleteOnExit();
139 Linear.saveModel(tempFile, model);
140
141 Model loadedModel = Linear.loadModel(tempFile);
142 assertThat(loadedModel).isEqualTo(model);
143 }
144 }
145
146 @Test
147 public void testCrossValidation() throws Exception {
148
149 int numClasses = random.nextInt(10) + 1;
150
151 Problem prob = createRandomProblem(numClasses);
152
153 Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.01);
154 int nr_fold = 10;
155 int[] target = new int[prob.l];
156 Linear.crossValidation(prob, param, nr_fold, target);
157
158 for (int clazz : target) {
159 assertThat(clazz).isGreaterThanOrEqualTo(0).isLessThan(numClasses);
160 }
161 }
162
163 @Test
164 public void testSaveModelWithIOException() throws Exception {
165 Model model = createRandomModel();
166
167 Writer out = PowerMockito.mock(Writer.class);
168
169 IOException ioException = new IOException("some reason");
170
171 doThrow(ioException).when(out).flush();
172
173 try {
174 Linear.saveModel(out, model);
175 fail("IOException expected");
176 } catch (IOException e) {
177 assertThat(e).isEqualTo(ioException);
178 }
179
180 verify(out).flush();
181 verify(out, times(1)).close();
182 }
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206 @Test
207 public void testTranspose() throws Exception {
208 Problem prob = new Problem();
209 prob.bias = -1;
210 prob.l = 4;
211 prob.n = 4;
212 prob.x = new FeatureNode[4][];
213 prob.x[0] = new FeatureNode[2];
214 prob.x[1] = new FeatureNode[1];
215 prob.x[2] = new FeatureNode[1];
216 prob.x[3] = new FeatureNode[3];
217
218 prob.x[0][0] = new FeatureNode(2, 1);
219 prob.x[0][1] = new FeatureNode(4, 1);
220
221 prob.x[1][0] = new FeatureNode(1, 1);
222 prob.x[2][0] = new FeatureNode(3, 1);
223
224 prob.x[3][0] = new FeatureNode(2, 2);
225 prob.x[3][1] = new FeatureNode(3, 1);
226 prob.x[3][2] = new FeatureNode(4, 1);
227
228 prob.y = new int[4];
229 prob.y[0] = 0;
230 prob.y[1] = 1;
231 prob.y[2] = 1;
232 prob.y[3] = 0;
233
234 Problem transposed = Linear.transpose(prob);
235
236 assertThat(transposed.x[0].length).isEqualTo(1);
237 assertThat(transposed.x[1].length).isEqualTo(2);
238 assertThat(transposed.x[2].length).isEqualTo(2);
239 assertThat(transposed.x[3].length).isEqualTo(2);
240
241 assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(2, 1));
242
243 assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(1, 1));
244 assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(4, 2));
245
246 assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(3, 1));
247 assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(4, 1));
248
249 assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(1, 1));
250 assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(4, 1));
251
252 assertThat(transposed.y).isEqualTo(prob.y);
253 }
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285 @Test
286 public void testTranspose2() throws Exception {
287 Problem prob = new Problem();
288 prob.bias = -1;
289 prob.l = 5;
290 prob.n = 10;
291 prob.x = new FeatureNode[5][];
292 prob.x[0] = new FeatureNode[3];
293 prob.x[1] = new FeatureNode[5];
294 prob.x[2] = new FeatureNode[4];
295 prob.x[3] = new FeatureNode[8];
296 prob.x[4] = new FeatureNode[2];
297
298 prob.x[0][0] = new FeatureNode(1, 7);
299 prob.x[0][1] = new FeatureNode(3, 3);
300 prob.x[0][2] = new FeatureNode(5, 2);
301
302 prob.x[1][0] = new FeatureNode(2, 1);
303 prob.x[1][1] = new FeatureNode(4, 5);
304 prob.x[1][2] = new FeatureNode(5, 3);
305 prob.x[1][3] = new FeatureNode(7, 4);
306 prob.x[1][4] = new FeatureNode(8, 2);
307
308 prob.x[2][0] = new FeatureNode(1, 9);
309 prob.x[2][1] = new FeatureNode(3, 1);
310 prob.x[2][2] = new FeatureNode(5, 1);
311 prob.x[2][3] = new FeatureNode(10, 7);
312
313 prob.x[3][0] = new FeatureNode(1, 2);
314 prob.x[3][1] = new FeatureNode(2, 2);
315 prob.x[3][2] = new FeatureNode(3, 9);
316 prob.x[3][3] = new FeatureNode(4, 7);
317 prob.x[3][4] = new FeatureNode(5, 8);
318 prob.x[3][5] = new FeatureNode(6, 1);
319 prob.x[3][6] = new FeatureNode(7, 5);
320 prob.x[3][7] = new FeatureNode(8, 4);
321
322 prob.x[4][0] = new FeatureNode(3, 1);
323 prob.x[4][1] = new FeatureNode(10, 3);
324
325 prob.y = new int[5];
326 prob.y[0] = 0;
327 prob.y[1] = 1;
328 prob.y[2] = 1;
329 prob.y[3] = 0;
330 prob.y[4] = 1;
331
332 Problem transposed = Linear.transpose(prob);
333
334 assertThat(transposed.x[0]).hasSize(3);
335 assertThat(transposed.x[1]).hasSize(2);
336 assertThat(transposed.x[2]).hasSize(4);
337 assertThat(transposed.x[3]).hasSize(2);
338 assertThat(transposed.x[4]).hasSize(4);
339 assertThat(transposed.x[5]).hasSize(1);
340 assertThat(transposed.x[7]).hasSize(2);
341 assertThat(transposed.x[7]).hasSize(2);
342 assertThat(transposed.x[8]).hasSize(0);
343 assertThat(transposed.x[9]).hasSize(2);
344
345 assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(1, 7));
346 assertThat(transposed.x[0][1]).isEqualTo(new FeatureNode(3, 9));
347 assertThat(transposed.x[0][2]).isEqualTo(new FeatureNode(4, 2));
348
349 assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(2, 1));
350 assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(4, 2));
351
352 assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(1, 3));
353 assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(3, 1));
354 assertThat(transposed.x[2][2]).isEqualTo(new FeatureNode(4, 9));
355 assertThat(transposed.x[2][3]).isEqualTo(new FeatureNode(5, 1));
356
357 assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(2, 5));
358 assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(4, 7));
359
360 assertThat(transposed.x[4][0]).isEqualTo(new FeatureNode(1, 2));
361 assertThat(transposed.x[4][1]).isEqualTo(new FeatureNode(2, 3));
362 assertThat(transposed.x[4][2]).isEqualTo(new FeatureNode(3, 1));
363 assertThat(transposed.x[4][3]).isEqualTo(new FeatureNode(4, 8));
364
365 assertThat(transposed.x[5][0]).isEqualTo(new FeatureNode(4, 1));
366
367 assertThat(transposed.x[6][0]).isEqualTo(new FeatureNode(2, 4));
368 assertThat(transposed.x[6][1]).isEqualTo(new FeatureNode(4, 5));
369
370 assertThat(transposed.x[7][0]).isEqualTo(new FeatureNode(2, 2));
371 assertThat(transposed.x[7][1]).isEqualTo(new FeatureNode(4, 4));
372
373 assertThat(transposed.x[9][0]).isEqualTo(new FeatureNode(3, 7));
374 assertThat(transposed.x[9][1]).isEqualTo(new FeatureNode(5, 3));
375
376 assertThat(transposed.y).isEqualTo(prob.y);
377 }
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399 @Test
400 public void testTranspose3() throws Exception {
401
402 Problem prob = new Problem();
403 prob.l = 3;
404 prob.n = 4;
405 prob.y = new int[3];
406 prob.x = new FeatureNode[4][];
407 prob.x[0] = new FeatureNode[3];
408 prob.x[1] = new FeatureNode[4];
409 prob.x[2] = new FeatureNode[1];
410 prob.x[3] = new FeatureNode[1];
411
412 prob.x[0][0] = new FeatureNode(1, 2);
413 prob.x[0][1] = new FeatureNode(3, 1);
414 prob.x[0][2] = new FeatureNode(4, 3);
415 prob.x[1][0] = new FeatureNode(1, 9);
416 prob.x[1][1] = new FeatureNode(2, 7);
417 prob.x[1][2] = new FeatureNode(3, 3);
418 prob.x[1][3] = new FeatureNode(4, 3);
419
420 prob.x[2][0] = new FeatureNode(2, 1);
421
422 prob.x[3][0] = new FeatureNode(3, 2);
423
424 Problem transposed = Linear.transpose(prob);
425 assertThat(transposed.x).hasSize(4);
426 assertThat(transposed.x[0]).hasSize(2);
427 assertThat(transposed.x[1]).hasSize(2);
428 assertThat(transposed.x[2]).hasSize(2);
429 assertThat(transposed.x[3]).hasSize(2);
430
431 assertThat(transposed.x[0][0]).isEqualTo(new FeatureNode(1, 2));
432 assertThat(transposed.x[0][1]).isEqualTo(new FeatureNode(2, 9));
433
434 assertThat(transposed.x[1][0]).isEqualTo(new FeatureNode(2, 7));
435 assertThat(transposed.x[1][1]).isEqualTo(new FeatureNode(3, 1));
436
437 assertThat(transposed.x[2][0]).isEqualTo(new FeatureNode(1, 1));
438 assertThat(transposed.x[2][1]).isEqualTo(new FeatureNode(2, 3));
439
440 assertThat(transposed.x[3][0]).isEqualTo(new FeatureNode(1, 3));
441 assertThat(transposed.x[3][1]).isEqualTo(new FeatureNode(2, 3));
442 }
443
444
445
446
447 @Test
448 public void testTrainPredict() {
449 Problem prob = new Problem();
450 prob.bias = -1;
451 prob.l = 4;
452 prob.n = 4;
453 prob.x = new FeatureNode[4][];
454 prob.x[0] = new FeatureNode[2];
455 prob.x[1] = new FeatureNode[1];
456 prob.x[2] = new FeatureNode[1];
457 prob.x[3] = new FeatureNode[3];
458
459 prob.x[0][0] = new FeatureNode(1, 1);
460 prob.x[0][1] = new FeatureNode(2, 1);
461
462 prob.x[1][0] = new FeatureNode(3, 1);
463 prob.x[2][0] = new FeatureNode(3, 1);
464
465 prob.x[3][0] = new FeatureNode(1, 2);
466 prob.x[3][1] = new FeatureNode(2, 1);
467 prob.x[3][2] = new FeatureNode(4, 1);
468
469 prob.y = new int[4];
470 prob.y[0] = 0;
471 prob.y[1] = 1;
472 prob.y[2] = 1;
473 prob.y[3] = 0;
474
475 for (SolverType solver : SolverType.values()) {
476 for (double C = 0.1; C <= 100.; C *= 1.2) {
477
478
479 if (C < 0.2) if (solver == SolverType.L1R_L2LOSS_SVC) continue;
480 if (C < 0.7) if (solver == SolverType.L1R_LR) continue;
481
482 Parameter param = new Parameter(solver, C, 0.1);
483 Model model = Linear.train(prob, param);
484
485 double[] featureWeights = model.getFeatureWeights();
486 if (solver == SolverType.MCSVM_CS) {
487 assertThat(featureWeights.length).isEqualTo(8);
488 } else {
489 assertThat(featureWeights.length).isEqualTo(4);
490 }
491
492 int i = 0;
493 for (int value : prob.y) {
494 int prediction = Linear.predict(model, prob.x[i]);
495 assertThat(prediction).isEqualTo(value);
496 if (model.isProbabilityModel()) {
497 double[] estimates = new double[model.getNrClass()];
498 int probabilityPrediction = Linear.predictProbability(model, prob.x[i], estimates);
499 assertThat(probabilityPrediction).isEqualTo(prediction);
500 assertThat(estimates[probabilityPrediction]).isGreaterThanOrEqualTo(1.0 / model.getNrClass());
501 double estimationSum = 0;
502 for (double estimate : estimates) {
503 estimationSum += estimate;
504 }
505 assertThat(estimationSum).isEqualTo(1.0, Delta.delta(0.001));
506 }
507 i++;
508 }
509 }
510 }
511 }
512
513 @Test
514 public void testTrainUnsortedProblem() {
515 Problem prob = new Problem();
516 prob.bias = -1;
517 prob.l = 1;
518 prob.n = 2;
519 prob.x = new FeatureNode[4][];
520 prob.x[0] = new FeatureNode[2];
521
522 prob.x[0][0] = new FeatureNode(2, 1);
523 prob.x[0][1] = new FeatureNode(1, 1);
524
525 prob.y = new int[4];
526 prob.y[0] = 0;
527
528 Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.1);
529 try {
530 Linear.train(prob, param);
531 fail("IllegalArgumentException expected");
532 } catch (IllegalArgumentException e) {
533 assertThat(e.getMessage()).contains("nodes").contains("sorted").contains("ascending").contains("order");
534 }
535 }
536 }