pywick
Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with state of the art neural networks. Does the world need another Pytorch framework? Probably not. But we started this project when no good frameworks were available and it just kept growing. So here we are.
Pywick tries to stay on the bleeding edge of research into neural networks. If you just wish to run a vanilla CNN, this is probably going to be overkill. However, if you want to get lost in the world of neural networks, fine-tuning and hyperparameter optimization for months on end then this is probably the right place for you :)
Among other things Pywick includes:
State of the art normalization, activation, loss functions and optimizers not included in the standard Pytorch library.
A high-level module for training with callbacks, constraints, metrics, conditions and regularizers.
Dozens of popular object classification and semantic segmentation models.
Comprehensive data loading, augmentation, transforms, and sampling capability.
Utility tensor functions.
Useful meters.
Basic GridSearch (exhaustive and random).
Hey, check this out, we now have docs! They're still a work in progress though so apologies for anything that's broken.
Aug. 1, 2019
New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet
New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses
Major restructuring and standardization of NN models and loading functionality
General bug fixes and code improvements
pip install pywick
or specific version from git:
pip install git+https://github.com/achaiah/pywick.git@v0.5.3
The ModuleTrainer
class provides a high-level training interface which abstracts away the training loop while providing callbacks, constraints, initializers, regularizers, and more.
Example:
from pywick.modules import ModuleTrainerfrom pywick.initializers import XavierUniformfrom pywick.metrics import CategoricalAccuracySingleInputimport torch.nn as nnimport torch.functional as F# Define your model EXACTLY as normalclass Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.conv2 = nn.Conv2d(32, 64, kernel_size=3) self.fc1 = nn.Linear(1600, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 1600) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) model = Network() trainer = ModuleTrainer(model) # optionally supply cuda_devices as a parameterinitializers = [XavierUniform(bias=False, module_filter='fc*')]# initialize metrics with top1 and top5 metrics = [CategoricalAccuracySingleInput(top_k=1), CategoricalAccuracySingleInput(top_k=5)] trainer.compile(loss='cross_entropy', # callbacks=callbacks, # define your callbacks here (e.g. model saver, LR scheduler) # regularizers=regularizers, # define regularizers # constraints=constraints, # define constraints optimizer='sgd', initializers=initializers, metrics=metrics) trainer.fit_loader(train_dataset_loader, val_loader=val_dataset_loader, num_epoch=20, verbose=1)
You also have access to the standard evaluation and prediction functions:
loss = trainer.evaluate(x_train, y_train) y_pred = trainer.predict(x_train)
PyWick provides a wide range of callbacks, generally mimicking the interface found in Keras
:
CSVLogger
- Logs epoch-level metrics to a CSV file
CyclicLRScheduler
- Cycles through min-max learning rate
EarlyStopping
- Provides ability to stop training early based on supplied criteria
History
- Keeps history of metrics etc. during the learning process
LambdaCallback
- Allows you to implement your own callbacks on the fly
LRScheduler
- Simple learning rate scheduler based on function or supplied schedule
ModelCheckpoint
- Comprehensive model saver
ReduceLROnPlateau
- Reduces learning rate (LR) when a plateau has been reached
SimpleModelCheckpoint
- Simple model saver
Additionally, a TensorboardLogger
is incredibly easy to implement via the TensorboardX (now part of pytorch 1.1 release!)
from pywick.callbacks import EarlyStopping callbacks = [EarlyStopping(monitor='val_loss', patience=5)] trainer.set_callbacks(callbacks)
PyWick also provides regularizers:
L1Regularizer
L2Regularizer
L1L2Regularizer
and constraints:
UnitNorm
MaxNorm
NonNeg
Both regularizers and constraints can be selectively applied on layers using regular expressions and the module_filter
argument. Constraints can be explicit (hard) constraints applied at an arbitrary batch or epoch frequency, or they can be implicit (soft) constraints similar to regularizers where the the constraint deviation is added as a penalty to the total model loss.
from pywick.constraints import MaxNorm, NonNegfrom pywick.regularizers import L1Regularizer# hard constraint applied every 5 batcheshard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*')# implicit constraint added as a penalty term to model losssoft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*') constraints = [hard_constraint, soft_constraint] trainer.set_constraints(constraints) regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')] trainer.set_regularizers(regularizers)
You can also fit directly on a torch.utils.data.DataLoader
and can have a validation set as well :
from pywick import TensorDatasetfrom torch.utils.data import DataLoader train_dataset = TensorDataset(x_train, y_train) train_loader = DataLoader(train_dataset, batch_size=32) val_dataset = TensorDataset(x_val, y_val) val_loader = DataLoader(val_dataset, batch_size=32) trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)
All standard models from Pytorch:
Resnet + Swish
SE Inception
BiSeNet (Bilateral Segmentation Network for Real-time Semantic Segmentation)
Deeplab v2 (DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs)
Deeplab v3 (Rethinking Atrous Convolution for Semantic Image Segmentation)
DenseASPP (DenseASPP for Semantic Segmentation in Street Scenes)
DRNNet (Dilated Residual Networks)
DUC, HDC (understanding convolution for semantic segmentation)
ENet (ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation)
Vanilla FCN: FCN32, FCN16, FCN8, in the versions of VGG, ResNet and DenseNet respectively (Fully convolutional networks for semantic segmentation)
FRRN (Full Resolution Residual Networks for Semantic Segmentation in Street Scenes)
FusionNet (FusionNet in Tensorflow by Hyungjoo Andrew Cho)
GCN (Large Kernel Matters)
LinkNet (Link-Net)
PSPNet (Pyramid scene parsing network)
RefineNet (RefineNet)
SegNet (Segnet: A deep convolutional encoder-decoder architecture for image segmentation)
Tiramisu (The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation)
U-Net (U-net: Convolutional networks for biomedical image segmentation)
Additional variations of many of the above
Read the docs for useful details! Then dive in:
# use the `get_model` utilityfrom pywick.models.model_utils import get_model, ModelType model = get_model(model_type=ModelType.CLASSIFICATION, model_name='resnet18', num_classes=1000, pretrained=True)
For a complete list of models (including many experimental ones) you can call the get_supported_models
method e.g. pywick.models.model_utils.get_supported_models(ModelType.SEGMENTATION)
The PyWick package provides wide variety of good data augmentation and transformation tools which can be applied during data loading. The package also provides the flexible TensorDataset
, FolderDataset
and MultiFolderDataset
classes to handle most dataset needs.
AddChannel
ChannelsFirst
ChannelsLast
Compose
ExpandAxis
Pad
PadNumpy
RandomChoiceCompose
RandomCrop
RandomFlip
RandomOrder
RangeNormalize
Slice2D
SpecialCrop
StdNormalize
ToFile
ToNumpyType
ToTensor
Transpose
TypeCast
Brightness
Contrast
Gamma
Grayscale
RandomBrightness
RandomChoiceBrightness
RandomChoiceContrast
RandomChoiceGamma
RandomChoiceSaturation
RandomContrast
RandomGamma
RandomGrayscale
RandomSaturation
Saturation
RandomAffine
RandomChoiceRotate
RandomChoiceShear
RandomChoiceTranslate
RandomChoiceZoom
RandomRotate
RandomShear
RandomSquareZoom
RandomTranslate
RandomZoom
Rotate
Shear
Translate
Zoom
We also provide a class for stringing multiple affine transformations together so that only one interpolation takes place:
Affine
AffineCompose
Blur
RandomChoiceBlur
RandomChoiceScramble
Scramble
We provide the following datasets which provide general structure and iterators for sampling from and using transforms on in-memory or out-of-memory data. In particular, the FolderDataset has been designed to fit most of your dataset needs. It has extensive options for data filtering and manipulation. It supports loading images for classification, segmentation and even arbitrary source/target mapping. Take a good look at its documentation for more info.
ClonedDataset
CSVDataset
FolderDataset
MultiFolderDataset
TensorDataset
tnt.BatchDataset
tnt.ConcatDataset
tnt.ListDataset
tnt.MultiPartitionDataset
tnt.ResampleDataset
tnt.ShuffleDataset
tnt.TensorDataset
tnt.TransformDataset
In many scenarios it is important to ensure that your traing set is properly balanced, however, it may not be practical in real life to obtain such a perfect dataset. In these cases you can use the ImbalancedDatasetSampler
as a drop-in replacement for the basic sampler provided by the DataLoader. More information can be found here
from pywick.samplers import ImbalancedDatasetSampler train_loader = torch.utils.data.DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset), batch_size=args.batch_size, **kwargs)
PyWick provides a few utility functions not commonly found:
th_iterproduct
(mimics itertools.product)
th_gather_nd
(N-dimensional version of torch.gather)
th_random_choice
(mimics np.random.choice)
th_pearsonr
(mimics scipy.stats.pearsonr)
th_corrcoef
(mimics np.corrcoef)
th_affine2d
and th_affine3d
(affine transforms on torch.Tensors)
We stand on the shoulders of (github?) giants and couldn't have done this without the rich github ecosystem and community. This framework is based in part on the excellent Torchsample framework originally published by @ncullen93. Additionally, many models have been gently borrowed/modified from @Cadene pretrained models repo as well as @Tramac segmentation repo.
@ncullen93
@cadene
@deallynomore
@recastrodiaz
@zijundeng
@Tramac
And many others! (attributions listed in the codebase as they occur)
And many others! (attributions listed in the codebase as they occur)
Thangs are broken matey! Arrr!!! |
---|
We're working on this project as time permits so you might discover bugs here and there. Feel free to report them, or better yet, to submit a pull request! |
上一篇:Ax
下一篇:torchgpipe
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com