使用Gym库和Pytorch框架的强化学习模型
受过训练的代理的样本测试(突破时的DQN,Pong上的A3C,CartPole上的DoubleDQN,InvertedPendulum上的连续A3C(MuJoCo)):
在乒乓球训练A3C代理的同时进行在线绘图示例(有16个学习者流程):
在CartPole上训练DQN代理时采样WARNING
记录(我们目前用作记录级别来消除INFO
来自visdom 的打印输出):
[WARNING ] (MainProcess) <===================================>[WARNING ] (MainProcess) bash$: python -m visdom.server[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900[WARNING ] (MainProcess) <===================================> DQN[WARNING ] (MainProcess) <-----------------------------------> Env[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123[INFO ] (MainProcess) Making new env: CartPole-v0[WARNING ] (MainProcess) Action Space: [0, 1][WARNING ] (MainProcess) State Space: 4[WARNING ] (MainProcess) <-----------------------------------> Model[WARNING ] (MainProcess) MlpModel ( (fc1): Linear (4 -> 16) (rl1): ReLU () (fc2): Linear (16 -> 16) (rl2): ReLU () (fc3): Linear (16 -> 16) (rl3): ReLU () (fc4): Linear (16 -> 2))[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.[WARNING ] (MainProcess) <===================================> Training ...[WARNING ] (MainProcess) Validation Data @ Step: 501[WARNING ] (MainProcess) Start Training @ Step: 501[WARNING ] (MainProcess) Reporting @ Step: 2500 | Elapsed Time: 5.32397913933[WARNING ] (MainProcess) Training Stats: epsilon: 0.972[WARNING ] (MainProcess) Training Stats: total_reward: 2500.0[WARNING ] (MainProcess) Training Stats: avg_reward: 21.7391304348[WARNING ] (MainProcess) Training Stats: nepisodes: 115[WARNING ] (MainProcess) Training Stats: nepisodes_solved: 114[WARNING ] (MainProcess) Training Stats: repisodes_solved: 0.991304347826[WARNING ] (MainProcess) Evaluating @ Step: 2500[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607[WARNING ] (MainProcess) Saving Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...[WARNING ] (MainProcess) Saved Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.[WARNING ] (MainProcess) Resume Training @ Step: 2500...
此仓库目前包含以下代理:
双DQN [3]
决斗网DQN(决斗DQN)[4]
Asynchronous Advantage Actor-Critic(A3C)(具有离散/连续动作空间支持)[5],[6]
具有经验重放(ACER)的示例高效演员 - 评论家(目前具有离散动作空间支持(截断重要性采样,一阶TRPO))[7],[8]
正在进行的工作: - 测试ACER
未来计划: - 深度确定性政策梯度(DDPG)[9],[10] - 连续DQN(CDQN或NAF)[11]
注意:我们遵循pytorch-dnc的确切代码结构,以便使代码易于移植。*./utils/factory.py
我们建议用户参考
./utils/factory.py
,在这里我们列出所有的集成Env
,Model
,Memory
,Agent
进Dict
的。所有这四个核心类都在实现./core/
。工厂模式./utils/factory.py
使代码超级干净,无论Agent
您想要训练什么类型,或者您想要训练哪种类型Env
,您只需要修改一些参数./utils/options.py
,然后就./main.py
可以完成所有操作(注意:此./main.py
文件永远不需要修改)。* namings为了使代码更加干净和可读,我们使用以下模式命名变量(主要在继承Agent
的模式中):**_vb
:torch.autograd.Variable
或这些对象的列表**_ts
:torch.Tensor
或者是这样的对象列表*否则:普通的python数据类型
您只需要修改一些参数./utils/options.py
来训练新配置。
配置您的培训./utils/options.py
:
line 14
:一个条目添加到CONFIGS
定义你的训练(agent_type
,env_type
,game
,model_type
,memory_type
)
line 33
:选择刚刚添加的条目
line 29-30
:填写您的机/集群ID( MACHINE
)和时间戳(TIMESTAMP
)来定义你的训练特征(MACHINE_TIMESTAMP
),相应的模型文件,本次培训的日志文件将这个签名(下保存./models/MACHINE_TIMESTAMP.pth
与./logs/MACHINE_TIMESTAMP.log
分别)。另外,visdom可视化将这个签名下显示(首先在bash型激活visdom服务器:python -m visdom.server &
,然后在浏览器中打开此地址:http://localhost:8097/env/MACHINE_TIMESTAMP
)
line 32
:培训模型,设置mode=1
(培训可视化将在下http://localhost:8097/env/MACHINE_TIMESTAMP
); 要测试当前培训的模型,您需要做的就是设置mode=2
(测试可视化将在下面http://localhost:8097/env/MACHINE_TIMESTAMP_test
)。
跑:
python main.py
我们还提供了2个附加脚本,用于在培训后快速评估您的结果。(依赖条件:LMJ积)* plot.sh
(例如,从日志文件中的情节:logs/machine1_17080801.log
)
./plot.sh machine1 17080801
生成的数字将保存到
figs/machine1_17080801/
plot_compare.sh
(例如,比较日志文件:logs/machine1_17080801.log
,logs/machine2_17080802.log
)./plot.sh 00 machine1 17080801 machine2 17080802
生成的数字将保存到
figs/compare_00/
颜色编码将按以下顺序排列:
red green blue magenta yellow cyan
并从@stokasto私人实施A3C
如果您发现此库很有用并想引用它,则以下内容是合适的:
@misc{pytorch-rl, author = {Zhang, Jingwei and Tai, Lei}, title = {jingweiz/pytorch-rl}, url = {https://github.com/jingweiz/pytorch-rl}, year = {2017}}
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com