diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index f13da3d0b..ed783d9ba 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -21,6 +21,7 @@ from sentence_transformers.datasets import ParallelSentencesDataset, SentencesDataset from sentence_transformers.LoggingHandler import LoggingHandler from sentence_transformers.model_card import SentenceTransformerModelCardData +from sentence_transformers.multi_vec_encoder import MultiVectorEncoder from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.readers import InputExample from sentence_transformers.sampler import DefaultBatchSampler, MultiDatasetDefaultBatchSampler @@ -62,6 +63,7 @@ "SparseEncoderTrainer", "SparseEncoderTrainingArguments", "SparseEncoderModelCardData", + "MultiVectorEncoder", "quantize_embeddings", "export_optimized_onnx_model", "export_dynamic_quantized_onnx_model", diff --git a/sentence_transformers/multi_vec_encoder/LateInteractionPooling.py b/sentence_transformers/multi_vec_encoder/LateInteractionPooling.py new file mode 100644 index 000000000..ec866630a --- /dev/null +++ b/sentence_transformers/multi_vec_encoder/LateInteractionPooling.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from sentence_transformers.models.Module import Module + + +class LateInteractionPooling(Module): + """ + Pooling layer that preserves token-level embeddings for multi-vector encoder models. + + Unlike standard Pooling which collapses token embeddings into a single sentence embedding, + LateInteractionPooling keeps all token embeddings but optionally: + - Projects them to a lower dimension (e.g., 768 → 128) + - Masks out special tokens ([CLS], [SEP]) + - Applies L2 normalization per token + + This is used for multi-vector encoder models where the similarity between + a query and document is computed via MaxSim over token embeddings. + + Note: + The special token masking (skip_cls_token, skip_sep_token) assumes BERT-style tokenization + with right-padding, where [CLS] is at position 0 and [SEP] is the last non-padding token. + This covers most encoder models (BERT, RoBERTa, DistilBERT, etc.). + + Args: + word_embedding_dimension: Dimension of the input word embeddings (e.g., 768 for BERT-based models). + output_dimension: Dimension of the output token embeddings. If None, uses word_embedding_dimension. + Common values are 128 or the original embedding dimension. + normalize: Whether to L2-normalize each token embedding. Default: True. + skip_cls_token: Whether to exclude the [CLS] token from the output. Default: False. + Assumes [CLS] is at position 0. + skip_sep_token: Whether to exclude the [SEP] token from the output. Default: False. + Assumes [SEP] is the last non-padding token (right-padding). + """ + + config_keys = [ + "word_embedding_dimension", + "output_dimension", + "normalize", + "skip_cls_token", + "skip_sep_token", + ] + + def __init__( + self, + word_embedding_dimension: int, + output_dimension: int | None = None, + normalize: bool = True, + skip_cls_token: bool = False, + skip_sep_token: bool = False, + ) -> None: + super().__init__() + + self.word_embedding_dimension = word_embedding_dimension + self.output_dimension = output_dimension if output_dimension is not None else word_embedding_dimension + self.normalize = normalize + self.skip_cls_token = skip_cls_token + self.skip_sep_token = skip_sep_token + + # Linear projection layer if dimensions differ + if self.output_dimension != self.word_embedding_dimension: + self.linear = nn.Linear(self.word_embedding_dimension, self.output_dimension) + else: + self.linear = None + + def __repr__(self) -> str: + return f"LateInteractionPooling({self.get_config_dict()})" + + def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: + """ + Forward pass that preserves token embeddings with optional projection and normalization. + + Args: + features: Dictionary containing: + - token_embeddings: [batch, seq_len, hidden_dim] + - attention_mask: [batch, seq_len] + + Returns: + Dictionary with updated: + - token_embeddings: [batch, seq_len, output_dim] (projected and optionally normalized) + - attention_mask: [batch, seq_len] (potentially modified if skipping tokens) + """ + token_embeddings = features["token_embeddings"] + attention_mask = features.get( + "attention_mask", + torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.long), + ) + + # Linear projection + if self.linear is not None: + token_embeddings = self.linear(token_embeddings) + + # Skip special tokens if configured + if self.skip_cls_token or self.skip_sep_token: + seq_lengths = attention_mask.sum(dim=1) # [batch] + attention_mask = attention_mask.clone() + + if self.skip_cls_token: + # Mask out the first token (CLS) + attention_mask[:, 0] = 0 + + if self.skip_sep_token: + # Mask out the last non-padding token (SEP) for each sequence at position (seq_length - 1) + batch_size = attention_mask.shape[0] + batch_indices = torch.arange(batch_size, device=attention_mask.device) + sep_positions = (seq_lengths - 1).clamp(min=0) + attention_mask[batch_indices, sep_positions] = 0 + + # Apply L2 normalization per token if configured + if self.normalize: + token_embeddings = F.normalize(token_embeddings, p=2, dim=-1) + + features["token_embeddings"] = token_embeddings + features["attention_mask"] = attention_mask + + return features + + def get_output_dimension(self) -> int: + """Returns the output dimension of each token embedding.""" + return self.output_dimension + + def get_sentence_embedding_dimension(self) -> int | None: + """ + Returns None since this module produces token-level embeddings, not a single sentence embedding. + + For multi-vector encoder models, embeddings are multi-vector (one per token), not single-vector. + """ + return None + + def get_config_dict(self) -> dict[str, Any]: + return { + "word_embedding_dimension": self.word_embedding_dimension, + "output_dimension": self.output_dimension, + "normalize": self.normalize, + "skip_cls_token": self.skip_cls_token, + "skip_sep_token": self.skip_sep_token, + } + + @classmethod + def load( + cls, + model_name_or_path: str, + subfolder: str = "", + token: bool | str | None = None, + cache_folder: str | None = None, + revision: str | None = None, + local_files_only: bool = False, + **kwargs, + ) -> LateInteractionPooling: + """Load the LateInteractionPooling module from a checkpoint.""" + config = cls.load_config( + model_name_or_path, + subfolder=subfolder, + token=token, + cache_folder=cache_folder, + revision=revision, + local_files_only=local_files_only, + ) + model = cls(**config) + + # Load weights if there's a linear projection layer + if model.linear is not None: + try: + cls.load_torch_weights( + model_name_or_path, + subfolder=subfolder, + token=token, + cache_folder=cache_folder, + revision=revision, + local_files_only=local_files_only, + model=model, + ) + except ValueError: + # No weights file found + pass + + return model + + def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None: + """ + Save the LateInteractionPooling module to disk. + + Args: + output_path: Directory to save the module. + safe_serialization: Whether to use safetensors format. + """ + os.makedirs(output_path, exist_ok=True) + + # Save config + self.save_config(output_path) + + # Save linear layer weights if present + if self.linear is not None: + self.save_torch_weights(output_path, safe_serialization=safe_serialization) diff --git a/sentence_transformers/multi_vec_encoder/MultiVectorEncoder.py b/sentence_transformers/multi_vec_encoder/MultiVectorEncoder.py new file mode 100644 index 000000000..ac94a6090 --- /dev/null +++ b/sentence_transformers/multi_vec_encoder/MultiVectorEncoder.py @@ -0,0 +1,548 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable +from typing import Any, Literal + +import numpy as np +import torch +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence + +from sentence_transformers.models import Transformer +from sentence_transformers.multi_vec_encoder import LateInteractionPooling +from sentence_transformers.multi_vec_encoder.similarity import maxsim, maxsim_pairwise +from sentence_transformers.SentenceTransformer import SentenceTransformer + +logger = logging.getLogger(__name__) + + +class MultiVectorEncoder(SentenceTransformer): + """ + Multi-Vector Encoder for multi-vector encoding. + + Unlike standard SentenceTransformer which produces a single embedding per text, + MultiVectorEncoder produces multiple embeddings (one per token) and computes + similarity via MaxSim (maximum similarity) between token embeddings. + + Args: + model_name_or_path: If it is a filepath on disk, it loads the model from that path. + If it is not a path, it first tries to download a pre-trained MultiVectorEncoder. + If that fails, tries to construct a model from the Hugging Face Hub with that name. + modules: A list of torch Modules that should be called sequentially. + Can be used to create custom MultiVectorEncoder models from scratch. + device: Device (like "cuda", "cpu", "mps") that should be used for computation. + If None, checks if a GPU can be used. + prompts: A dictionary with prompts for the model. The key is the prompt name, + the value is the prompt text. + default_prompt_name: The name of the prompt that should be used by default. + cache_folder: Path to store models. + trust_remote_code: Whether or not to allow for custom models defined on the Hub. + revision: The specific model version to use. + local_files_only: Whether or not to only look at local files. + token: Hugging Face authentication token to download private models. + model_kwargs: Additional model configuration parameters. + tokenizer_kwargs: Additional tokenizer configuration parameters. + config_kwargs: Additional model configuration parameters. + backend: The backend to use for inference ("torch", "onnx", or "openvino"). + """ + + def __init__( + self, + model_name_or_path: str | None = None, + modules: Iterable[nn.Module] | None = None, + device: str | None = None, + prompts: dict[str, str] | None = None, + default_prompt_name: str | None = None, + cache_folder: str | None = None, + trust_remote_code: bool = False, + revision: str | None = None, + local_files_only: bool = False, + token: bool | str | None = None, + model_kwargs: dict[str, Any] | None = None, + tokenizer_kwargs: dict[str, Any] | None = None, + config_kwargs: dict[str, Any] | None = None, + backend: Literal["torch", "onnx", "openvino"] = "torch", + ) -> None: + super().__init__( + model_name_or_path=model_name_or_path, + modules=modules, + device=device, + prompts=prompts, + default_prompt_name=default_prompt_name, + cache_folder=cache_folder, + trust_remote_code=trust_remote_code, + revision=revision, + local_files_only=local_files_only, + token=token, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, + backend=backend, + ) + + def encode_query( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 32, + show_progress_bar: bool | None = None, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str | None = None, + normalize_embeddings: bool = True, + **kwargs: Any, + ) -> list[np.ndarray] | list[Tensor]: + """ + Encode queries into multi-vector token embeddings. + + This method is a specialized version of :meth:`encode` that: + 1. Uses a predefined "query" prompt if available + 2. Sets the task to "query" for Router-based models + + Args: + sentences: The sentences to embed. + prompt_name: The name of the prompt to use. Defaults to "query" if available. + prompt: The prompt text to prepend. + batch_size: The batch size for encoding. + show_progress_bar: Whether to show a progress bar. + convert_to_numpy: Whether to convert outputs to numpy arrays. + convert_to_tensor: Whether to convert outputs to tensors. + device: Device to use for computation. + normalize_embeddings: Whether to L2-normalize each token embedding. + **kwargs: Additional arguments passed to encode. + + Returns: + List of embeddings, each with shape [num_tokens, dim]. + """ + if prompt_name is None and "query" in self.prompts and prompt is None: + prompt_name = "query" + + return self.encode( + sentences=sentences, + prompt_name=prompt_name, + prompt=prompt, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + convert_to_numpy=convert_to_numpy, + convert_to_tensor=convert_to_tensor, + device=device, + normalize_embeddings=normalize_embeddings, + task="query", + **kwargs, + ) + + def encode_document( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 32, + show_progress_bar: bool | None = None, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str | None = None, + normalize_embeddings: bool = True, + **kwargs: Any, + ) -> list[np.ndarray] | list[Tensor]: + """ + Encode documents into multi-vector token embeddings. + + This method is a specialized version of :meth:`encode` that: + 1. Uses a predefined "document" prompt if available + 2. Sets the task to "document" for Router-based models + + Args: + sentences: The sentences to embed. + prompt_name: The name of the prompt to use. Defaults to "document"/"passage"/"corpus" if available. + prompt: The prompt text to prepend. + batch_size: The batch size for encoding. + show_progress_bar: Whether to show a progress bar. + convert_to_numpy: Whether to convert outputs to numpy arrays. + convert_to_tensor: Whether to convert outputs to tensors. + device: Device to use for computation. + normalize_embeddings: Whether to L2-normalize each token embedding. + **kwargs: Additional arguments passed to encode. + + Returns: + List of embeddings, each with shape [num_tokens, dim]. + """ + if prompt_name is None and prompt is None: + for candidate_prompt_name in ["document", "passage", "corpus"]: + if candidate_prompt_name in self.prompts: + prompt_name = candidate_prompt_name + break + + return self.encode( + sentences=sentences, + prompt_name=prompt_name, + prompt=prompt, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + convert_to_numpy=convert_to_numpy, + convert_to_tensor=convert_to_tensor, + device=device, + normalize_embeddings=normalize_embeddings, + task="document", + **kwargs, + ) + + def encode( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 32, + show_progress_bar: bool | None = None, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str | None = None, + normalize_embeddings: bool = True, + **kwargs: Any, + ) -> list[np.ndarray] | list[Tensor]: + """ + Encode sentences into multi-vector token embeddings. + + Unlike standard SentenceTransformer.encode() which returns a single embedding per sentence, + this returns a list of embeddings (one per sentence), where each embedding has shape + [num_tokens, dim] representing the token-level embeddings. + + Note: + Token normalization is handled by LateInteractionPooling (normalize=True by default). + The normalize_embeddings parameter is passed to the parent but does not affect + token embeddings output - use the pooling layer's normalize setting instead. + + Args: + sentences: The sentences to embed. + prompt_name: The name of the prompt to use. + prompt: The prompt text to prepend. + batch_size: The batch size for encoding. + show_progress_bar: Whether to show a progress bar. + convert_to_numpy: Whether to convert outputs to numpy arrays. + convert_to_tensor: Whether to convert outputs to tensors. + device: Device to use for computation. + normalize_embeddings: Passed to parent (normalization handled by pooling layer). + **kwargs: Additional arguments. + + Returns: + List of embeddings, each with shape [num_tokens, dim]. + The number of tokens varies per sentence based on tokenization. + """ + return super().encode( + sentences=sentences, + prompt_name=prompt_name, + prompt=prompt, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + output_value="token_embeddings", + convert_to_numpy=convert_to_numpy, + convert_to_tensor=convert_to_tensor, + device=device, + normalize_embeddings=normalize_embeddings, + **kwargs, + ) + + def similarity( + self, + query_embeddings: list[np.ndarray] | list[Tensor] | Tensor, + document_embeddings: list[np.ndarray] | list[Tensor] | Tensor, + ) -> Tensor: + """ + Compute MaxSim similarity between query and document embeddings. + + Args: + query_embeddings: Query token embeddings. Either: + - List of arrays/tensors with shape [num_query_tokens, dim] each + - Padded tensor with shape [batch_q, max_query_tokens, dim] + document_embeddings: Document token embeddings. Either: + - List of arrays/tensors with shape [num_doc_tokens, dim] each + - Padded tensor with shape [batch_d, max_doc_tokens, dim] + + Returns: + Similarity scores with shape [batch_q, batch_d]. + """ + query_embs, query_mask = self._prepare_embeddings_for_similarity(query_embeddings) + doc_embs, doc_mask = self._prepare_embeddings_for_similarity(document_embeddings) + + return maxsim(query_embs, doc_embs, query_mask, doc_mask) + + def similarity_pairwise( + self, + query_embeddings: list[np.ndarray] | list[Tensor] | Tensor, + document_embeddings: list[np.ndarray] | list[Tensor] | Tensor, + ) -> Tensor: + """ + Compute pairwise MaxSim similarity between corresponding query-document pairs. + + Args: + query_embeddings: Query token embeddings with batch size N. + document_embeddings: Document token embeddings with batch size N. + + Returns: + Similarity scores with shape [N]. + """ + query_embs, query_mask = self._prepare_embeddings_for_similarity(query_embeddings) + doc_embs, doc_mask = self._prepare_embeddings_for_similarity(document_embeddings) + + return maxsim_pairwise(query_embs, doc_embs, query_mask, doc_mask) + + def _prepare_embeddings_for_similarity( + self, + embeddings: list[np.ndarray] | list[Tensor] | Tensor, + ) -> tuple[Tensor, Tensor]: + """ + Convert embeddings to padded tensor format with attention mask. + + Args: + embeddings: Either a list of variable-length embeddings with shape [num_tokens, dim] each, + or a pre-padded tensor with shape [batch, max_tokens, dim]. + + Returns: + Tuple of (padded_embeddings, attention_mask). + + Raises: + ValueError: If embeddings is empty or has inconsistent dimensions. + """ + # Handle pre-padded tensor input + if isinstance(embeddings, Tensor): + mask = torch.ones(embeddings.shape[:-1], device=embeddings.device, dtype=torch.long) + return embeddings, mask + + # Validate non-empty list + if len(embeddings) == 0: + raise ValueError("embeddings list cannot be empty") + + # Convert numpy arrays to tensors + if isinstance(embeddings[0], np.ndarray): + embeddings = [torch.from_numpy(e) for e in embeddings] + + # Validate all elements are tensors and have consistent dimensions, collect lengths + dim = embeddings[0].shape[-1] + lengths = [] + for i, emb in enumerate(embeddings): + if not isinstance(emb, Tensor): + raise ValueError(f"Expected Tensor at index {i}, got {type(emb).__name__}") + if emb.ndim != 2: + raise ValueError(f"Expected 2D tensor [num_tokens, dim] at index {i}, got shape {emb.shape}") + if emb.shape[-1] != dim: + raise ValueError(f"Inconsistent embedding dimension at index {i}: expected {dim}, got {emb.shape[-1]}") + lengths.append(emb.shape[0]) + + # Padding + padded = pad_sequence(embeddings, batch_first=True, padding_value=0.0) + + # Mask creation + device = embeddings[0].device + lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1) # [batch_size, 1] + positions = torch.arange(padded.shape[1], device=device).unsqueeze(0) # [1, max_len] + mask = (positions < lengths_tensor).long() + + return padded, mask + + def rank( + self, + query: str, + documents: list[str], + top_k: int | None = None, + return_documents: bool = False, + batch_size: int = 32, + show_progress_bar: bool = False, + **kwargs, + ) -> list[dict[str, Any]]: + """ + Rank documents by relevance to a query using MaxSim. + + Args: + query: The query string. + documents: List of document strings to rank. + top_k: Number of top documents to return. If None, returns all. + return_documents: Whether to include document text in results. + batch_size: Batch size for encoding documents. + show_progress_bar: Whether to show a progress bar during encoding. + **kwargs: Additional arguments passed to encode methods. + + Returns: + List of dicts with keys: + - "corpus_id": Index of the document + - "score": MaxSim similarity score + - "text": Document text (if return_documents=True) + """ + # Encode query and documents + query_embedding = self.encode_query( + [query], + batch_size=1, + show_progress_bar=False, + convert_to_tensor=True, + **kwargs, + ) + document_embeddings = self.encode_document( + documents, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + convert_to_tensor=True, + **kwargs, + ) + + # Compute similarities + scores = self.similarity(query_embedding, document_embeddings) + scores = scores[0] # Remove query batch dimension + + # Sort by score + sorted_indices = torch.argsort(scores, descending=True) + + if top_k is not None: + sorted_indices = sorted_indices[:top_k] + + # Build results + results = [] + for idx in sorted_indices: + idx = int(idx) + result = { + "corpus_id": idx, + "score": float(scores[idx]), + } + if return_documents: + result["text"] = documents[idx] + results.append(result) + + return results + + def get_token_embedding_dimension(self) -> int | None: + """ + Returns the dimension of each token embedding. + + Returns: + The token embedding dimension, or None if unknown. + """ + # Check for LateInteractionPooling module + for module in self._modules.values(): + if isinstance(module, LateInteractionPooling): + return module.get_output_dimension() + + # Fall back to transformer dimension + for module in self._modules.values(): + if hasattr(module, "get_word_embedding_dimension"): + return module.get_word_embedding_dimension() + + return None + + def get_sentence_embedding_dimension(self) -> int | None: + """ + Returns None since multi-vector encoder models produce multi-vector embeddings, + not single sentence embeddings. + """ + return None + + def _load_auto_model( + self, + model_name_or_path: str, + token: bool | str | None, + cache_folder: str | None, + revision: str | None = None, + trust_remote_code: bool = False, + local_files_only: bool = False, + model_kwargs: dict[str, Any] | None = None, + tokenizer_kwargs: dict[str, Any] | None = None, + config_kwargs: dict[str, Any] | None = None, + has_modules: bool = False, + ) -> list[nn.Module]: + """ + Creates a multi-vector encoder model from a transformer and returns the modules. + + For models without existing multi-vector encoder configuration, creates: + - Transformer module + - LateInteractionPooling module (with projection to 128 dimensions) + """ + logger.warning( + f"No multi-vector encoder model found with name {model_name_or_path}. " + "Creating a new model with default multi-vector encoder configuration." + ) + + shared_kwargs = { + "token": token, + "trust_remote_code": trust_remote_code, + "revision": revision, + "local_files_only": local_files_only, + } + model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs} + tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs} + config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs} + + # Create transformer module + transformer_model = Transformer( + model_name_or_path, + cache_dir=cache_folder, + model_args=model_kwargs, + tokenizer_args=tokenizer_kwargs, + config_args=config_kwargs, + backend=self.backend, + ) + + # Create multi-vector pooling (project to 128 dimensions, normalize) + word_embedding_dimension = transformer_model.get_word_embedding_dimension() + pooling_model = LateInteractionPooling( + word_embedding_dimension=word_embedding_dimension, + output_dimension=128, + normalize=True, + ) + + modules = [transformer_model, pooling_model] + + return modules + + def save( + self, + path: str, + model_name: str | None = None, + create_model_card: bool = True, + train_datasets: list[str] | None = None, + safe_serialization: bool = True, + ) -> None: + """ + Save the model and its configuration files to a directory. + + Args: + path: Path on disk where the model will be saved. + model_name: Optional model name. + create_model_card: If True, create a README.md with model information. + train_datasets: Optional list of dataset names used to train the model. + safe_serialization: If True, save using safetensors format. + """ + return super().save( + path=path, + model_name=model_name, + create_model_card=create_model_card, + train_datasets=train_datasets, + safe_serialization=safe_serialization, + ) + + def save_pretrained( + self, + path: str, + model_name: str | None = None, + create_model_card: bool = True, + train_datasets: list[str] | None = None, + safe_serialization: bool = True, + ) -> None: + """ + Save the model and its configuration files to a directory. + Alias for :meth:`save`. + """ + return super().save_pretrained( + path=path, + model_name=model_name, + create_model_card=create_model_card, + train_datasets=train_datasets, + safe_serialization=safe_serialization, + ) + + @property + def max_seq_length(self) -> int: + """Returns the maximum input sequence length for the model.""" + return super().max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value) -> None: + """Sets the maximum input sequence length for the model.""" + self._first_module().max_seq_length = value diff --git a/sentence_transformers/multi_vec_encoder/__init__.py b/sentence_transformers/multi_vec_encoder/__init__.py new file mode 100644 index 000000000..bdc168a4f --- /dev/null +++ b/sentence_transformers/multi_vec_encoder/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +# LateInteractionPooling must be imported first to avoid circular import +# (MultiVectorEncoder imports LateInteractionPooling from this package) +from sentence_transformers.multi_vec_encoder.LateInteractionPooling import LateInteractionPooling +from sentence_transformers.multi_vec_encoder.MultiVectorEncoder import MultiVectorEncoder + +__all__ = [ + "MultiVectorEncoder", + "LateInteractionPooling", +] diff --git a/sentence_transformers/multi_vec_encoder/similarity.py b/sentence_transformers/multi_vec_encoder/similarity.py new file mode 100644 index 000000000..df3d3b759 --- /dev/null +++ b/sentence_transformers/multi_vec_encoder/similarity.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import torch +from torch import Tensor + + +def maxsim( + query_embeddings: Tensor, + document_embeddings: Tensor, + query_mask: Tensor | None = None, + document_mask: Tensor | None = None, +) -> Tensor: + """ + Compute MaxSim similarity between queries and documents. + + MaxSim computes, for each query token, the maximum cosine similarity to any document token, + then sums these maximum similarities across all query tokens: + + score(q, d) = Σᵢ maxⱼ(qᵢ · dⱼ) + + where qᵢ are query token embeddings and dⱼ are document token embeddings. + + This is the core similarity function used in late interaction models. + + Args: + query_embeddings: Query token embeddings of shape [batch_q, num_query_tokens, dim]. + Should be L2-normalized for cosine similarity. + document_embeddings: Document token embeddings of shape [batch_d, num_doc_tokens, dim]. + Should be L2-normalized for cosine similarity. + query_mask: Optional attention mask for queries of shape [batch_q, num_query_tokens]. + 1 for valid tokens, 0 for padding. If None, all tokens are considered valid. + document_mask: Optional attention mask for documents of shape [batch_d, num_doc_tokens]. + 1 for valid tokens, 0 for padding. If None, all tokens are considered valid. + + Returns: + Similarity scores of shape [batch_q, batch_d]. + """ + batch_q, num_query_tokens, dim = query_embeddings.shape + batch_d, num_doc_tokens, _ = document_embeddings.shape + + # Compute token-level similarity: [batch_q, batch_d, num_query_tokens, num_doc_tokens] + token_similarities = torch.einsum("aik,bjk->abij", query_embeddings, document_embeddings) + # Reshape to [batch_q, num_query_tokens, batch_d, num_doc_tokens] + token_similarities = token_similarities.permute(0, 2, 1, 3) + + # Apply document mask: set similarities with padding tokens to -inf + if document_mask is not None: + doc_mask_expanded = document_mask.unsqueeze(0).unsqueeze(0) + token_similarities = token_similarities.masked_fill(doc_mask_expanded == 0, float("-inf")) + + # For each query token, find max similarity to any document token + max_similarities = token_similarities.max(dim=-1).values + + # Apply query mask: set masked query tokens to 0 before summing + if query_mask is not None: + query_mask_expanded = query_mask.unsqueeze(-1) + max_similarities = max_similarities * query_mask_expanded + + # Sum over query tokens + scores = max_similarities.sum(dim=1) + + return scores + + +def maxsim_pairwise( + query_embeddings: Tensor, + document_embeddings: Tensor, + query_mask: Tensor | None = None, + document_mask: Tensor | None = None, +) -> Tensor: + """ + Compute pairwise MaxSim similarity between corresponding query-document pairs. + + Unlike :func:`maxsim` which computes all pairs, this function computes similarity + only for corresponding pairs: score[i] = maxsim(query[i], document[i]). + + Args: + query_embeddings: Query token embeddings of shape [batch, num_query_tokens, dim]. + document_embeddings: Document token embeddings of shape [batch, num_doc_tokens, dim]. + query_mask: Optional attention mask for queries of shape [batch, num_query_tokens]. + document_mask: Optional attention mask for documents of shape [batch, num_doc_tokens]. + + Returns: + Similarity scores of shape [batch]. + """ + batch_size, num_query_tokens, dim = query_embeddings.shape + batch_d, num_doc_tokens, _ = document_embeddings.shape + + assert batch_size == batch_d, ( + f"Batch sizes must match for pairwise computation. " + f"Got query batch size {batch_size} and document batch size {batch_d}. " + f"Fallback to maxsim() for computing all query-document pairs." + ) + + # Compute token-level similarity for each pair: [batch, num_query_tokens, num_doc_tokens] + token_similarities = torch.bmm(query_embeddings, document_embeddings.transpose(1, 2)) + + # Apply document mask: set similarities with padding tokens to -inf + if document_mask is not None: + doc_mask_expanded = document_mask.unsqueeze(1) + token_similarities = token_similarities.masked_fill(doc_mask_expanded == 0, float("-inf")) + + # For each query token, find max similarity to any document token + max_similarities = token_similarities.max(dim=-1).values + + # Apply query mask: set masked query tokens to 0 before summing + if query_mask is not None: + max_similarities = max_similarities * query_mask + + # Sum over query tokens + scores = max_similarities.sum(dim=1) + + return scores diff --git a/tests/multi_vec_encoder/__init__.py b/tests/multi_vec_encoder/__init__.py new file mode 100644 index 000000000..9d48db4f9 --- /dev/null +++ b/tests/multi_vec_encoder/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tests/multi_vec_encoder/test_multi_vec_encoder.py b/tests/multi_vec_encoder/test_multi_vec_encoder.py new file mode 100644 index 000000000..0fc4ef974 --- /dev/null +++ b/tests/multi_vec_encoder/test_multi_vec_encoder.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import pytest +import torch + +from sentence_transformers.multi_vec_encoder import LateInteractionPooling +from sentence_transformers.multi_vec_encoder.similarity import maxsim, maxsim_pairwise + + +class TestLateInteractionPooling: + """Tests for LateInteractionPooling module.""" + + @pytest.mark.parametrize( + ("word_dim", "output_dim", "expected_output_dim"), + [ + (768, None, 768), + (768, 768, 768), + (768, 128, 128), + (128, 64, 64), + ], + ) + def test_dimensions(self, word_dim: int, output_dim: int | None, expected_output_dim: int) -> None: + """Test that dimension projection works correctly.""" + pooling = LateInteractionPooling( + word_embedding_dimension=word_dim, + output_dimension=output_dim, + normalize=False, + ) + features = { + "token_embeddings": torch.randn(2, 5, word_dim), + "attention_mask": torch.ones(2, 5, dtype=torch.long), + } + output = pooling(features) + + assert output["token_embeddings"].shape == (2, 5, expected_output_dim) + assert pooling.get_output_dimension() == expected_output_dim + + @pytest.mark.parametrize("normalize", [True, False]) + def test_normalization(self, normalize: bool) -> None: + """Test that L2 normalization produces unit vectors when enabled.""" + pooling = LateInteractionPooling(word_embedding_dimension=64, normalize=normalize) + features = { + "token_embeddings": torch.randn(2, 5, 64) * 10, + "attention_mask": torch.ones(2, 5, dtype=torch.long), + } + output = pooling(features) + norms = torch.norm(output["token_embeddings"], dim=-1) + + if normalize: + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-5) + else: + assert not torch.allclose(norms, torch.ones_like(norms), atol=1e-5) + + @pytest.mark.parametrize( + ("skip_cls", "skip_sep", "attention_mask", "expected_mask"), + [ + (True, False, [[1, 1, 1, 0]], [[0, 1, 1, 0]]), # CLS at 0 masked + (False, True, [[1, 1, 1, 0]], [[1, 1, 0, 0]]), # SEP at 2 masked (last non-padding) + (True, True, [[1, 1, 1, 0]], [[0, 1, 0, 0]]), # CLS at 0, SEP at 2 masked + (False, False, [[1, 1, 1, 0]], [[1, 1, 1, 0]]), # No masking + (True, True, [[1, 1, 1, 1]], [[0, 1, 1, 0]]), # CLS at 0, SEP at 3 masked + ], + ) + def test_skip_tokens(self, skip_cls: bool, skip_sep: bool, attention_mask: list, expected_mask: list) -> None: + """Test that CLS/SEP token skipping modifies attention mask correctly.""" + pooling = LateInteractionPooling( + word_embedding_dimension=32, skip_cls_token=skip_cls, skip_sep_token=skip_sep, normalize=False + ) + features = { + "token_embeddings": torch.randn(1, len(attention_mask[0]), 32), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + } + output = pooling(features) + + assert torch.equal(output["attention_mask"], torch.tensor(expected_mask, dtype=torch.long)) + + def test_config_roundtrip(self, tmp_path) -> None: + """Test that config save/load preserves all settings.""" + pooling = LateInteractionPooling( + word_embedding_dimension=768, + output_dimension=128, + normalize=True, + skip_cls_token=True, + skip_sep_token=False, + ) + pooling.save(str(tmp_path)) + loaded = LateInteractionPooling.load(str(tmp_path)) + + assert (loaded.word_embedding_dimension, loaded.output_dimension) == (768, 128) + assert (loaded.normalize, loaded.skip_cls_token, loaded.skip_sep_token) == (True, True, False) + + def test_get_sentence_embedding_dimension_returns_none(self) -> None: + """Test that get_sentence_embedding_dimension returns None for multi-vector models.""" + assert LateInteractionPooling(word_embedding_dimension=768).get_sentence_embedding_dimension() is None + + +class TestMaxSimSimilarity: + """Tests for MaxSim similarity functions.""" + + @pytest.mark.parametrize( + ("batch_q", "batch_d", "num_q_tokens", "num_d_tokens", "dim"), + [(1, 1, 3, 5, 64), (1, 2, 4, 6, 128), (2, 3, 5, 10, 64), (3, 3, 8, 8, 32)], + ) + def test_output_shape(self, batch_q: int, batch_d: int, num_q_tokens: int, num_d_tokens: int, dim: int) -> None: + """Test that maxsim returns correct output shape [batch_q, batch_d].""" + scores = maxsim(torch.randn(batch_q, num_q_tokens, dim), torch.randn(batch_d, num_d_tokens, dim)) + assert scores.shape == (batch_q, batch_d) + + @pytest.mark.parametrize( + ("query_mask", "doc_mask", "expected_score"), + [ + (None, None, None), # No masks, just check no NaN + ([[1, 1, 0]], None, None), + (None, [[1, 1, 1, 0, 0]], None), + ([[1, 1, 0]], [[1, 1, 1, 0, 0]], None), + ([[0, 0, 0]], None, 0.0), # All query tokens masked -> score = 0 + ], + ) + def test_with_masks(self, query_mask: list | None, doc_mask: list | None, expected_score: float | None) -> None: + """Test that maxsim handles attention masks correctly.""" + query = torch.randn(1, 3, 64) + doc = torch.randn(1, 5, 64) + q_mask = torch.tensor(query_mask, dtype=torch.long) if query_mask else None + d_mask = torch.tensor(doc_mask, dtype=torch.long) if doc_mask else None + + scores = maxsim(query, doc, q_mask, d_mask) + + assert scores.shape == (1, 1) and not torch.isnan(scores).any() + if expected_score is not None: + assert scores.item() == expected_score + + def test_masked_tokens_ignored(self) -> None: + """Test that masked document tokens don't contribute to similarity.""" + query = torch.tensor([[[1.0, 0.0, 0.0, 0.0]]]) + doc = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [-1.0, 0.0, 0.0, 0.0]]]) + + scores_unmasked = maxsim(query, doc) + scores_masked = maxsim(query, doc, document_mask=torch.tensor([[1, 0]])) + + assert torch.allclose(scores_unmasked, scores_masked, atol=1e-5) + + @pytest.mark.parametrize( + ("batch_size", "num_q_tokens", "num_d_tokens", "dim"), + [(1, 3, 5, 64), (2, 4, 6, 128), (5, 8, 10, 32)], + ) + def test_pairwise_output_shape(self, batch_size: int, num_q_tokens: int, num_d_tokens: int, dim: int) -> None: + """Test that maxsim_pairwise returns correct output shape [batch].""" + scores = maxsim_pairwise( + torch.randn(batch_size, num_q_tokens, dim), torch.randn(batch_size, num_d_tokens, dim) + ) + assert scores.shape == (batch_size,) + + def test_pairwise_batch_mismatch_raises(self) -> None: + """Test that mismatched batch sizes raise an AssertionError.""" + with pytest.raises(AssertionError, match="Batch sizes must match"): + maxsim_pairwise(torch.randn(2, 4, 64), torch.randn(3, 5, 64)) + + def test_pairwise_consistency_with_maxsim(self) -> None: + """Test that pairwise scores match diagonal of full maxsim matrix.""" + query, doc = torch.randn(3, 4, 64), torch.randn(3, 5, 64) + assert torch.allclose(maxsim_pairwise(query, doc), torch.diagonal(maxsim(query, doc)), atol=1e-5) + + @pytest.mark.parametrize( + ("query", "doc", "expected"), + [ + (torch.tensor([[[1.0, 0.0]]]), torch.tensor([[[1.0, 0.0]]]), 1.0), # Identical + (torch.tensor([[[1.0, 0.0]]]), torch.tensor([[[0.0, 1.0]]]), 0.0), # Orthogonal + (torch.tensor([[[1.0, 0.0]]]), torch.tensor([[[-1.0, 0.0]]]), -1.0), # Opposite + (torch.zeros(1, 1, 2), torch.zeros(1, 1, 2), 0.0), # All zeros + ( + torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), + torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), + 2.0, + ), # Multi-token sum + ], + ) + def test_edge_cases(self, query: torch.Tensor, doc: torch.Tensor, expected: float) -> None: + """Test maxsim with edge cases: identical, orthogonal, opposite vectors, zeros, multi-token.""" + assert torch.allclose(maxsim(query, doc), torch.tensor([[expected]]), atol=1e-5)