View Javadoc

1   package de.bwaldvogel.liblinear;
2   
3   import static de.bwaldvogel.liblinear.Linear.copyOf;
4   
5   import java.io.File;
6   import java.io.IOException;
7   import java.io.Reader;
8   import java.io.Serializable;
9   import java.io.Writer;
10  import java.util.Arrays;
11  
12  
13  /**
14   * <p>Model stores the model obtained from the training procedure</p>
15   *
16   * <p>use {@link Linear#loadModel(File)} and {@link Linear#saveModel(File, Model)} to load/save it</p>
17   */
18  public final class Model implements Serializable {
19  
20      private static final long serialVersionUID = -6456047576741854834L;
21  
22      double                    bias;
23  
24      /** label of each class */
25      int[]                     label;
26  
27      int                       nr_class;
28  
29      int                       nr_feature;
30  
31      SolverType                solverType;
32  
33      /** feature weight array */
34      double[]                  w;
35  
36      /**
37       * @return number of classes
38       */
39      public int getNrClass() {
40          return nr_class;
41      }
42  
43      /**
44       * @return number of features
45       */
46      public int getNrFeature() {
47          return nr_feature;
48      }
49  
50      public int[] getLabels() {
51          return copyOf(label, nr_class);
52      }
53  
54      /**
55       * The nr_feature*nr_class array w gives feature weights. We use one
56       * against the rest for multi-class classification, so each feature
57       * index corresponds to nr_class weight values. Weights are
58       * organized in the following way
59       *
60       * <pre>
61       * +------------------+------------------+------------+
62       * | nr_class weights | nr_class weights |  ...
63       * | for 1st feature  | for 2nd feature  |
64       * +------------------+------------------+------------+
65       * </pre>
66       *
67       * If bias &gt;= 0, x becomes [x; bias]. The number of features is
68       * increased by one, so w is a (nr_feature+1)*nr_class array. The
69       * value of bias is stored in the variable bias.
70       * @see #getBias()
71       * @return a <b>copy of</b> the feature weight array as described
72       */
73      public double[] getFeatureWeights() {
74          return Linear.copyOf(w, w.length);
75      }
76  
77      /**
78       * @return true for logistic regression solvers
79       */
80      public boolean isProbabilityModel() {
81          return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
82      }
83  
84      /**
85       * @see #getFeatureWeights()
86       */
87      public double getBias() {
88          return bias;
89      }
90  
91      @Override
92      public String toString() {
93          StringBuilder sb = new StringBuilder("Model");
94          sb.append(" bias=").append(bias);
95          sb.append(" nr_class=").append(nr_class);
96          sb.append(" nr_feature=").append(nr_feature);
97          sb.append(" solverType=").append(solverType);
98          return sb.toString();
99      }
100 
101     @Override
102     public int hashCode() {
103         final int prime = 31;
104         int result = 1;
105         long temp;
106         temp = Double.doubleToLongBits(bias);
107         result = prime * result + (int)(temp ^ (temp >>> 32));
108         result = prime * result + Arrays.hashCode(label);
109         result = prime * result + nr_class;
110         result = prime * result + nr_feature;
111         result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
112         result = prime * result + Arrays.hashCode(w);
113         return result;
114     }
115 
116     @Override
117     public boolean equals(Object obj) {
118         if (this == obj) return true;
119         if (obj == null) return false;
120         if (getClass() != obj.getClass()) return false;
121         Model other = (Model)obj;
122         if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
123         if (!Arrays.equals(label, other.label)) return false;
124         if (nr_class != other.nr_class) return false;
125         if (nr_feature != other.nr_feature) return false;
126         if (solverType == null) {
127             if (other.solverType != null) return false;
128         } else if (!solverType.equals(other.solverType)) return false;
129         if (!equals(w, other.w)) return false;
130         return true;
131     }
132 
133     /**
134      * don't use {@link Arrays#equals(double[], double[])} here, cause 0.0 and -0.0 should be handled the same
135      *
136      * @see Linear#saveModel(java.io.Writer, Model)
137      */
138     protected static boolean equals(double[] a, double[] a2) {
139         if (a == a2) return true;
140         if (a == null || a2 == null) return false;
141 
142         int length = a.length;
143         if (a2.length != length) return false;
144 
145         for (int i = 0; i < length; i++)
146             if (a[i] != a2[i]) return false;
147 
148         return true;
149     }
150 
151     /**
152      * see {@link Linear#saveModel(java.io.File, Model)}
153      */
154     public void save(File file) throws IOException {
155         Linear.saveModel(file, this);
156     }
157 
158     /**
159      * see {@link Linear#saveModel(Writer, Model)}
160      */
161     public void save(Writer writer) throws IOException {
162         Linear.saveModel(writer, this);
163     }
164 
165     /**
166      * see {@link Linear#loadModel(File)}
167      */
168     public static Model load(File file) throws IOException {
169         return Linear.loadModel(file);
170     }
171 
172     /**
173      * see {@link Linear#loadModel(Reader)}
174      */
175     public static Model load(Reader inputReader) throws IOException {
176         return Linear.loadModel(inputReader);
177     }
178 }