[確率・統計]ARIMA(自己回帰和分移動平均)モデルのJAVA実装

 気分日記で次週の気分推移の予測をするため、ARIMA(自己回帰和分移動平均)モデルでの予測機能を追加しました。

 ARIMAモデルは、周期的な特徴を持つ時系列データの予測をできるモデルとなっています。詳しい内容は、他のブログ記事に任せここでは、Javaで実装方法を記載します。

Pythonと異なりJavaは数値計算のオープンソースが少なく単純に検索した結果は以下のみが見つかりました。

timeseries-forecast
https://github.com/Workday/timeseries-forecast 

ライセンスはMIT Licenseでアプリへの転用は容易になっています。

以下のようにこのライブラリを使って予測データを作成できます。

import com.workday.insights.timeseries.arima.Arima;
import com.workday.insights.timeseries.arima.struct.ArimaParams;
import com.workday.insights.timeseries.arima.struct.ForecastResult;

class arima {
    public static void main(String[] args) {
		// Prepare input timeseries data.
		double[] dataArray = new double[] {2, 1, 2, 5, 2, 1, 2, 5, 2, 1, 2, 5, 2, 1, 2, 5};

		// Set ARIMA model parameters.
		int forecastSize = 10;

		// Ser Arima parameter
		// 以下を指定します。
		// p : 自己回帰パラメータ
		// d : 差分の階数
		// q : 移動平均パラメータ
		// 以下はSARIMAモデルに使用される
		// P : 季節性自己相関
		// D : 季節性導出
		// Q : 季節性移動平均
		// m : 各季節の期間の数
		ArimaParams arimaParams = new ArimaParams(3, 0, 2, 1, 1, 0, 0);

		// Obtain forecast result. The structure contains forecasted values and performance metric etc.
		ForecastResult forecastResult = Arima.forecast_arima(dataArray, forecastSize, arimaParams);

		// Read forecast values
		double[] forecastData = forecastResult.getForecast(); // in this example, it will return { 2 }

		// You can obtain upper- and lower-bounds of confidence intervals on forecast values.
		// By default, it computes at 95%-confidence level. This value can be adjusted in ForecastUtil.java
		double[] uppers = forecastResult.getForecastUpperConf();
		double[] lowers = forecastResult.getForecastLowerConf();

		// You can also obtain the root mean-square error as validation metric.
		double rmse = forecastResult.getRMSE();

		// It also provides the maximum normalized variance of the forecast values and their confidence interval.
		double maxNormalizedVariance = forecastResult.getMaxNormalizedVariance();

		// Finally you can read log messages.
		String log = forecastResult.getLog();

		//Output result
		System.out.print("forecastData\n");
			for(int i = 0; i < forecastData.length; ++i){
			System.out.print(String.format("%f \n", forecastData[i]));
		}
		System.out.print("\nuppers\n");
		for(int i = 0; i < uppers.length; ++i){
			System.out.print(String.format("%f \n", uppers[i]));
		}
		System.out.print("\nlowers\n");
		for(int i = 0; i < lowers.length; ++i){
			System.out.print(String.format("%f \n", lowers[i]));
		}
		System.out.print(String.format("rmse = %f \n", rmse));
		System.out.print(String.format("maxNormalizedVariance = %f \n", maxNormalizedVariance));
		System.out.print(log);
    }
}

0 件のコメント :

コメントを投稿