资源算法anuvada

anuvada

2019-09-19 | |  74 |   0 |   0

This repo is deprecated. Please find the updated package here.

https://github.com/EdGENetworks/anuvada

Anuvada: Interpretable Models for NLP using PyTorch

One of the common criticisms of deep learning has been it's black box nature. To address this issue, researchers have developed many ways to visualise and explain the inference. Some examples would be attention in the case of RNN's, activation maps, guided back propagation and occlusion (in the case of CNN's). This library is an ongoing effort to provide a high-level access to such models relying on PyTorch.

Installing

Clone this repo and add it to your python library path.

Getting started

Importing libraries

import anuvadaimport numpy as npimport torchimport pandas as pd
from anuvada.models.classification_attention_rnn import AttentionClassifier

Creating the dataset

from anuvada.datasets.data_loader import CreateDatasetfrom anuvada.datasets.data_loader import LoadData
data = CreateDataset()
df = pd.read_csv('MovieSummaries/movie_summary_filtered.csv')
# passing only the first 512 samples, I don't have a GPU!y = list(df.Genre.values)[0:512]x = list(df.summary.values)[0:512]
x, y = data.create_dataset(x,y, folder_path='test', max_doc_tokens=500)

Loading created dataset

l = LoadData()
x, y, token2id, label2id, lengths_mask = l.load_data_from_path('test')

Change into torch vectors

x = torch.from_numpy(x)
y = torch.from_numpy(y)

Create attention classifier

acf = AttentionClassifier(vocab_size=len(token2id),embed_size=25,gru_hidden=25,n_classes=len(label2id))
loss = acf.fit(x,y, lengths_mask ,epochs=5)
Epoch 1 / 5
[========================================] 100% loss: 3.9904loss: 3.9904

Epoch 2 / 5
[========================================] 100% loss: 3.9851loss: 3.9851

Epoch 3 / 5
[========================================] 100% loss: 3.9783loss: 3.9783

Epoch 4 / 5
[========================================] 100% loss: 3.9739loss: 3.9739

Epoch 5 / 5
[========================================] 100% loss: 3.9650loss: 3.9650

To do list

  • [x] Implement Attention with RNN

  • [ ] Implement Attention Visualisation

  • [x] Implement working Fit Module

  • [x] Implement support for masking gradients in RNN (Working now!)

  • [x] Implement a generic data set loader

  • [ ] Implement CNN Classifier with feature map visualisation

Acknowledgments

  • https://github.com/henryre/pytorch-fitmodule


上一篇:DEF

下一篇:Tiny Face (w/ training)

用户评价
全部评价

热门资源

  • Keras-ResNeXt

    Keras ResNeXt Implementation of ResNeXt models...

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • spark-corenlp

    This package wraps Stanford CoreNLP annotators ...

  • capsnet-with-caps...

    CapsNet with capsule-wise convolution Project ...

  • inferno-boilerplate

    This is a very basic boilerplate example for pe...