pytorch-rl
Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes):
Sample loggings while training a DQN agent on CartPole (we use WARNING
as the logging level currently to get rid of the INFO
printouts from visdom):
[WARNING ] (MainProcess) <===================================>[WARNING ] (MainProcess) bash$: python -m visdom.server[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900[WARNING ] (MainProcess) <===================================> DQN[WARNING ] (MainProcess) <-----------------------------------> Env[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123[INFO ] (MainProcess) Making new env: CartPole-v0[WARNING ] (MainProcess) Action Space: [0, 1][WARNING ] (MainProcess) State Space: 4[WARNING ] (MainProcess) <-----------------------------------> Model[WARNING ] (MainProcess) MlpModel ( (fc1): Linear (4 -> 16) (rl1): ReLU () (fc2): Linear (16 -> 16) (rl2): ReLU () (fc3): Linear (16 -> 16) (rl3): ReLU () (fc4): Linear (16 -> 2))[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.[WARNING ] (MainProcess) <===================================> Training ...[WARNING ] (MainProcess) Validation Data @ Step: 501[WARNING ] (MainProcess) Start Training @ Step: 501[WARNING ] (MainProcess) Reporting @ Step: 2500 | Elapsed Time: 5.32397913933[WARNING ] (MainProcess) Training Stats: epsilon: 0.972[WARNING ] (MainProcess) Training Stats: total_reward: 2500.0[WARNING ] (MainProcess) Training Stats: avg_reward: 21.7391304348[WARNING ] (MainProcess) Training Stats: nepisodes: 115[WARNING ] (MainProcess) Training Stats: nepisodes_solved: 114[WARNING ] (MainProcess) Training Stats: repisodes_solved: 0.991304347826[WARNING ] (MainProcess) Evaluating @ Step: 2500[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607[WARNING ] (MainProcess) Saving Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...[WARNING ] (MainProcess) Saved Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.[WARNING ] (MainProcess) Resume Training @ Step: 2500 ...
This repo currently contains the following agents:
Double DQN [3]
Dueling network DQN (Dueling DQN) [4]
Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]
Work in progress: - Testing ACER
Future Plans: - Deep Deterministic Policy Gradient (DDPG) [9], [10] - Continuous DQN (CDQN or NAF) [11]
NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable. * ./utils/factory.py
We suggest the users refer to
./utils/factory.py
, where we list all the integratedEnv
,Model
,Memory
,Agent
intoDict
's. All of those four core classes are implemented in./core/
. The factory pattern in./utils/factory.py
makes the code super clean, as no matter what type ofAgent
you want to train, or which type ofEnv
you want to train on, all you need to do is to simply modify some parameters in./utils/options.py
, then the./main.py
will do it all (NOTE: this./main.py
file never needs to be modified). * namings To make the code more clean and readable, we name the variables using the following pattern (mainly in inheritedAgent
's): **_vb
:torch.autograd.Variable
's or a list of such objects **_ts
:torch.Tensor
's or a list of such objects * otherwise: normal python datatypes
You only need to modify some parameters in ./utils/options.py
to train a new configuration.
Configure your training in ./utils/options.py
:
line 14
: add an entry into CONFIGS
to define your training (agent_type
, env_type
, game
, model_type
, memory_type
)
line 33
: choose the entry you just added
line 29-30
: fill in your machine/cluster ID (MACHINE
) and timestamp (TIMESTAMP
) to define your training signature (MACHINE_TIMESTAMP
), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth
& ./logs/MACHINE_TIMESTAMP.log
respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: python -m visdom.server &
, then open this address in your browser: http://localhost:8097/env/MACHINE_TIMESTAMP
)
line 32
: to train a model, set mode=1
(training visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP
); to test the model of this current training, all you need to do is to set mode=2
(testing visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP_test
).
Run:
python main.py
We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot) * plot.sh
(e.g., plot from log file: logs/machine1_17080801.log
)
./plot.sh machine1 17080801
the generated figures will be saved into
figs/machine1_17080801/
plot_compare.sh
(e.g., compare log files:logs/machine1_17080801.log
,logs/machine2_17080802.log
)./plot.sh 00 machine1 17080801 machine2 17080802
the generated figures will be saved into
figs/compare_00/
the color coding will be in the order of:
red green blue magenta yellow cyan
And a private implementation of A3C from @stokasto
If you find this library useful and would like to cite it, the following would be appropriate:
@misc{pytorch-rl, author = {Zhang, Jingwei and Tai, Lei}, title = {jingweiz/pytorch-rl}, url = {https://github.com/jingweiz/pytorch-rl}, year = {2017} }
上一篇:Seq2seq-Chatbot
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com