资源算法pytorch-fitmodule

pytorch-fitmodule

2019-10-09 | |  161 |   0 |   0

A super simple fit method for PyTorch Modules

Ever wanted a pretty, Keras-like fit method for your PyTorch Modules? Here's one. It lacks some of the advanced functionality, but it's easy to use:

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom pytorch_fitmodule import FitModule

X, Y, n_classes = torch.get_me_some_data()class MLP(FitModule):    def __init__(self, n_feats, n_classes, hidden_size=50):        super(MLP, self).__init__()        self.fc1 = nn.Linear(n_feats, hidden_size)        self.fc2 = nn.Linear(hidden_size, n_classes)    def forward(self, x):        return F.log_softmax(self.fc2(F.relu(self.fc1(x))))

f = MLP(X.size()[1], n_classes)def n_correct(y_true, y_pred):    return (y_true == torch.max(y_pred, 1)[1]).sum()

f.fit(X, Y, epochs=5, validation_split=0.3, metrics=[n_correct])

Installation

Just clone this repo and add it to your Python path. You'll need

all of which are available via Anaconda.

Example

Try out a simple example with the included script:

python run_example.py
Epoch 1 / 10
[========================================] 100%	loss: 1.3285    accuracy: 0.5676    val_loss: 1.0450    val_accuracy: 0.5693

Epoch 2 / 10
[========================================] 100%	loss: 0.8004    accuracy: 0.8900    val_loss: 0.5804    val_accuracy: 0.8900

Epoch 3 / 10
[========================================] 100%	loss: 0.4638    accuracy: 0.8981    val_loss: 0.3845    val_accuracy: 0.8983

Epoch 4 / 10
[========================================] 100%	loss: 0.3357    accuracy: 0.9033    val_loss: 0.2998    val_accuracy: 0.9043

Epoch 5 / 10
[========================================] 100%	loss: 0.2684    accuracy: 0.9196    val_loss: 0.2462    val_accuracy: 0.9213

Epoch 6 / 10
[========================================] 100%	loss: 0.2215    accuracy: 0.9374    val_loss: 0.2061    val_accuracy: 0.9423

Epoch 7 / 10
[========================================] 100%	loss: 0.1841    accuracy: 0.9586    val_loss: 0.1738    val_accuracy: 0.9590

Epoch 8 / 10
[========================================] 100%	loss: 0.1543    accuracy: 0.9704    val_loss: 0.1478    val_accuracy: 0.9673

Epoch 9 / 10
[========================================] 100%	loss: 0.1298    accuracy: 0.9806    val_loss: 0.1266    val_accuracy: 0.9747

Epoch 10 / 10
[========================================] 100%	loss: 0.1099    accuracy: 0.9861    val_loss: 0.1094    val_accuracy: 0.9800



上一篇:inferno

下一篇:inferno-sklearn

用户评价
全部评价

热门资源

  • Keras-ResNeXt

    Keras ResNeXt Implementation of ResNeXt models...

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • spark-corenlp

    This package wraps Stanford CoreNLP annotators ...

  • capsnet-with-caps...

    CapsNet with capsule-wise convolution Project ...

  • inferno-boilerplate

    This is a very basic boilerplate example for pe...