1 package de.bwaldvogel.liblinear;
2
3 import static de.bwaldvogel.liblinear.Linear.NL;
4 import static org.fest.assertions.Assertions.assertThat;
5
6 import java.io.BufferedWriter;
7 import java.io.File;
8 import java.io.FileWriter;
9 import java.util.ArrayList;
10 import java.util.Collection;
11
12 import org.junit.Test;
13
14
15 public class TrainTest {
16
17 @Test
18 public void testParseCommandLine() {
19 Train train = new Train();
20
21 for (SolverType solver : SolverType.values()) {
22 train.parse_command_line(new String[] {"-B", "5.3", "-s", "" + solver.ordinal(), "model-filename"});
23 Parameter param = train.getParameter();
24 assertThat(param.solverType).isEqualTo(solver);
25
26 if (solver.ordinal() == 0 || solver.ordinal() == 2
27 || solver.ordinal() == 5 || solver.ordinal() == 6) {
28 assertThat(param.eps).isEqualTo(0.01);
29 } else {
30 assertThat(param.eps).isEqualTo(0.1);
31 }
32
33 assertThat(train.getBias()).isEqualTo(5.3);
34 }
35 }
36
37 @Test
38 public void testReadProblem() throws Exception {
39
40 File file = File.createTempFile("svm", "test");
41 file.deleteOnExit();
42
43 Collection<String> lines = new ArrayList<String>();
44 lines.add("1 1:1 3:1 4:1 6:1");
45 lines.add("2 2:1 3:1 5:1 7:1");
46 lines.add("1 3:1 5:1");
47 lines.add("1 1:1 4:1 7:1");
48 lines.add("2 4:1 5:1 7:1");
49 BufferedWriter writer = new BufferedWriter(new FileWriter(file));
50 try {
51 for (String line : lines)
52 writer.append(line).append(NL);
53 }
54 finally {
55 writer.close();
56 }
57
58 Train train = new Train();
59 train.readProblem(file.getAbsolutePath());
60
61 Problem prob = train.getProblem();
62 assertThat(prob.bias).isEqualTo(1);
63 assertThat(prob.y).hasSize(lines.size());
64 assertThat(prob.y).isEqualTo(new int[] {1, 2, 1, 1, 2});
65 assertThat(prob.n).isEqualTo(8);
66 assertThat(prob.l).isEqualTo(prob.y.length);
67 assertThat(prob.x).hasSize(prob.y.length);
68
69 for (FeatureNode[] nodes : prob.x) {
70
71 assertThat(nodes.length).isLessThanOrEqualTo(prob.n);
72 for (FeatureNode node : nodes) {
73
74 if (prob.bias >= 0 && nodes[nodes.length - 1] == node) {
75 assertThat(node.index).isEqualTo(prob.n);
76 assertThat(node.value).isEqualTo(prob.bias);
77 } else {
78 assertThat(node.index).isLessThan(prob.n);
79 }
80 }
81 }
82 }
83
84
85
86
87 @Test
88 public void testReadProblemEmptyLine() throws Exception {
89
90 File file = File.createTempFile("svm", "test");
91 file.deleteOnExit();
92
93 Collection<String> lines = new ArrayList<String>();
94 lines.add("1 1:1 3:1 4:1 6:1");
95 lines.add("2 ");
96 BufferedWriter writer = new BufferedWriter(new FileWriter(file));
97 try {
98 for (String line : lines)
99 writer.append(line).append(NL);
100 }
101 finally {
102 writer.close();
103 }
104
105 Problem prob = Train.readProblem(file, -1.0);
106 assertThat(prob.bias).isEqualTo(-1);
107 assertThat(prob.y).hasSize(lines.size());
108 assertThat(prob.y).isEqualTo(new int[] {1, 2});
109 assertThat(prob.n).isEqualTo(6);
110 assertThat(prob.l).isEqualTo(prob.y.length);
111 assertThat(prob.x).hasSize(prob.y.length);
112
113 assertThat(prob.x[0]).hasSize(4);
114 assertThat(prob.x[1]).hasSize(0);
115 }
116
117 @Test(expected = InvalidInputDataException.class)
118 public void testReadUnsortedProblem() throws Exception {
119 File file = File.createTempFile("svm", "test");
120 file.deleteOnExit();
121
122 Collection<String> lines = new ArrayList<String>();
123 lines.add("1 1:1 3:1 4:1 6:1");
124 lines.add("2 2:1 3:1 5:1 7:1");
125 lines.add("1 3:1 5:1 4:1");
126
127 BufferedWriter writer = new BufferedWriter(new FileWriter(file));
128 try {
129 for (String line : lines)
130 writer.append(line).append(NL);
131 }
132 finally {
133 writer.close();
134 }
135
136 Train train = new Train();
137 train.readProblem(file.getAbsolutePath());
138 }
139
140
141 @Test(expected = InvalidInputDataException.class)
142 public void testReadProblemWithInvalidIndex() throws Exception {
143 File file = File.createTempFile("svm", "test");
144 file.deleteOnExit();
145
146 Collection<String> lines = new ArrayList<String>();
147 lines.add("1 1:1 3:1 4:1 6:1");
148 lines.add("2 2:1 3:1 5:1 -4:1");
149
150 BufferedWriter writer = new BufferedWriter(new FileWriter(file));
151 try {
152 for (String line : lines)
153 writer.append(line).append(NL);
154 }
155 finally {
156 writer.close();
157 }
158
159 Train train = new Train();
160 try {
161 train.readProblem(file.getAbsolutePath());
162 } catch (InvalidInputDataException e) {
163 throw e;
164 }
165 }
166
167 @Test(expected = InvalidInputDataException.class)
168 public void testReadWrongProblem() throws Exception {
169 File file = File.createTempFile("svm", "test");
170 file.deleteOnExit();
171
172 Collection<String> lines = new ArrayList<String>();
173 lines.add("1 1:1 3:1 4:1 6:1");
174 lines.add("2 2:1 3:1 5:1 7:1");
175 lines.add("1 3:1 5:a");
176
177 BufferedWriter writer = new BufferedWriter(new FileWriter(file));
178 try {
179 for (String line : lines)
180 writer.append(line).append(NL);
181 }
182 finally {
183 writer.close();
184 }
185
186 Train train = new Train();
187 try {
188 train.readProblem(file.getAbsolutePath());
189 } catch (InvalidInputDataException e) {
190 throw e;
191 }
192 }
193 }