View Javadoc

1   package liblinear;
2   
3   import static liblinear.Linear.NL;
4   import static liblinear.Linear.atof;
5   import static liblinear.Linear.atoi;
6   import static liblinear.Linear.closeQuietly;
7   import static liblinear.Linear.printf;
8   
9   import java.io.BufferedReader;
10  import java.io.BufferedWriter;
11  import java.io.File;
12  import java.io.FileInputStream;
13  import java.io.FileOutputStream;
14  import java.io.IOException;
15  import java.io.InputStreamReader;
16  import java.io.OutputStreamWriter;
17  import java.io.Writer;
18  import java.util.ArrayList;
19  import java.util.Formatter;
20  import java.util.List;
21  import java.util.StringTokenizer;
22  import java.util.regex.Pattern;
23  
24  
25  public class Predict {
26  
27      private static boolean       flag_predict_probability = false;
28  
29      private static final Pattern COLON                    = Pattern.compile(":");
30  
31      /**
32       * <p><b>Note: The streams are NOT closed</b></p>
33       */
34      static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException {
35          int correct = 0;
36          int total = 0;
37  
38          int nr_class = model.getNrClass();
39          double[] prob_estimates = null;
40          int n;
41          int nr_feature = model.getNrFeature();
42          if (model.bias >= 0)
43              n = nr_feature + 1;
44          else
45              n = nr_feature;
46  
47          Formatter out = new Formatter(writer);
48  
49          if (flag_predict_probability) {
50              if (model.solverType != SolverType.L2R_LR) {
51                  throw new IllegalArgumentException("probability output is only supported for logistic regression");
52              }
53  
54              int[] labels = model.getLabels();
55              prob_estimates = new double[nr_class];
56  
57              printf(out, "labels");
58              for (int j = 0; j < nr_class; j++)
59                  printf(out, " %d", labels[j]);
60              printf(out, "\n");
61          }
62  
63  
64          String line = null;
65          while ((line = reader.readLine()) != null) {
66              List<FeatureNode> x = new ArrayList<FeatureNode>();
67              StringTokenizer st = new StringTokenizer(line, " \t");
68              String label = st.nextToken();
69              int target_label = atoi(label);
70  
71              while (st.hasMoreTokens()) {
72                  String[] split = COLON.split(st.nextToken(), 2);
73                  if (split == null || split.length < 2) exit_input_error(total + 1);
74  
75                  try {
76                      int idx = atoi(split[0]);
77                      double val = atof(split[1]);
78  
79                      // feature indices larger than those in training are not used
80                      if (idx <= nr_feature) {
81                          FeatureNode node = new FeatureNode(idx, val);
82                          x.add(node);
83                      }
84                  } catch (NumberFormatException e) {
85                      exit_input_error(total + 1, e);
86                  }
87              }
88  
89              if (model.bias >= 0) {
90                  FeatureNode node = new FeatureNode(n, model.bias);
91                  x.add(node);
92              }
93  
94              FeatureNode[] nodes = new FeatureNode[x.size()];
95              nodes = x.toArray(nodes);
96  
97              int predict_label;
98  
99              if (flag_predict_probability) {
100                 predict_label = Linear.predictProbability(model, nodes, prob_estimates);
101                 printf(out, "%d", predict_label);
102                 for (int j = 0; j < model.nr_class; j++)
103                     printf(out, " %g", prob_estimates[j]);
104                 printf(out, "\n");
105             } else {
106                 predict_label = Linear.predict(model, nodes);
107                 printf(out, "%d\n", predict_label);
108             }
109 
110             if (predict_label == target_label) {
111                 ++correct;
112             }
113             ++total;
114         }
115         System.out.printf("Accuracy = %g%% (%d/%d)" + NL, (double)correct / total * 100, correct, total);
116     }
117 
118     private static void exit_input_error(int line_num, Throwable cause) {
119         throw new RuntimeException("Wrong input format at line " + line_num, cause);
120     }
121 
122     private static void exit_input_error(int line_num) {
123         throw new RuntimeException("Wrong input format at line " + line_num);
124     }
125 
126     private static void exit_with_help() {
127         System.out.println("Usage: predict [options] test_file model_file output_file" + NL //
128             + "options:" + NL //
129             + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0)" + NL //
130         );
131         System.exit(1);
132     }
133 
134     public static void main(String[] argv) throws IOException {
135         int i;
136 
137         // parse options
138         for (i = 0; i < argv.length; i++) {
139             if (argv[i].charAt(0) != '-') break;
140             ++i;
141             switch (argv[i - 1].charAt(1)) {
142                 case 'b':
143                     try {
144                         flag_predict_probability = (atoi(argv[i]) != 0);
145                     } catch (NumberFormatException e) {
146                         exit_with_help();
147                     }
148                     break;
149 
150                 default:
151                     System.err.println("unknown option: -" + argv[i - 1].charAt(1) + NL);
152                     exit_with_help();
153                     break;
154             }
155         }
156         if (i >= argv.length || argv.length <= i + 2) {
157             exit_with_help();
158         }
159 
160         BufferedReader reader = null;
161         Writer writer = null;
162         try {
163             reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), Linear.FILE_CHARSET));
164             writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), Linear.FILE_CHARSET));
165 
166             Model model = Linear.loadModel(new File(argv[i + 1]));
167             doPredict(reader, writer, model);
168         }
169         finally {
170             closeQuietly(reader);
171             closeQuietly(writer);
172         }
173     }
174 }