VisualSearch_MXNet
In this tutorial we will create a Visual Search engine for browsing 1M amazon product images.
First step, indexing the image dataset by computing the image embeddings using a pre-trained network as a featurizer:
Second step, query the index using an efficient K-NN search algorithm, here we use Hierarchical Navigable Small World graphs (HNSW)
Pre-requisite:
MXNet: pip install --pre mxnet-cu91
hnswlib (follow the guide here: https://github.com/nmslib/hnsw)
Gluon-CV pip install gluoncv --upgrade
import mxnet as mxfrom mxnet import gluon, ndfrom mxnet.gluon.model_zoo import visionimport multiprocessingfrom mxnet.gluon.data.vision.datasets import ImageFolderDatasetfrom mxnet.gluon.data import DataLoaderimport numpy as npimport wgetimport imghdrimport jsonimport pickleimport hnswlibimport numpy as npimport glob, os, timeimport matplotlib.pyplot as plt import matplotlib.gridspec as gridspecimport urllib.parseimport urllibimport gzip%matplotlib inline
Data originally from here: http://jmcauley.ucsd.edu/data/amazon/
Image-based recommendations on styles and substitutes J. McAuley, C. Targett, J. Shi, A. van den Hengel SIGIR, 2015
We only use a subset of the total number of images, here 1M (it takes about 40 minutes to download all the data on an ec2 instance)
subset_num = 1000000
Beware, if using the full dataset this will download 300GB of images, make sure you have the appropriate hardware and connexion! Alternatively, just set images_path
to a directory containing images following this format ID.jpg
data_path = 'metadata.json'images_path = '/data/amazon_images_subset'
num_lines = 0num_lines = sum(1 for line in open(data_path))assert num_lines >= subset_num, "Subset needs to be smaller or equal to total number of example"
Download the metadata.json file that contains the URL of the images
if not os.path.isfile(data_path): # Downloading the metadata, 3.1GB, unzipped 9GB !wget -nv https://s3.us-east-2.amazonaws.com/mxnet-public/stanford_amazon/metadata.json.gz !gzip -d metadata.json.gzif not os.path.isdir(images_path): os.makedirs(images_path)
def parse(path, num_cpu, modulo): g = open(path, 'r') for i, l in enumerate(g): if (i >= num_lines - subset_num and i%num_cpu == modulo): yield eval(l)
def download_files(modulo): for data in parse(data_path, NUM_CPU, modulo): if 'imUrl' in data and data['imUrl'] is not None and 'categories' in data and data['imUrl'].split('.')[-1] == 'jpg': url = data['imUrl'] try: path = os.path.join(images_path, data['asin']+'.jpg') if not os.path.isfile(path): file = urllib.request.urlretrieve(url, path) except: print("Error downloading {}".format(url))
Downloading the images using 10 times more processes than cores
NUM_CPU = multiprocessing.cpu_count()*10
pool = multiprocessing.Pool(processes=NUM_CPU) results = pool.map(download_files, list(range(NUM_CPU)))
# Removing all the fake jpegslist_files = glob.glob(os.path.join(images_path, '**.jpg'))for file in list_files: if imghdr.what(file) != 'jpeg': print('Removed {} it is a {}'.format(file, imghdr.what(file))) os.remove(file)
BATCH_SIZE = 256EMBEDDING_SIZE = 512SIZE = (224, 224)MEAN_IMAGE= mx.nd.array([0.485, 0.456, 0.406])STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])
We use a pre-trained model from the model zoo
ctx = mx.gpu()
Networks from the model-zoo follow the convention that the features are on the .features
property and output on the .output
property. It makes it very easy to transform any pre-trained network in featurizer.
net = vision.resnet18_v2(pretrained=True, ctx=ctx) net = net.features
to convert the images to a shape usable by the network
def transform(image, label): resized = mx.image.resize_short(image, SIZE[0]).astype('float32') cropped, crop_info = mx.image.center_crop(resized, SIZE) cropped /= 255. normalized = mx.image.color_normalize(cropped, mean=MEAN_IMAGE, std=STD_IMAGE) transposed = nd.transpose(normalized, (2,0,1)) return transposed, label
import os, tempfile, glob empty_folder = tempfile.mkdtemp()# Create an empty image Folder Data Setdataset = ImageFolderDataset(root=empty_folder, transform=transform)
list_files = glob.glob(os.path.join(images_path, '**.jpg'))
Because of the data validation and invalid URL, our actual subset is smaller than the one requested
dataset.items = list(zip(list_files, [0]*len(list_files)))
We load the dataset in a dataloader with as many workers as CPU cores
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, last_batch='keep', shuffle=False, num_workers=multiprocessing.cpu_count())
features = np.zeros((len(dataset), EMBEDDING_SIZE), dtype=np.float32)
%%time tick = time.time() n_print = 100j = 0for i, (data, label) in enumerate(dataloader): data = data.as_in_context(ctx) if i%n_print == 0 and i > 0: print("{0} batches, {1} images, {2:.3f} img/sec".format(i, i*BATCH_SIZE, BATCH_SIZE*n_print/(time.time()-tick))) tick = time.time() output = net(data) features[(i)*BATCH_SIZE:(i+1)*max(BATCH_SIZE, len(output)), :] = output.asnumpy().squeeze()
100 batches, 25600 images, 1333.611 img/sec 200 batches, 51200 images, 2097.873 img/sec 300 batches, 76800 images, 2108.257 img/sec 400 batches, 102400 images, 2119.740 img/sec 500 batches, 128000 images, 2007.043 img/sec 600 batches, 153600 images, 2104.296 img/sec 700 batches, 179200 images, 2155.201 img/sec 800 batches, 204800 images, 2105.456 img/sec 900 batches, 230400 images, 2106.616 img/sec 1000 batches, 256000 images, 2128.810 img/sec 1100 batches, 281600 images, 2125.134 img/sec 1200 batches, 307200 images, 2141.244 img/sec 1300 batches, 332800 images, 2103.341 img/sec 1400 batches, 358400 images, 2116.504 img/sec 1500 batches, 384000 images, 2090.445 img/sec 1600 batches, 409600 images, 2138.420 img/sec 1700 batches, 435200 images, 2088.554 img/sec 1800 batches, 460800 images, 2127.671 img/sec 1900 batches, 486400 images, 2118.631 img/sec 2000 batches, 512000 images, 2084.014 img/sec 2100 batches, 537600 images, 2111.905 img/sec 2200 batches, 563200 images, 2125.523 img/sec 2300 batches, 588800 images, 2106.901 img/sec 2400 batches, 614400 images, 2123.917 img/sec 2500 batches, 640000 images, 2064.876 img/sec 2600 batches, 665600 images, 2117.610 img/sec 2700 batches, 691200 images, 2112.028 img/sec 2800 batches, 716800 images, 2066.120 img/sec 2900 batches, 742400 images, 2068.632 img/sec 3000 batches, 768000 images, 2095.919 img/sec 3100 batches, 793600 images, 2104.414 img/sec 3200 batches, 819200 images, 2090.150 img/sec 3300 batches, 844800 images, 2068.915 img/sec 3400 batches, 870400 images, 2113.243 img/sec 3500 batches, 896000 images, 2105.340 img/sec 3600 batches, 921600 images, 2127.197 img/sec 3700 batches, 947200 images, 2123.200 img/sec CPU times: user 4min 43s, sys: 3min 22s, total: 8min 5s Wall time: 7min 42s
# Number of elements in the indexnum_elements = len(features) labels_index = np.arange(num_elements)
%%time # Declaring indexp = hnswlib.Index(space = 'l2', dim = EMBEDDING_SIZE) # possible options are l2, cosine or ip# Initing index - the maximum number of elements should be known beforehandp.init_index(max_elements = num_elements, ef_construction = 100, M = 16)# Element insertion (can be called several times):int_labels = p.add_items(features, labels_index)# Controlling the recall by setting ef:p.set_ef(100) # ef should always be > k
CPU times: user 31min 34s, sys: 16.4 s, total: 31min 51s Wall time: 1min
p.save_index('index.idx')
We test the results by sampling random images from the dataset and searching their K-NN
def plot_predictions(images): gs = gridspec.GridSpec(3, 3) fig = plt.figure(figsize=(15, 15)) gs.update(hspace=0.1, wspace=0.1) for i, (gg, image) in enumerate(zip(gs, images)): gg2 = gridspec.GridSpecFromSubplotSpec(10, 10, subplot_spec=gg) ax = fig.add_subplot(gg2[:,:]) ax.imshow(image, cmap='Greys_r') ax.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelleft='off', labelbottom='off') ax.axes.set_title("result [{}]".format(i)) if i == 0: plt.setp(ax.spines.values(), color='red') ax.axes.set_title("SEARCH".format(i))
def search(N, k): # Query dataset, k - number of closest elements (returns 2 numpy arrays) q_labels, q_distances = p.knn_query([features[N]], k = k) images = [plt.imread(dataset.items[label][0]) for label in q_labels[0]] plot_predictions(images)
%%time index = np.random.randint(0,len(features)) k = 6search(index, k)
CPU times: user 292 ms, sys: 0 ns, total: 292 ms Wall time: 287 ms
path = 'dress.jpg'
p.set_ef(300) # ef should always be > kimage = plt.imread(path)[:,:,:3] image_t, _ = transform(nd.array(image), 1) output = net(image_t.expand_dims(axis=0).as_in_context(ctx)) labels, distances = p.knn_query([output.asnumpy().reshape(-1,)], k = 5) images = [image] images += [plt.imread(dataset.items[label][0]) for label in labels[0]]
plot_predictions(images)
上一篇:visual_search
下一篇:fk-visual-search
还没有评论,说两句吧!
热门资源
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
shih-styletransfer
shih-styletransfer Code from Style Transfer ...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com