View Javadoc

1   package de.bwaldvogel.liblinear;
2   
3   import static de.bwaldvogel.liblinear.Linear.NL;
4   import static org.fest.assertions.Assertions.assertThat;
5   
6   import java.io.BufferedWriter;
7   import java.io.File;
8   import java.io.FileWriter;
9   import java.util.ArrayList;
10  import java.util.Collection;
11  
12  import org.junit.Test;
13  
14  
15  public class TrainTest {
16  
17      @Test
18      public void testParseCommandLine() {
19          Train train = new Train();
20  
21          for (SolverType solver : SolverType.values()) {
22              train.parse_command_line(new String[] {"-B", "5.3", "-s", "" + solver.ordinal(), "model-filename"});
23              Parameter param = train.getParameter();
24              assertThat(param.solverType).isEqualTo(solver);
25              // check default eps
26              if (solver.ordinal() == 0 || solver.ordinal() == 2 //
27                  || solver.ordinal() == 5 || solver.ordinal() == 6) {
28                  assertThat(param.eps).isEqualTo(0.01);
29              } else {
30                  assertThat(param.eps).isEqualTo(0.1);
31              }
32              // check if bias is set
33              assertThat(train.getBias()).isEqualTo(5.3);
34          }
35      }
36  
37      @Test
38      public void testReadProblem() throws Exception {
39  
40          File file = File.createTempFile("svm", "test");
41          file.deleteOnExit();
42  
43          Collection<String> lines = new ArrayList<String>();
44          lines.add("1 1:1  3:1  4:1   6:1");
45          lines.add("2 2:1  3:1  5:1   7:1");
46          lines.add("1 3:1  5:1");
47          lines.add("1 1:1  4:1  7:1");
48          lines.add("2 4:1  5:1  7:1");
49          BufferedWriter writer = new BufferedWriter(new FileWriter(file));
50          try {
51              for (String line : lines)
52                  writer.append(line).append(NL);
53          }
54          finally {
55              writer.close();
56          }
57  
58          Train train = new Train();
59          train.readProblem(file.getAbsolutePath());
60  
61          Problem prob = train.getProblem();
62          assertThat(prob.bias).isEqualTo(1);
63          assertThat(prob.y).hasSize(lines.size());
64          assertThat(prob.y).isEqualTo(new int[] {1, 2, 1, 1, 2});
65          assertThat(prob.n).isEqualTo(8);
66          assertThat(prob.l).isEqualTo(prob.y.length);
67          assertThat(prob.x).hasSize(prob.y.length);
68  
69          for (FeatureNode[] nodes : prob.x) {
70  
71              assertThat(nodes.length).isLessThanOrEqualTo(prob.n);
72              for (FeatureNode node : nodes) {
73                  // bias term
74                  if (prob.bias >= 0 && nodes[nodes.length - 1] == node) {
75                      assertThat(node.index).isEqualTo(prob.n);
76                      assertThat(node.value).isEqualTo(prob.bias);
77                  } else {
78                      assertThat(node.index).isLessThan(prob.n);
79                  }
80              }
81          }
82      }
83  
84      /**
85       * unit-test for Issue #1 (http://github.com/bwaldvogel/liblinear-java/issues#issue/1)
86       */
87      @Test
88      public void testReadProblemEmptyLine() throws Exception {
89  
90          File file = File.createTempFile("svm", "test");
91          file.deleteOnExit();
92  
93          Collection<String> lines = new ArrayList<String>();
94          lines.add("1 1:1  3:1  4:1   6:1");
95          lines.add("2 ");
96          BufferedWriter writer = new BufferedWriter(new FileWriter(file));
97          try {
98              for (String line : lines)
99                  writer.append(line).append(NL);
100         }
101         finally {
102             writer.close();
103         }
104 
105         Problem prob = Train.readProblem(file, -1.0);
106         assertThat(prob.bias).isEqualTo(-1);
107         assertThat(prob.y).hasSize(lines.size());
108         assertThat(prob.y).isEqualTo(new int[] {1, 2});
109         assertThat(prob.n).isEqualTo(6);
110         assertThat(prob.l).isEqualTo(prob.y.length);
111         assertThat(prob.x).hasSize(prob.y.length);
112 
113         assertThat(prob.x[0]).hasSize(4);
114         assertThat(prob.x[1]).hasSize(0);
115     }
116 
117     @Test(expected = InvalidInputDataException.class)
118     public void testReadUnsortedProblem() throws Exception {
119         File file = File.createTempFile("svm", "test");
120         file.deleteOnExit();
121 
122         Collection<String> lines = new ArrayList<String>();
123         lines.add("1 1:1  3:1  4:1   6:1");
124         lines.add("2 2:1  3:1  5:1   7:1");
125         lines.add("1 3:1  5:1  4:1"); // here's the mistake: not correctly sorted
126 
127         BufferedWriter writer = new BufferedWriter(new FileWriter(file));
128         try {
129             for (String line : lines)
130                 writer.append(line).append(NL);
131         }
132         finally {
133             writer.close();
134         }
135 
136         Train train = new Train();
137         train.readProblem(file.getAbsolutePath());
138     }
139 
140 
141     @Test(expected = InvalidInputDataException.class)
142     public void testReadProblemWithInvalidIndex() throws Exception {
143         File file = File.createTempFile("svm", "test");
144         file.deleteOnExit();
145 
146         Collection<String> lines = new ArrayList<String>();
147         lines.add("1 1:1  3:1  4:1   6:1");
148         lines.add("2 2:1  3:1  5:1  -4:1");
149 
150         BufferedWriter writer = new BufferedWriter(new FileWriter(file));
151         try {
152             for (String line : lines)
153                 writer.append(line).append(NL);
154         }
155         finally {
156             writer.close();
157         }
158 
159         Train train = new Train();
160         try {
161             train.readProblem(file.getAbsolutePath());
162         } catch (InvalidInputDataException e) {
163             throw e;
164         }
165     }
166 
167     @Test(expected = InvalidInputDataException.class)
168     public void testReadWrongProblem() throws Exception {
169         File file = File.createTempFile("svm", "test");
170         file.deleteOnExit();
171 
172         Collection<String> lines = new ArrayList<String>();
173         lines.add("1 1:1  3:1  4:1   6:1");
174         lines.add("2 2:1  3:1  5:1   7:1");
175         lines.add("1 3:1  5:a"); // here's the mistake: incomplete line
176 
177         BufferedWriter writer = new BufferedWriter(new FileWriter(file));
178         try {
179             for (String line : lines)
180                 writer.append(line).append(NL);
181         }
182         finally {
183             writer.close();
184         }
185 
186         Train train = new Train();
187         try {
188             train.readProblem(file.getAbsolutePath());
189         } catch (InvalidInputDataException e) {
190             throw e;
191         }
192     }
193 }