Data Augmentation and Sampling for Pytorch
v0.1.3 JUST RELEASED - contains significant improvements, bug fixes, and additional support. Get it from the releases, or pull the master branch.
This package provides a few things: - A high-level module for Keras-like training with callbacks, constraints, and regularizers. - Comprehensive data augmentation, transforms, sampling, and loading - Utility tensor and variable functions so you don't need numpy as often
Have any feature requests? Submit an issue! I'll make it happen. Specifically, any data augmentation, data loading, or sampling functions.
Want to contribute? Check the issues page for those tagged with [contributions welcome].
The ModuleTrainer
class provides a high-level training interface which abstracts away the training loop while providing callbacks, constraints, initializers, regularizers, and more.
Example:
from torchsample.modules import ModuleTrainer# Define your model EXACTLY as normalclass Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.conv2 = nn.Conv2d(32, 64, kernel_size=3) self.fc1 = nn.Linear(1600, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 1600) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x)model = Network()trainer = ModuleTrainer(model)trainer.compile(loss='nll_loss', optimizer='adadelta')trainer.fit(x_train, y_train, val_data=(x_test, y_test), num_epoch=20, batch_size=128, verbose=1)
You also have access to the standard evaluation and prediction functions:
loss = model.evaluate(x_train, y_train)y_pred = model.predict(x_train)
Torchsample provides a wide range of callbacks, generally mimicking the interface found in Keras
:
EarlyStopping
ModelCheckpoint
LearningRateScheduler
ReduceLROnPlateau
CSVLogger
from torchsample.callbacks import EarlyStoppingcallbacks = [EarlyStopping(monitor='val_loss', patience=5)]model.set_callbacks(callbacks)
Torchsample also provides regularizers:
L1Regularizer
L2Regularizer
L1L2Regularizer
and constraints: - UnitNorm
- MaxNorm
- NonNeg
Both regularizers and constraints can be selectively applied on layers using regular expressions and the module_filter
argument. Constraints can be explicit (hard) constraints applied at an arbitrary batch or epoch frequency, or they can be implicit (soft) constraints similar to regularizers where the the constraint deviation is added as a penalty to the total model loss.
from torchsample.constraints import MaxNorm, NonNegfrom torchsample.regularizers import L1Regularizer# hard constraint applied every 5 batcheshard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*')# implicit constraint added as a penalty term to model losssoft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*')constraints = [hard_constraint, soft_constraint]model.set_constraints(constraints)regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')]model.set_regularizers(regularizers)
You can also fit directly on a torch.utils.data.DataLoader
and can have a validation set as well :
from torchsample import TensorDatasetfrom torch.utils.data import DataLoadertrain_dataset = TensorDataset(x_train, y_train)train_loader = DataLoader(train_dataset, batch_size=32)val_dataset = TensorDataset(x_val, y_val)val_loader = DataLoader(val_dataset, batch_size=32)trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)
Finally, torchsample provides a few utility functions not commonly found:
th_iterproduct
(mimics itertools.product)
th_gather_nd
(N-dimensional version of torch.gather)
th_random_choice
(mimics np.random.choice)
th_pearsonr
(mimics scipy.stats.pearsonr)
th_corrcoef
(mimics np.corrcoef)
th_affine2d
and th_affine3d
(affine transforms on torch.Tensors)
F_affine2d
and F_affine3d
F_map_coordinates2d
and F_map_coordinates3d
The torchsample package provides a ton of good data augmentation and transformation tools which can be applied during data loading. The package also provides the flexible TensorDataset
and FolderDataset
classes to handle most dataset needs.
These transforms work directly on torch tensors
Compose()
AddChannel()
SwapDims()
RangeNormalize()
StdNormalize()
Slice2D()
RandomCrop()
SpecialCrop()
Pad()
RandomFlip()
ToTensor()
The following transforms perform affine (or affine-like) transforms on torch tensors.
Rotate()
Translate()
Shear()
Zoom()
We also provide a class for stringing multiple affine transformations together so that only one interpolation takes place:
Affine()
AffineCompose()
We provide the following datasets which provide general structure and iterators for sampling from and using transforms on in-memory or out-of-memory data:
TensorDataset()
FolderDataset()
Thank you to the following people and contributors: - All Keras contributors - @deallynomore - @recastrodiaz
上一篇:FlowNet 2.0
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com