001package org.opengion.penguin.math.statistics;
002
003import java.util.Arrays;
004
005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
006
007/**
008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。
009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。
010 * c0の切片を考慮するかどうかはnoInterceptで決めます。
011 * 
012 */
013//public class HybsMultiRegression {
014public class HybsMultiRegression implements HybsRegression {
015        private double cnst[];                  // 各係数(xの種類+1になる?)
016        private double rsquare;         // 決定係数
017        private boolean noIntercept; //切片を利用するかどうか
018                
019        /**
020         * コンストラクタ。
021         * 与えた二次元データを元に重回帰を計算します。
022         * xデータとして二次元配列を与えます。
023         * noInterceptで切片有り無しを選択します。
024         * @param in_x 説明変数
025         * @param in_y 目的変数
026         * @param noIntercept 切片利用有無(trueで利用しない)
027         */
028        public HybsMultiRegression(final double[][] in_x, final double[] in_y, final boolean noIntercept){
029                train( in_x, in_y, noIntercept );
030                
031//              this.noIntercept = noIntercept;
032//              
033//              // ここで重回帰計算
034//              OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
035//              regression.setNoIntercept(noIntercept);
036//       regression.newSampleData(in_y, in_x);
037//              
038//              cnst = regression.estimateRegressionParameters();
039//              rsquare = regression.calculateRSquared();
040        }
041        
042        /**
043         * コンストラクタ。
044         * 係数配列を与えられるようにしておきます。
045         * (以前に計算したものを利用)
046         * @param in_c 係数配列
047         * @param noIntercept 切片利用有無(trueで利用しない)
048         * 
049         */
050        public HybsMultiRegression( final double[] in_c, final boolean noIntercept){
051                this.cnst = in_c;
052                this.noIntercept = noIntercept;
053        }
054        
055        /**
056         * 与えた二次元データを元に重回帰を計算します。
057         * xデータとして二次元配列を与えます。
058         * noInterceptで切片有り無しを選択します。
059         * 
060         * @param in_x 説明変数
061         * @param in_y 目的変数
062         * @param noIntercept 切片利用有無(trueで利用しない)
063         */
064        private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
065                this.noIntercept = noIntercept;
066
067                // ここで重回帰計算
068                final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
069                regression.setNoIntercept(noIntercept);
070        regression.newSampleData(in_y, in_x);
071
072                cnst    = regression.estimateRegressionParameters();
073                rsquare = regression.calculateRSquared();
074        }
075        
076//      /**
077//       * 係数の取得。
078//       * @return 係数配列
079//       */
080//      public double[] getParam(){
081//              return cnst;
082//      }
083        
084        /**
085         * 係数をセットした配列を返します。
086         *
087         * @return 各係数の配列
088         */
089        @Override
090        public double[] getCoefficient() {
091                return Arrays.copyOf( cnst,cnst.length );
092        }
093        
094        /**
095         * 配列の内容を係数としてセットします。
096         * 
097         * @param in_c 係数配列
098         */
099        public void setCoefficient(final double[] in_c){
100                cnst = in_c;
101        }
102        
103        /**
104         * 決定係数の取得。
105         * @return 決定係数
106         */
107        public double getRSquare(){
108                return rsquare;
109        }
110        
111        /**
112         * 計算( c0 + c1x1...)を行う。
113         * noInterceptによってc0の利用を決める。
114         * xの大きさが足りない場合は0を返す。
115         * 
116         * @param in_x 必要な大きさの変数配列
117         * @return 計算結果
118         */
119        public double predict(final double... in_x){
120                double rtn = 0;
121                int itr = noIntercept ? 0 : 1;
122                if( in_x.length < cnst.length-itr ){
123                        return 0;
124                }
125                
126                for( int i=0; i < in_x.length; i++ ){
127                        rtn = rtn + in_x[i] * cnst[i+itr];
128                }
129                if( !noIntercept ){ rtn = rtn + cnst[0]; }
130                
131                return rtn;
132        }
133
134        /*** ここまでが本体 ***/
135        /*** ここからテスト用mainメソッド ***/
136        /**
137         * @param args *****************************************/
138        public static void main(final String [] args) {
139                // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより
140                double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
141                double[][] x = new double[10][];
142                x[0] = new double[] { 165, 65 };
143                x[1] = new double[] { 170, 68 };
144                x[2] = new double[] { 172, 70 };
145                x[3] = new double[] { 175, 65 };
146                x[4] = new double[] { 170, 80 };
147                x[5] = new double[] { 172, 85 };
148                x[6] = new double[] { 183, 78 };
149                x[7] = new double[] { 187, 79 };
150                x[8] = new double[] { 180, 95 };
151                x[9] = new double[] { 185, 97 };
152                
153                
154                HybsMultiRegression mr = new HybsMultiRegression(x,y,true);
155                
156                System.out.println( mr.getRSquare() );
157                System.out.println( Arrays.toString( mr.getCoefficient()) );
158                
159                System.out.println( mr.predict( new double[] { 169,85 } ));
160        }
161}
162