原标题:MXNET之GluonTS学习手册:第一章《DeepAR模型初步尝试》
原文来自:CSDN 原文链接:https://blog.csdn.net/qq_34461600/article/details/103005723
阅读本手册需要一定mxnet、gluon操作基础。
本文使用cpu训练代码。
本文github地址:https://github.com/zmkwjx/GluonTS-Learning-in-Action
GluonTS官网地址:https://gluon-ts.mxnet.io
1、环境以及安装
1.1 本文开发环境:ubuntu16.04TS、python3.7
1.2 快速安装
pip install matplotlib numpy pandas pathlib
pip install mxnet mxnet-mkl gluon gluonts
2、训练程序
#Third-party imports
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from gluonts.model import deepar
from gluonts.dataset import common
from gluonts.dataset.util import to_pandas
from gluonts.model.predictor import Predictor
2.1 加载训练数据 Twitter_volume_AMZN.csv
url = "./data/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
data = common.ListDataset([{"start": df.index[0],
"target": df.value[:"2015-04-23 00:00:00"]}], freq="H")
2.2 解读 ListDataset
class gluonts.dataset.common.ListDataset(data_iter: Iterable[Dict[str, Any]], freq: str, one_dim_target: bool = True)
data_iter: 可迭代对象产生数据集中的所有项目。每个项目都应该是一个将字符串映射到值的字典。例如:{“start”: “2014-09-07”, “target”: [0.1, 0.2]}
freq: 时间序列中的观察频率。
one_dim_target: 是否仅接受单变量目标时间序列。
2.3 训练现有模型
GluonTS带有许多预先构建的模型。用户所需要做的就是配置一些超参数。现有模型专注于(但不限于)概率预测。概率预测是以概率分布的形式进行的预测,而不是简单的单点估计。
estimator = deepar.DeepAREstimator(freq="H", prediction_length=24)
predictor = estimator.train(training_data=data)
2.4 预览训练结果
for test_entry, forecast in zip(train_data, predictor.predict(train_data)):
to_pandas(test_entry)[-60:].plot(linewidth=2)
forecast.plot(color='g', prediction_intervals=[50.0, 90.0])
plt.grid(which='both')
plt.show()
预测结果
2.5 输出训练结果
prediction = next(predictor.predict(train_data))
print(prediction.mean)
prediction.plot(output_file='graph.png')
OUT
2.5 保存训练模型
predictor.serialize(Path("此处填入Model文件夹的绝对路径"))
2.6 使用训练模型
predictor = Predictor.deserialize(Path("此处填入Model文件夹的绝对路径"))
import pandas as pd
from pathlib import Path
from gluonts.dataset import common
from gluonts.dataset.util import to_pandas
from gluonts.model.predictor import Predictor
url = "./data/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
train_data = common.ListDataset([{"start": df.index[0],
"target": df.value[:"2015-04-23 00:00:00"]}],freq="H")
predictor = Predictor.deserialize(Path("此处填入Model文件夹的绝对路径"))
prediction = next(predictor.predict(train_data))
print(prediction.mean)
prediction.plot(output_file='graph.png')
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com