资源算法vsepp_tensorflow

vsepp_tensorflow

2020-02-13 | |  33 |   0 |   0

Improving Visual-Semantic Embeddings with Hard Negatives

Tensorflow implementation of Improving Visual-Semantic Embeddings with Hard Negatives based on the original implementation in PyTorch.

Download the data

The data should be downloaded in the /data directory. To download the Flickr8k dataset, run:

!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -qq Flickr8k_Dataset.zip
!unzip -qq Flickr8k_text.zip
!rm Flickr8k_Dataset.zip
!rm Flickr8k_text.zip

or to download the Flickr30k dataset run:

!wget https://raw.githubusercontent.com/BryanPlummer/flickr30k_entities/master/train.txt
!wget https://raw.githubusercontent.com/BryanPlummer/flickr30k_entities/master/val.txt
!wget https://raw.githubusercontent.com/BryanPlummer/flickr30k_entities/master/test.txt
!wget http://shannon.cs.illinois.edu/DenotationGraph/data/flickr30k-images.tar
!wget http://shannon.cs.illinois.edu/DenotationGraph/data/flickr30k.tar.gz
!tar -xf flickr30k-images.tar
!tar -xzf flickr30k.tar.gz

Training a model (Flickr8k example)

To train a model, the train.py script needs to be run in the following way:

python src/train.py --batch_size 128
                    --prefetch_size 5
                    --images_path "data/Flicker8k_Dataset/"
                    --texts_path "data/Flickr8k.token.txt"
                    --train_imgs_file_path "data/Flickr_8k.trainImages.txt"
                    --val_imgs_file_path "data/Flickr_8k.devImages.txt"
                    --save_model_path "models/name_of_the_model"

Where the default values are taken for the other command line arguments.

Performing inference (Flickr8k example)

If there is a trained model instance, inference on the test set can be run in the following way:

python src/inference.py --batch_size 128
                        --prefetch_size 5
                        --images_path "data/Flicker8k_Dataset/"
                        --texts_path "data/Flickr8k.token.txt"
                        --test_imgs_file_path "data/Flickr_8k.trainImages.txt"
                        --checkpointh_path "models/name_of_the_model"
                        --joint_space 1024 # Assuming that the joint space of the trained instance is 1024
                        --num_layers 1 # Assuming that the number of layers of the trained instance is 1


上一篇:erlang-groupy

下一篇:vsepp-project

用户评价
全部评价

热门资源

  • TensorFlow-Course

    This repository aims to provide simple and read...

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • mxnet_VanillaCNN

    This is a mxnet implementation of the Vanilla C...

  • DuReader_QANet_BiDAF

    Machine Reading Comprehension on DuReader Usin...

  • Klukshu-Sockeye-...

    KLUKSHU SOCKEYE PROJECTS 2016 This repositor...