资源算法使用Gym库和Pytorch框架的强化学习模型

使用Gym库和Pytorch框架的强化学习模型

2019-08-19 | |  232 |   0 |   0

深强化学习

pytorchvisdom


  • 受过训练的代理的样本测试(突破时的DQN,Pong上的A3C,CartPole上的DoubleDQN,InvertedPendulum上的连续A3C(MuJoCo)):

  • 在乒乓球训练A3C代理的同时进行在线绘图示例(有16个学习者流程): a3c_pong_plot

  • 在CartPole上训练DQN代理时采样WARNING记录(我们目前用作记录级别来消除INFO来自visdom 打印输出):

[WARNING ] (MainProcess) <===================================>[WARNING ] (MainProcess) bash$: python -m visdom.server[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900[WARNING ] (MainProcess) <===================================> DQN[WARNING ] (MainProcess) <-----------------------------------> Env[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123[INFO    ] (MainProcess) Making new env: CartPole-v0[WARNING ] (MainProcess) Action Space: [0, 1][WARNING ] (MainProcess) State  Space: 4[WARNING ] (MainProcess) <-----------------------------------> Model[WARNING ] (MainProcess) MlpModel (
  (fc1): Linear (4 -> 16)
  (rl1): ReLU ()
  (fc2): Linear (16 -> 16)
  (rl2): ReLU ()
  (fc3): Linear (16 -> 16)
  (rl3): ReLU ()
  (fc4): Linear (16 -> 2))[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.[WARNING ] (MainProcess) <===================================> Training ...[WARNING ] (MainProcess) Validation Data @ Step: 501[WARNING ] (MainProcess) Start  Training @ Step: 501[WARNING ] (MainProcess) Reporting       @ Step: 2500 | Elapsed Time: 5.32397913933[WARNING ] (MainProcess) Training Stats:   epsilon:          0.972[WARNING ] (MainProcess) Training Stats:   total_reward:     2500.0[WARNING ] (MainProcess) Training Stats:   avg_reward:       21.7391304348[WARNING ] (MainProcess) Training Stats:   nepisodes:        115[WARNING ] (MainProcess) Training Stats:   nepisodes_solved: 114[WARNING ] (MainProcess) Training Stats:   repisodes_solved: 0.991304347826[WARNING ] (MainProcess) Evaluating      @ Step: 2500[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607[WARNING ] (MainProcess) Saving Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...[WARNING ] (MainProcess) Saved  Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.[WARNING ] (MainProcess) Resume Training @ Step: 2500
...

包括什么?

此仓库目前包含以下代理:

  • 深度学习(DQN)[1][2]

  • 双DQN [3]

  • 决斗网DQN(决斗DQN)[4]

  • Asynchronous Advantage Actor-Critic(A3C)(具有离散/连续动作空间支持)[5][6]

  • 具有经验重放(ACER)的示例高效演员 - 评论家(目前具有离散动作空间支持(截断重要性采样,一阶TRPO))[7][8]

正在进行的工作: - 测试ACER

未来计划: - 深度确定性政策梯度(DDPG)[9][10] - 连续DQN(CDQN或NAF)[11]

代码结构和命名约定:

注意:我们遵循pytorch-dnc的确切代码结构,以便使代码易于移植。*./utils/factory.py

我们建议用户参考./utils/factory.py,在这里我们列出所有的集成EnvModel, MemoryAgentDict的。所有这四个核心类都在实现./core/工厂模式./utils/factory.py使代码超级干净,无论Agent您想要训练什么类型,或者您想要训练哪种类型Env,您只需要修改一些参数./utils/options.py,然后就./main.py可以完成所有操作(注意:此./main.py文件永远不需要修改)。* namings为了使代码更加干净和可读,我们使用以下模式命名变量(主要在继承Agent模式中):* *_vbtorch.autograd.Variable或这些对象的列表* *_tstorch.Tensor或者是这样的对象列表*否则:普通的python数据类型

依赖


怎么运行:

您只需要修改一些参数./utils/options.py来训练新配置。

  • 配置您的培训./utils/options.py

    • line 14:一个条目添加到CONFIGS定义你的训练(agent_typeenv_typegamemodel_typememory_type

    • line 33:选择刚刚添加的条目

    • line 29-30:填写您的机/集群ID( MACHINE和时间戳(TIMESTAMP)来定义你的训练特征(MACHINE_TIMESTAMP),相应的模型文件,本次培训的日志文件将这个签名(下保存./models/MACHINE_TIMESTAMP.pth./logs/MACHINE_TIMESTAMP.log分别)。另外,visdom可视化将这个签名下显示(首先在bash型激活visdom服务器:python -m visdom.server &,然后在浏览器中打开此地址:http://localhost:8097/env/MACHINE_TIMESTAMP

    • line 32:培训模型,设置mode=1(培训可视化将在下http://localhost:8097/env/MACHINE_TIMESTAMP); 要测试当前培训的模型,您需要做的就是设置mode=2(测试可视化将在下面http://localhost:8097/env/MACHINE_TIMESTAMP_test)。

  • 跑:

    python main.py


奖金脚本:)

我们还提供了2个附加脚本,用于在培训后快速评估您的结果。(依赖条件:LMJ积)* plot.sh(例如,从日志文件中的情节:logs/machine1_17080801.log

  • ./plot.sh machine1 17080801

  • 生成的数字将保存到 figs/machine1_17080801/

  • plot_compare.sh(例如,比较日志文件:logs/machine1_17080801.loglogs/machine2_17080802.log./plot.sh 00 machine1 17080801 machine2 17080802

  • 生成的数字将保存到 figs/compare_00/

  • 颜色编码将按以下顺序排列: red green blue magenta yellow cyan


我们在开发这个回购时提到的回购:


引文

如果您发现此库很有用并想引用它,则以下内容是合适的:

@misc{pytorch-rl,
  author = {Zhang, Jingwei and Tai, Lei},
  title = {jingweiz/pytorch-rl},
  url = {https://github.com/jingweiz/pytorch-rl},
  year = {2017}
}


上一篇:R-FCN

下一篇:Ultrasound nerve segmentation

用户评价
全部评价

热门资源

  • Keras-ResNeXt

    Keras ResNeXt Implementation of ResNeXt models...

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • spark-corenlp

    This package wraps Stanford CoreNLP annotators ...

  • capsnet-with-caps...

    CapsNet with capsule-wise convolution Project ...

  • inferno-boilerplate

    This is a very basic boilerplate example for pe...