Pytorch Negative Sampling Loss
Negative Sampling Loss implemented in PyTorch.
Usage
neg_loss = NEG_loss(num_classes, embedding_size)optimizer = SGD(neg_loss.parameters(), 0.1)for i in range(num_iterations):
''' input is [batch_size] shaped tensors of Long type while target has shape of [batch_size, window_size] '''
input, target = next_batch(batch_size)
loss = neg_loss(input, target, num_sample)
optimizer.zero_grad()
loss.backward()
optimizer.step()word_embeddings = neg_loss.input_embeddings()