Python Statsmodels: Using SARIMAX with exogenous r

2019-06-17 18:29发布

问题:

I'm using statsmodels.tsa.SARIMAX() to train a model with exogenous variables. Is there an equivalent of get_prediction() when a model is trained with exogenous variables so that the object returned contains the predicted mean and confidence interval rather than just an array of predicted mean results? The predict() and forecast() methods take exogenous variables, but only return the predicted mean value.

SARIMA_model = sm.tsa.SARIMAX(endog=y_train.astype('float64'),
                          exog=ExogenousFeature_train.values.astype('float64'), 
                          order=(1,0,0),
                          seasonal_order=(2,1,0,7), 
                          simple_differencing=False)

model_results = SARIMA_model.fit()

pred = model_results.predict(start=train_end_date,
                               end=test_end_date,
                               exog=ExogenousFeature_test.values.astype('float64').reshape(343,1),
                               dynamic=False)

pred here is an array of predicted values rather than an object containing predicted mean values and confidence intervals that you would get if you ran get_predict(). Note, get_predict() does not take exogenous variables.

My version of statsmodels is 0.8

回答1:

There has been some backward compatibility related issues due to which full results (with pred intervals etc) are not being exposed.

To get you what you want now: Use get_prediction and get_forecast functions with parameters described below

    pred_res = sarimax_model.get_prediction(exog=ExogenousFeature_train.values.astype('float64'), full_results=True,alpha=0.05)
    pred_means = pred_res.predicted_mean
    # Specify your prediction intervals by alpha parameter. alpha=0.05 implies 95% CI
    pred_cis = pred_res.conf_int(alpha=0.05)

    # You can then plot it (import matplotlib first)
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1,1,1)
    #Actual data
    ax.plot(y_train.astype('float64'), '--', color="blue", label='data')
    # Means
    ax.plot(pred_means, lw=1, color="black", alpha=0.5, label='SARIMAX')
    ax.fill_between(pred_means.index, pred_cis.iloc[:, 0], pred_cis.iloc[:, 1], alpha=0.05)
    ax.legend(loc='upper right')
    plt.draw()

For more info, go to:

  • https://github.com/statsmodels/statsmodels/issues/2823
  • Solution by the author: http://www.statsmodels.org/dev/examples/notebooks/generated/statespace_local_linear_trend.html