pytorch-lightning
The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.
Simple installation from PyPI
pip install pytorch-lightning
Lightning is a very lightweight wrapper on PyTorch. This means you don't have to learn a new library. To use Lightning, simply refactor your research code into the LightningModule format and Lightning will automate the rest. Lightning guarantees tested, correct, modern best practices for the automated parts.
Use our seed-project aimed at reproducibility!
Every research project starts the same, a model, a training loop, validation loop, etc. As your research advances, you're likely to need distributed training, 16-bit precision, checkpointing, gradient accumulation, etc.
Lightning sets up all the boilerplate state-of-the-art training for you so you can focus on the research.
Think about Lightning as refactoring your research code instead of using a new framework. The research code goes into a LightningModule which you fit using a Trainer.
The LightningModule defines a system such as seq-2-seq, GAN, etc... It can ALSO define a simple classifier such as the example below.
To use lightning do 2 things:
WARNING: This syntax is for version 0.5.0+ where abbreviations were removed.
import osimport torchfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom torchvision.datasets import MNISTimport torchvision.transforms as transformsimport pytorch_lightning as plclass CoolSystem(pl.LightningModule): def __init__(self): super(CoolSystem, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_nb): # REQUIRED x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_nb): # OPTIONAL x, y = batch y_hat = self.forward(x) return {'val_loss': F.cross_entropy(y_hat, y)} def validation_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) return torch.optim.Adam(self.parameters(), lr=0.02) @pl.data_loader def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def test_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
Fit with a trainer
from pytorch_lightning import Trainer model = CoolSystem()# most basic trainer, uses good defaultstrainer = Trainer() trainer.fit(model)
Trainer sets up a tensorboard logger, early stopping and checkpointing by default (you can modify all of them or use something other than tensorboard).
Here are more advanced examples
# train on cpu using only 10% of the data (for demo purposes)trainer = Trainer(max_nb_epochs=1, train_percent_check=0.1)# train on 4 gpus (lightning chooses GPUs for you)# trainer = Trainer(max_nb_epochs=1, gpus=4) # train on 4 gpus (you choose GPUs)# trainer = Trainer(max_nb_epochs=1, gpus=[0, 1, 3, 7]) # train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)# trainer = Trainer(max_nb_epochs=1, gpus=8, nb_gpu_nodes=4)# train (1 epoch only here for demo)trainer.fit(model)# view tensorboard logs print('View tensorboard logs by runningntensorboard --logdir %s' % os.getcwd())print('and going to http://localhost:6006 on your browser')
When you're all done you can even run the test set separately.
trainer.test()
Everything in gray!
You define the blue parts using the LightningModule interface:
# what to do in the training loopdef training_step(self, batch, batch_nb):# what to do in the validation loopdef validation_step(self, batch, batch_nb):# how to aggregate validation_step outputsdef validation_end(self, outputs):# and your dataloadersdef train_dataloader():def val_dataloader():def test_dataloader():
Could be as complex as seq-2-seq + attention
# define what happens for training heredef training_step(self, batch, batch_nb): x, y = batch # define your own forward and loss calculation hidden_states = self.encoder(x) # even as complex as a seq-2-seq + attn model # (this is just a toy, non-working example to illustrate) start_token = '<SOS>' last_hidden = torch.zeros(...) loss = 0 for step in range(max_seq_len): attn_context = self.attention_nn(hidden_states, start_token) pred = self.decoder(start_token, attn_context, last_hidden) last_hidden = pred pred = self.predict_nn(pred) loss += self.loss(last_hidden, y[step]) #toy example as well loss = loss / max_seq_len return {'loss': loss}
Or as basic as CNN image classification
# define what happens for validation heredef validation_step(self, batch, batch_nb): x, y = batch # or as basic as a CNN classification out = self.forward(x) loss = my_loss(out, y) return {'loss': loss}
And you also decide how to collate the output of all validation steps
def validation_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step :return: """ val_loss_mean = 0 val_acc_mean = 0 for output in outputs: val_loss_mean += output['val_loss'] val_acc_mean += output['val_acc'] val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) logs = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} result = {'log': logs} return result
Lightning is fully integrated with tensorboard, MLFlow and supports any logging module.
Lightning also adds a text column with all the hyperparameters for this experiment.
Welcome to the Lightning community!
If you have any questions, feel free to:
Ask on stackoverflow with the tag pytorch-lightning.
If no one replies to you quickly enough, feel free to post the stackoverflow link to our Gitter chat!
To chat with the rest of us visit our gitter channel!
How do I use Lightning for rapid research?
Here's a walk-through
Why was Lightning created?
Lightning has 3 goals in mind:
Maximal flexibility while abstracting out the common boilerplate across research projects.
Reproducibility. If all projects use the LightningModule template, it will be much much easier to understand what's going on and where to look! It will also mean every implementation follows a standard format.
Democratizing PyTorch power user features. Distributed training? 16-bit? know you need them but don't want to take the time to implement? All good... these come built into Lightning.
How does Lightning compare with Ignite and fast.ai?
Here's a thorough comparison.
Is this another library I have to learn?
Nope! We use pure Pytorch everywhere and don't add unecessary abstractions!
Are there plans to support Python 2?
Nope.
Are there plans to support virtualenv?
Nope. Please use anaconda or miniconda.
Which PyTorch versions do you support?
PyTorch 1.1.0
# install pytorch 1.1.0 using the official instructions # install test-tube 0.6.7.6 which supports 1.1.0 pip install test-tube==0.6.7.6 # install latest Lightning version without upgrading deps pip install -U --no-deps pytorch-lightning
PyTorch 1.2.0 Install via pip as normal
If you can't wait for the next release, install the most up to date code with:
using GIT (locally clone whole repo with full history)
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
using instant zip (last state of the repo without git history)
pip install https://github.com/williamFalcon/pytorch-lightning/archive/master.zip --upgrade
You can also install any past release from this repository:
pip install https://github.com/williamFalcon/pytorch-lightning/archive/0.4.4.zip --upgrade
上一篇:hub
下一篇:tensorwatch
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com