Source code for rdf2vecgpu.embedders.word2vec

import torch
import torch.nn as nn
from torch.optim import Adam
import lightning as L
import numpy as np
from torch.nn import functional as F


[docs] class SkipGram(L.LightningModule): def __init__( self, vocab_size, embedding_dim, neg_samples, learning_rate=0.003, use_sparse=False, ): super().__init__() self.vocab_size = vocab_size self.learning_rate = learning_rate self.in_embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=use_sparse) self.output_embeddings = nn.Embedding( vocab_size, embedding_dim, sparse=use_sparse ) bound = 1.0 / embedding_dim nn.init.uniform_(self.in_embeddings.weight, -bound, bound) nn.init.uniform_(self.output_embeddings.weight, -bound, bound) self.neg_samples = neg_samples self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") self.use_sparse = use_sparse
[docs] def forward(self, center, context, negative): center_embed = self.in_embeddings(center) # (batch, embed) context_embed = self.output_embeddings(context) # (batch, embed) negative_embed = self.output_embeddings(negative) # (batch, neg, embed) # Positive score pos_logits = (center_embed * context_embed).sum(dim=1) pos_targets = torch.ones_like(pos_logits) # pos_score = torch.sum(torch.mul(center_embed, context_embed), dim=1) pos_loss = self.loss_fn(pos_logits, pos_targets) # Negative score neg_score = torch.bmm(negative_embed, center_embed.unsqueeze(-1)).squeeze(-1) neg_targets = torch.zeros_like(neg_score) neg_loss = self.loss_fn(neg_score, neg_targets).sum(dim=1) return torch.mean(pos_loss + neg_loss)
[docs] def training_step(self, batch, batch_idx): center, context = batch negative = torch.randint( 0, self.vocab_size, (center.size(0), self.neg_samples), device=self.device ) loss = self(center, context, negative) self.log("train_loss", loss, prog_bar=True) return loss
def configure_optimizers(self): opt_function = torch.optim.SparseAdam if self.use_sparse else torch.optim.Adam return opt_function(self.parameters(), lr=self.learning_rate)
[docs] class CBOW(L.LightningModule): def __init__( self, vocab_size, embedding_dim, learning_rate=0.003, use_sparse=False ): super().__init__() self.save_hyperparameters() self.vocab_size = vocab_size self.learning_rate = learning_rate self.use_sparse = use_sparse self.in_embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=use_sparse) self.output_embeddings = nn.Linear(embedding_dim, vocab_size) bound = 1.0 / embedding_dim nn.init.uniform_(self.in_embeddings.weight, -bound, bound) nn.init.uniform_(self.output_embeddings.weight, -bound, bound) self.loss_fn = nn.CrossEntropyLoss()
[docs] def forward(self, context): mask = (context != -1).float() context_embeds = self.in_embeddings(context.clamp(min=0)) masked_embeds = context_embeds * mask.unsqueeze(-1) sum_embeds = masked_embeds.sum(dim=1) counts = mask.sum(dim=1, keepdim=True).clamp(min=1.0) context_mean = sum_embeds / counts logits = self.output_embeddings(context_mean) return logits
[docs] def training_step(self, batch, batch_idx): context, center = batch center = center.long() logits = self(context) loss = self.loss_fn(logits, center) self.log("train_loss", loss, prog_bar=True) return loss
def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
[docs] class OrderAwareSkipgram(L.LightningModule): def __init__( self, vocab_size, embedding_dim, neg_samples, max_distance=5, learning_rate=0.003, use_sparse=False, ): super().__init__() self.save_hyperparameters() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.neg_samples = neg_samples self.max_distance = max_distance self.learning_rate = learning_rate self.use_sparse = use_sparse # Center word embeddings self.in_embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=use_sparse) # Context embeddings with position-specific weights # We create separate output embeddings for each possible distance self.output_embeddings = nn.ModuleList( [ nn.Embedding(vocab_size, embedding_dim, sparse=use_sparse) for _ in range(2 * max_distance + 1) # -max_distance to +max_distance ] ) # Position embeddings to encode relative distance self.position_embeddings = nn.Embedding(2 * max_distance + 1, embedding_dim) # Initialize embeddings bound = 1.0 / embedding_dim nn.init.uniform_(self.in_embeddings.weight, -bound, bound) nn.init.uniform_(self.position_embeddings.weight, -bound, bound) for output_emb in self.output_embeddings: nn.init.uniform_(output_emb.weight, -bound, bound) self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") def _get_position_index(self, distance): """Convert relative distance to position index""" return distance + self.max_distance
[docs] def forward(self, center, context, distances, negative, neg_distances=None): """ Args: center: (batch_size,) center word indices context: (batch_size,) context word indices distances: (batch_size,) relative distances from center to context negative: (batch_size, neg_samples) negative sample indices neg_distances: (batch_size, neg_samples) distances for negative samples """ batch_size = center.size(0) # Get center embeddings center_embed = self.in_embeddings(center) # (batch_size, embed_dim) # Get position-aware context embeddings pos_indices = self._get_position_index(distances) # (batch_size,) # Compute positive scores with position awareness pos_logits = [] for i in range(batch_size): pos_idx = pos_indices[i].item() context_embed = self.output_embeddings[pos_idx]( context[i : i + 1] ) # (1, embed_dim) pos_embed = self.position_embeddings( pos_indices[i : i + 1] ) # (1, embed_dim) # Combine context and position embeddings combined_embed = context_embed + pos_embed # Compute similarity score logit = (center_embed[i : i + 1] * combined_embed).sum(dim=1) pos_logits.append(logit) pos_logits = torch.cat(pos_logits, dim=0) # (batch_size,) pos_targets = torch.ones_like(pos_logits) pos_loss = self.loss_fn(pos_logits, pos_targets) # Compute negative scores if neg_distances is None: # Sample random distances for negative samples neg_distances = torch.randint( -self.max_distance, self.max_distance + 1, (batch_size, self.neg_samples), device=center.device, ) neg_pos_indices = self._get_position_index( neg_distances ) # (batch_size, neg_samples) neg_logits = [] for i in range(batch_size): center_i = center_embed[i : i + 1] # (1, embed_dim) neg_logits_i = [] for j in range(self.neg_samples): pos_idx = neg_pos_indices[i, j].item() neg_embed = self.output_embeddings[pos_idx]( negative[i : i + 1, j : j + 1] ) # (1, 1, embed_dim) neg_pos_embed = self.position_embeddings( neg_pos_indices[i : i + 1, j : j + 1] ) # (1, 1, embed_dim) # Combine negative context and position embeddings combined_neg_embed = neg_embed.squeeze(1) + neg_pos_embed.squeeze( 1 ) # (1, embed_dim) # Compute similarity score logit = (center_i * combined_neg_embed).sum(dim=1) neg_logits_i.append(logit) neg_logits.append(torch.cat(neg_logits_i, dim=0)) neg_logits = torch.stack(neg_logits, dim=0) # (batch_size, neg_samples) neg_targets = torch.zeros_like(neg_logits) neg_loss = self.loss_fn(neg_logits, neg_targets).sum(dim=1) return torch.mean(pos_loss + neg_loss)
[docs] def training_step(self, batch, batch_idx): if len(batch) == 3: center, context, distances = batch neg_distances = None else: center, context, distances, neg_distances = batch batch_size = center.size(0) # Generate negative samples negative = torch.randint( 0, self.vocab_size, (batch_size, self.neg_samples), device=self.device ) # Generate negative distances if not provided if neg_distances is None: neg_distances = torch.randint( -self.max_distance, self.max_distance + 1, (batch_size, self.neg_samples), device=self.device, ) loss = self(center, context, distances, negative, neg_distances) self.log("train_loss", loss, prog_bar=True) return loss
def configure_optimizers(self): opt_function = torch.optim.SparseAdam if self.use_sparse else torch.optim.Adam return opt_function(self.parameters(), lr=self.learning_rate)
[docs] def get_embeddings(self): """Return the center word embeddings for downstream tasks""" return self.in_embeddings.weight.data
[docs] def get_context_embeddings(self, distance=0): """Get context embeddings for a specific distance""" pos_idx = self._get_position_index(torch.tensor(distance)) return self.output_embeddings[pos_idx.item()].weight.data
[docs] class OrderAwareCBOW(L.LightningModule): def __init__( self, vocab_size, embedding_dim, neg_samples, max_distance=5, learning_rate=0.003, use_sparse=False, ): super().__init__() self.save_hyperparameters() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.neg_samples = neg_samples self.max_distance = max_distance self.learning_rate = learning_rate self.use_sparse = use_sparse # Context word embeddings with position-specific weights self.in_embeddings = nn.ModuleList( [ nn.Embedding(vocab_size, embedding_dim, sparse=use_sparse) for _ in range(2 * max_distance + 1) # -max_distance to +max_distance ] ) # Output embeddings for center words self.output_embeddings = nn.Embedding( vocab_size, embedding_dim, sparse=use_sparse ) # Position embeddings to encode relative distance self.position_embeddings = nn.Embedding(2 * max_distance + 1, embedding_dim) # Initialize embeddings bound = 1.0 / embedding_dim for input_emb in self.in_embeddings: nn.init.uniform_(input_emb.weight, -bound, bound) nn.init.uniform_(self.output_embeddings.weight, -bound, bound) nn.init.uniform_(self.position_embeddings.weight, -bound, bound) self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") def _get_position_index(self, distance): """Convert relative distance to position index""" return distance + self.max_distance
[docs] def forward( self, context_words, context_distances, center, negative, neg_distances=None ): """ Args: context_words: (batch_size, context_size) context word indices context_distances: (batch_size, context_size) relative distances from center to each context word center: (batch_size,) center word indices negative: (batch_size, neg_samples) negative sample indices neg_distances: (batch_size, neg_samples) distances for negative samples (not used in CBOW) """ batch_size, context_size = context_words.size() # Get position-aware context embeddings and combine them context_embeds = [] for i in range(batch_size): batch_context_embeds = [] for j in range(context_size): # Get position index for this context word distance = context_distances[i, j] pos_idx = self._get_position_index(distance).item() # Get context embedding from position-specific embedding matrix ctx_embed = self.in_embeddings[pos_idx]( context_words[i : i + 1, j : j + 1] ) # (1, 1, embed_dim) pos_embed = self.position_embeddings( self._get_position_index(distance).unsqueeze(0) ) # (1, embed_dim) # Combine context and position embeddings combined_embed = ctx_embed.squeeze(1) + pos_embed # (1, embed_dim) batch_context_embeds.append(combined_embed) # Average the context embeddings for this batch item avg_context = torch.mean( torch.cat(batch_context_embeds, dim=0), dim=0, keepdim=True ) # (1, embed_dim) context_embeds.append(avg_context) context_embeds = torch.cat(context_embeds, dim=0) # (batch_size, embed_dim) # Get center word embeddings center_embed = self.output_embeddings(center) # (batch_size, embed_dim) # Compute positive scores pos_logits = (context_embeds * center_embed).sum(dim=1) # (batch_size,) pos_targets = torch.ones_like(pos_logits) pos_loss = self.loss_fn(pos_logits, pos_targets) # Compute negative scores negative_embed = self.output_embeddings( negative ) # (batch_size, neg_samples, embed_dim) # Expand context embeddings to match negative samples shape context_embeds_expanded = context_embeds.unsqueeze( 1 ) # (batch_size, 1, embed_dim) # Compute negative logits neg_logits = (context_embeds_expanded * negative_embed).sum( dim=2 ) # (batch_size, neg_samples) neg_targets = torch.zeros_like(neg_logits) neg_loss = self.loss_fn(neg_logits, neg_targets).sum(dim=1) return torch.mean(pos_loss + neg_loss)
[docs] def training_step(self, batch, batch_idx): if len(batch) == 3: context_words, context_distances, center = batch else: context_words, context_distances, center, _ = ( batch # Ignore neg_distances if provided ) batch_size = center.size(0) # Generate negative samples negative = torch.randint( 0, self.vocab_size, (batch_size, self.neg_samples), device=self.device ) loss = self(context_words, context_distances, center, negative) self.log("train_loss", loss, prog_bar=True) return loss
def configure_optimizers(self): opt_function = torch.optim.SparseAdam if self.use_sparse else torch.optim.Adam return opt_function(self.parameters(), lr=self.learning_rate)
[docs] def get_embeddings(self): """Return the center word embeddings for downstream tasks""" return self.output_embeddings.weight.data
[docs] def get_context_embeddings(self, distance=0): """Get context embeddings for a specific distance""" pos_idx = self._get_position_index(torch.tensor(distance)) return self.in_embeddings[pos_idx.item()].weight.data
[docs] def predict_center(self, context_words, context_distances, top_k=5): """Predict the most likely center words given context""" self.eval() with torch.no_grad(): batch_size, context_size = context_words.size() # Get position-aware context embeddings context_embeds = [] for i in range(batch_size): batch_context_embeds = [] for j in range(context_size): distance = context_distances[i, j] pos_idx = self._get_position_index(distance).item() ctx_embed = self.in_embeddings[pos_idx]( context_words[i : i + 1, j : j + 1] ) pos_embed = self.position_embeddings( self._get_position_index(distance).unsqueeze(0) ) combined_embed = ctx_embed.squeeze(1) + pos_embed batch_context_embeds.append(combined_embed) avg_context = torch.mean( torch.cat(batch_context_embeds, dim=0), dim=0, keepdim=True ) context_embeds.append(avg_context) context_embeds = torch.cat(context_embeds, dim=0) # Compute similarity with all center words all_center_embeds = self.output_embeddings.weight # (vocab_size, embed_dim) scores = torch.mm( context_embeds, all_center_embeds.t() ) # (batch_size, vocab_size) # Get top-k predictions _, top_indices = torch.topk(scores, top_k, dim=1) return top_indices