重排技术深度指南

摘要

重排技术是现代RAG系统中的关键环节,通过对初检结果进行精细化排序,显著提升检索质量。本指南深入解析Cross-Encoder、Bi-Encoder和ColBERT三种主流重排架构的技术原理、数学推导和工程实现,为构建高性能检索系统提供系统性指导。

关键词速查表

关键词说明
Cross-Encoder交叉编码器,将查询和文档同时输入产生联合表示
Bi-Encoder双编码器,分别编码查询和文档再计算相似度
ColBERT延迟交互模型,通过延迟最大化实现高效重排
Late Interaction晚交互,在编码后进行交互以平衡效率和效果
Pointwise Ranking点级排序,独立评估每个文档的相关性
Pairwise Ranking对级排序,比较文档对的相对顺序
Listwise Ranking列表级排序,考虑整个文档列表的排序
注意机制Self-Attention和Cross-Attention的统称
蒸馏学习Knowledge Distillation,从大模型迁移知识
微调策略Fine-tuning策略,包括全量微调和PEFT

一、重排技术的定位与价值

1.1 检索系统架构演进

现代检索系统通常采用两阶段甚至多阶段架构:粗排阶段使用轻量级模型快速筛选候选文档,精排阶段使用重排模型进行精细排序。这种分层设计在效果和效率之间取得平衡,使系统能够处理海量文档的同时保证排序质量。

两阶段架构的数学表达如下:设查询为 ,文档集合为 ,初排模型 产生候选集 ,重排模型 进行最终排序。理想情况下, 更精确,但计算成本更高,因此只在较小的候选集上运行。

1.2 重排的核心价值

重排技术解决了向量检索的固有局限。纯向量检索依赖全局语义匹配,可能遗漏精确的关键词匹配或忽略查询中的特定约束。重排模型通过更深层的语义理解,能够捕捉查询意图的细微差别,识别文档中的关键段落,并在全局视角下进行最优排序。

例如,当用户查询”Python如何处理并发”时,向量检索可能返回涵盖”Python异步编程”和”Python多线程”的文档。重排模型能够进一步分析这些文档中与”并发”相关的具体内容权重,从而给出更精准的排序。

1.3 重排技术分类

根据交互方式的不同,重排技术可分为三类:Cross-Encoder提供最完整的交互但计算成本最高;Bi-Encoder计算效率高但交互能力有限;ColBERT等晚交互模型则在两者之间取得平衡。理解这些技术的设计权衡对于正确选择和部署至关重要。

二、Bi-Encoder双编码器详解

2.1 架构原理

Bi-Encoder采用双塔结构,分别独立编码查询和文档。这种设计允许预先计算文档编码,实现高效的批量检索,但代价是查询和文档之间缺乏早期交互,可能丢失细粒度的语义关联。

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from typing import List, Tuple, Optional
 
class BiEncoderModel(nn.Module):
    """
    Bi-Encoder双编码器模型
    
    查询和文档使用共享的编码器参数,但独立编码
    """
    
    def __init__(self, model_name: str, embedding_dim: int = 768,
                 pooling_strategy: str = "cls"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.embedding_dim = embedding_dim
        self.pooling_strategy = pooling_strategy
        
        # 投影层(可选)
        self.projection = nn.Linear(
            self.encoder.config.hidden_size,
            embedding_dim
        ) if embedding_dim != self.encoder.config.hidden_size else nn.Identity()
    
    def encode_query(self, query: str) -> torch.Tensor:
        """编码查询"""
        inputs = self.tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
        outputs = self.encoder(**inputs)
        
        return self._pooling(outputs.last_hidden_state, inputs["attention_mask"])
    
    def encode_document(self, document: str, title: Optional[str] = None) -> torch.Tensor:
        """
        编码文档
        
        可选择性地包含标题信息
        """
        if title:
            text = f"{title}: {document}"
        else:
            text = document
        
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
        outputs = self.encoder(**inputs)
        
        return self._pooling(outputs.last_hidden_state, inputs["attention_mask"])
    
    def _pooling(self, hidden_states: torch.Tensor, 
                 attention_mask: torch.Tensor) -> torch.Tensor:
        """
        池化策略
        
        支持: cls, mean, max, weighted
        """
        if self.pooling_strategy == "cls":
            # 使用[CLS] token
            return hidden_states[:, 0, :]
        
        elif self.pooling_strategy == "mean":
            # Mean Pooling
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
            sum_mask = mask_expanded.sum(dim=1)
            return sum_embeddings / sum_mask.clamp(min=1e-9)
        
        elif self.pooling_strategy == "max":
            # Max Pooling
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            hidden_states[mask_expanded == 0] = -1e9
            return torch.max(hidden_states, dim=1)[0]
        
        elif self.pooling_strategy == "weighted":
            # Attention-weighted Pooling
            attention_weights = torch.softmax(hidden_states, dim=1)
            return torch.sum(hidden_states * attention_weights, dim=1)
        
        raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")
    
    def compute_similarity(self, query_emb: torch.Tensor, 
                          doc_emb: torch.Tensor) -> torch.Tensor:
        """计算查询-文档相似度"""
        # 归一化后计算点积(等价于余弦相似度)
        query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=1)
        doc_emb = torch.nn.functional.normalize(doc_emb, p=2, dim=1)
        return torch.sum(query_emb * doc_emb, dim=-1)

2.2 对比学习训练

Bi-Encoder的训练通常采用对比学习方法,通过最大化正样本对相似度、最小化负样本对相似度来学习有效的表示空间。

class ContrastiveLoss(nn.Module):
    """
    对比损失函数
    
    支持难负例挖掘和温度系数调节
    """
    
    def __init__(self, temperature: float = 0.07, 
                 use_hard_negatives: bool = True):
        super().__init__()
        self.temperature = temperature
        self.use_hard_negatives = use_hard_negatives
    
    def forward(self, query_emb: torch.Tensor, 
                positive_emb: torch.Tensor,
                negative_embs: torch.Tensor) -> torch.Tensor:
        """
        计算对比损失
        
        Args:
            query_emb: 查询嵌入 (batch_size, embedding_dim)
            positive_emb: 正样本嵌入 (batch_size, embedding_dim)
            negative_embs: 负样本嵌入 (batch_size, num_negatives, embedding_dim)
        """
        # 归一化
        query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=1)
        positive_emb = torch.nn.functional.normalize(positive_emb, p=2, dim=1)
        negative_embs = torch.nn.functional.normalize(negative_embs, p=2, dim=-1)
        
        # 计算正样本相似度
        positive_sim = torch.sum(query_emb * positive_emb, dim=-1) / self.temperature
        
        # 计算负样本相似度
        batch_size = query_emb.size(0)
        num_negatives = negative_embs.size(1)
        
        # (batch_size, num_negatives)
        negative_sim = torch.bmm(
            negative_embs, 
            query_emb.unsqueeze(-1)
        ).squeeze(-1) / self.temperature
        
        # InfoNCE损失
        logits = torch.cat([positive_sim.unsqueeze(-1), negative_sim], dim=-1)
        labels = torch.zeros(batch_size, dtype=torch.long, device=query_emb.device)
        
        loss = nn.CrossEntropyLoss()(logits, labels)
        
        return loss
 
class BiEncoderTrainer:
    """Bi-Encoder训练器"""
    
    def __init__(self, model: BiEncoderModel, 
                 learning_rate: float = 2e-5,
                 batch_size: int = 32,
                 num_negatives: int = 4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.batch_size = batch_size
        self.num_negatives = num_negatives
        self.loss_fn = ContrastiveLoss(temperature=0.07)
    
    def train_step(self, queries: List[str], 
                   positives: List[str],
                   negatives: List[List[str]]) -> float:
        """单步训练"""
        self.model.train()
        
        # 编码
        query_embs = []
        positive_embs = []
        negative_embs_list = []
        
        for i in range(0, len(queries), self.batch_size):
            batch_queries = queries[i:i+self.batch_size]
            batch_positives = positives[i:i+self.batch_size]
            batch_negatives = negatives[i:i+self.batch_size]
            
            # 编码查询
            q_embs = torch.cat([
                self.model.encode_query(q).cpu() 
                for q in batch_queries
            ], dim=0)
            query_embs.append(q_embs)
            
            # 编码正样本
            p_embs = torch.cat([
                self.model.encode_document(p).cpu() 
                for p in batch_positives
            ], dim=0)
            positive_embs.append(p_embs)
            
            # 编码负样本
            neg_embs = []
            for neg_list in batch_negatives:
                neg_batch = [
                    self.model.encode_document(n).cpu() 
                    for n in neg_list[:self.num_negatives]
                ]
                # Padding到统一长度
                while len(neg_batch) < self.num_negatives:
                    neg_batch.append(torch.zeros_like(neg_batch[0]))
                neg_embs.append(torch.stack(neg_batch))
            
            negative_embs_list.append(torch.stack(neg_embs))
        
        query_embs = torch.cat(query_embs, dim=0).to(self.model.encoder.device)
        positive_embs = torch.cat(positive_embs, dim=0).to(self.model.encoder.device)
        negative_embs = torch.cat(negative_embs_list, dim=0).to(self.model.encoder.device)
        
        # 计算损失
        loss = self.loss_fn(query_embs, positive_embs, negative_embs)
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

2.3 应用场景与限制

Bi-Encoder的优势在于预计算能力:文档编码可以离线完成并缓存,查询时只需编码查询向量即可计算相似度。这使其非常适合大规模文档库的检索场景。典型应用包括语义搜索、相似文档发现和聚类分析。

然而,Bi-Encoder的独立编码特性也带来了局限。由于查询和文档在编码阶段不交互,模型无法捕捉查询中特定词与文档中对应词的精确匹配关系。例如,“2024年Q3财报”这样的精确查询在Bi-Encoder中可能被当作普通语义匹配处理,而Cross-Encoder则能精确关联”2024”、“Q3”和”财报”这些关键词。

三、Cross-Encoder交叉编码器详解

3.1 架构原理

Cross-Encoder的核心创新是将查询和文档拼接后一同输入编码器,通过Self-Attention和Cross-Attention机制实现深层的交互学习。这种设计虽然失去了预计算能力,但获得了最完整的交互能力。

class CrossEncoderModel(nn.Module):
    """
    Cross-Encoder交叉编码器模型
    
    查询和文档拼接后联合编码
    """
    
    def __init__(self, model_name: str, num_labels: int = 1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.num_labels = num_labels
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_labels)
        )
    
    def forward(self, query: str, document: str) -> torch.Tensor:
        """
        前向传播
        
        返回相关性分数(回归任务)或概率分布(分类任务)
        """
        # 拼接查询和文档
        inputs = self.tokenizer(
            query,
            document,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
        outputs = self.encoder(**inputs)
        
        # 使用[CLS] token的表示
        cls_output = outputs.last_hidden_state[:, 0, :]
        
        # 分类
        logits = self.classifier(cls_output)
        
        return logits
    
    def score(self, queries: List[str], 
              documents: List[str]) -> List[float]:
        """
        批量计算相关性分数
        
        注意:此方法效率较低,适合小规模重排场景
        """
        self.model.eval()
        
        scores = []
        for query, doc in zip(queries, documents):
            with torch.no_grad():
                score = self.forward(query, doc)
                scores.append(score.item())
        
        return scores
 
class CrossEncoderWithInteraction(nn.Module):
    """
    带显式交互层的Cross-Encoder
    
    在标准Cross-Encoder基础上增加双向注意力交互
    """
    
    def __init__(self, model_name: str):
        super().__init__()
        self.query_encoder = AutoModel.from_pretrained(model_name)
        self.doc_encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        hidden_size = self.query_encoder.config.hidden_size
        
        # Cross-Attention层
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        # 输出层
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 1)
        )
    
    def forward(self, query: str, document: str) -> torch.Tensor:
        """带显式交互的前向传播"""
        # 独立编码
        q_inputs = self.tokenizer(
            query,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        )
        q_inputs = {k: v.to(self.query_encoder.device) for k, v in q_inputs.items()}
        q_outputs = self.query_encoder(**q_inputs)
        q_hidden = q_outputs.last_hidden_state  # (1, seq_len_q, hidden)
        
        d_inputs = self.tokenizer(
            document,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=384
        )
        d_inputs = {k: v.to(self.doc_encoder.device) for k, v in d_inputs.items()}
        d_outputs = self.doc_encoder(**d_inputs)
        d_hidden = d_outputs.last_hidden_state  # (1, seq_len_d, hidden)
        
        # Cross-Attention: Query关注Document
        cross_attn_output, _ = self.cross_attention(
            q_hidden, d_hidden, d_hidden
        )
        
        # 融合查询和交互表示
        combined = torch.cat([
            q_hidden[:, 0, :],  # CLS token
            cross_attn_output[:, 0, :]  # 交互后的CLS
        ], dim=-1)
        
        return self.output_layer(combined)

3.2 排序学习训练范式

Cross-Encoder的训练采用排序学习(Learning to Rank)范式,主要有点级、对级和列表级三种方法。

from typing import List, Tuple
 
class RankingLossFunctions:
    """排序损失函数集合"""
    
    @staticmethod
    def pairwise_hinge_loss(scores: torch.Tensor, 
                           labels: torch.Tensor,
                           margin: float = 1.0) -> torch.Tensor:
        """
        Pairwise Hinge Loss (RankSVM风格)
        
        强制正确文档的分数比错误文档高至少margin
        """
        batch_size = scores.size(0)
        
        # 构建文档对
        pairs = []
        for i in range(batch_size):
            for j in range(i + 1, batch_size):
                if labels[i] > labels[j]:  # i应该排在j前面
                    pairs.append((i, j))
                elif labels[j] > labels[i]:  # j应该排在i前面
                    pairs.append((j, i))
        
        if not pairs:
            return torch.tensor(0.0, device=scores.device)
        
        # 计算hinge loss
        loss = 0
        for pos_idx, neg_idx in pairs:
            diff = scores[neg_idx] - scores[pos_idx] + margin
            if diff > 0:
                loss += diff
        
        return loss / len(pairs)
    
    @staticmethod
    def listwise_softmax_loss(scores: torch.Tensor,
                              labels: torch.Tensor) -> torch.Tensor:
        """
        Listwise Softmax Loss (ListNet风格)
        
        直接优化整个列表的排序概率分布
        """
        # 将标签转换为概率分布
        labels_exp = torch.exp(labels.float())
        labels_prob = labels_exp / labels_exp.sum()
        
        # 将分数转换为概率分布
        scores_exp = torch.exp(scores.float())
        scores_prob = scores_exp / scores_exp.sum()
        
        # 交叉熵损失
        loss = -torch.sum(labels_prob * torch.log(scores_prob + 1e-10))
        
        return loss
    
    @staticmethod
    def lambdarank_loss(scores: torch.Tensor,
                       labels: torch.Tensor,
                       k: int = 10) -> torch.Tensor:
        """
        LambdaRank损失
        
        结合了pairwise学习和NDCG增益的梯度加权
        """
        batch_size = scores.size(0)
        
        # 计算DCG
        def dcg_at_k(relevance, k):
            relevance = relevance[:k]
            gains = 2 ** relevance - 1
            discounts = torch.log2(torch.arange(len(relevance), device=relevance.device) + 2)
            return torch.sum(gains / discounts)
        
        # 计算排序
        sorted_indices = torch.argsort(scores, descending=True)
        sorted_labels = labels[sorted_indices]
        
        # IDCG
        ideal_sorted = torch.sort(labels, descending=True)[0]
        idcg = dcg_at_k(ideal_sorted, min(k, len(labels)))
        
        # DCG
        dcg = dcg_at_k(sorted_labels, min(k, len(labels)))
        
        # NDCG
        ndcg = dcg / (idcg + 1e-10)
        
        # LambdaRank梯度(简化版)
        # 实际实现需要计算每个pair的lambda权重
        loss = 1 - ndcg
        
        return loss
 
class CrossEncoderTrainer:
    """Cross-Encoder训练器"""
    
    def __init__(self, model: CrossEncoderModel,
                 loss_type: str = "pairwise",
                 **loss_kwargs):
        self.model = model
        self.loss_type = loss_type
        self.loss_kwargs = loss_kwargs
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=2e-5,
            weight_decay=0.01
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=1000
        )
    
    def train_step(self, query: str, 
                   documents: List[str],
                   labels: List[float]) -> float:
        """
        单步训练
        
        Args:
            query: 查询文本
            documents: 文档列表
            labels: 相关性标签列表
        """
        self.model.train()
        
        # 编码
        scores = []
        for doc in documents:
            score = self.model(query, doc)
            scores.append(score.squeeze())
        
        scores = torch.stack(scores)
        labels = torch.tensor(labels, device=scores.device).float()
        
        # 计算损失
        if self.loss_type == "pairwise":
            loss = RankingLossFunctions.pairwise_hinge_loss(scores, labels)
        elif self.loss_type == "listwise":
            loss = RankingLossFunctions.listwise_softmax_loss(scores, labels)
        elif self.loss_type == "lambdarank":
            loss = RankingLossFunctions.lambdarank_loss(scores, labels)
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        return loss.item()

3.3 效率优化策略

Cross-Encoder的主要瓶颈在于无法预计算文档编码,每个查询-文档对都需要重新编码。以下策略可有效优化效率:

class CrossEncoderOptimizer:
    """Cross-Encoder效率优化"""
    
    @staticmethod
    def batch_processing(queries: List[str], 
                        documents: List[str],
                        batch_size: int = 8) -> List[List[float]]:
        """
        批量处理以提高GPU利用率
        
        将多个查询-文档对组成batch并行处理
        """
        scores_matrix = []
        
        for i in range(0, len(queries), batch_size):
            batch_queries = queries[i:i+batch_size]
            batch_docs = documents[i:i+batch_size]
            
            # 构建batch的tokenization
            # 这里需要模型支持批量处理
            # 实际实现依赖具体的模型架构
            
            pass
        
        return scores_matrix
    
    @staticmethod
    def cache_document_encodings(model: CrossEncoderModel,
                                documents: List[str],
                                cache_dir: str = "./cache"):
        """
        缓存文档编码(部分计算)
        
        适用于只需要cross-attention层计算的场景
        """
        import os
        os.makedirs(cache_dir, exist_ok=True)
        
        # 保存文档的tokenized表示
        tokenizer = model.tokenizer
        doc_ids = tokenizer(
            documents,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        # 保存到缓存
        torch.save(doc_ids, os.path.join(cache_dir, "doc_encodings.pt"))
    
    @staticmethod
    def precompute_token_types(model: CrossEncoderModel,
                              documents: List[str]):
        """
        预计算文档的token序列
        
        仅在查询时执行cross-attention计算
        """
        tokenizer = model.tokenizer
        doc_encoding = tokenizer(
            documents,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        
        # 使用模型的encoder预计算文档表示
        with torch.no_grad():
            doc_outputs = model.encoder(**{
                k: v for k, v in doc_encoding.items()
            })
            doc_hidden = doc_outputs.last_hidden_state
        
        return doc_hidden

四、ColBERT晚交互模型

4.1 核心设计理念

ColBERT(Contextualized Late Interaction over BERT)由斯坦福大学提出,是一种创新的检索模型。它既保留了Bi-Encoder的预计算能力,又通过延迟交互机制实现了接近Cross-Encoder的检索质量。

ColBERT的核心思想是:先独立编码查询和文档的每个token,然后在编码后执行轻量级的交互计算。这种”晚交互”策略使模型能够精确匹配查询中的每个token与文档中的相关片段,同时保持计算效率。

4.2 架构详解

class ColBERTModel(nn.Module):
    """
    ColBERT晚交互检索模型
    
    支持预计算的文档编码和高效的查询-文档交互
    """
    
    def __init__(self, model_name: str, 
                 similarity_metric: str = "cosine"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.similarity_metric = similarity_metric
        
        # ColBERT使用特殊的向量维度(通常较小以提高效率)
        self.embedding_dim = 128
        
        # 线性投影层
        self.projection = nn.Linear(
            self.encoder.config.hidden_size,
            self.embedding_dim
        )
    
    def encode_query(self, query: str) -> torch.Tensor:
        """
        编码查询
        
        返回所有token的向量表示
        """
        inputs = self.tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=32,  # ColBERT查询通常较短
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
        outputs = self.encoder(**inputs)
        
        # 获取token表示并投影
        query_vectors = self.projection(outputs.last_hidden_state)
        
        # Mask padding
        mask = inputs["attention_mask"].unsqueeze(-1).float()
        query_vectors = query_vectors * mask
        
        return query_vectors  # (1, seq_len, embedding_dim)
    
    def encode_document(self, document: str) -> torch.Tensor:
        """
        编码文档
        
        返回所有token的向量表示
        """
        inputs = self.tokenizer(
            document,
            padding=True,
            truncation=True,
            max_length=512,  # ColBERT文档可以较长
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
        outputs = self.encoder(**inputs)
        
        # 获取token表示并投影
        doc_vectors = self.projection(outputs.last_hidden_state)
        
        return doc_vectors  # (1, seq_len, embedding_dim)
    
    def late_interaction(self, query_vectors: torch.Tensor,
                        doc_vectors: torch.Tensor) -> torch.Tensor:
        """
        晚交互计算
        
        核心操作:Max-Sim
        对每个查询token,找到文档中最相似的token,计算加权和
        """
        # 归一化
        query_vectors = torch.nn.functional.normalize(
            query_vectors, p=2, dim=-1
        )
        doc_vectors = torch.nn.functional.normalize(
            doc_vectors, p=2, dim=-1
        )
        
        # 计算所有查询token与所有文档token的相似度
        # (batch, query_len, doc_len)
        similarity_matrix = torch.matmul(
            query_vectors,
            doc_vectors.transpose(-2, -1)
        )
        
        # 对每个查询token,取最大值
        # (batch, query_len)
        max_similarity = torch.max(similarity_matrix, dim=-1)[0]
        
        # 求和作为最终分数
        scores = torch.sum(max_similarity, dim=-1)
        
        return scores  # (batch,)
 
class ColBERTIndexer:
    """ColBERT文档索引器"""
    
    def __init__(self, model: ColBERTModel, 
                 index_path: str = "./colbert_index"):
        self.model = model
        self.index_path = index_path
        self.document_vectors = []
        self.doc_ids = []
    
    def index_document(self, doc_id: str, document: str):
        """索引单个文档"""
        self.model.eval()
        
        with torch.no_grad():
            doc_vectors = self.model.encode_document(document)
        
        # 保存向量
        self.document_vectors.append(doc_vectors.cpu())
        self.doc_ids.append(doc_id)
    
    def batch_index(self, documents: List[Tuple[str, str]]):
        """
        批量索引文档
        
        documents: [(doc_id, content), ...]
        """
        self.model.eval()
        
        for doc_id, content in documents:
            self.index_document(doc_id, content)
    
    def save_index(self):
        """保存索引到磁盘"""
        import os
        os.makedirs(self.index_path, exist_ok=True)
        
        # 保存文档向量
        all_vectors = torch.cat(self.document_vectors, dim=0)
        torch.save({
            'vectors': all_vectors,
            'doc_ids': self.doc_ids
        }, os.path.join(self.index_path, "doc_index.pt"))
    
    def load_index(self):
        """从磁盘加载索引"""
        import os
        checkpoint = torch.load(
            os.path.join(self.index_path, "doc_index.pt")
        )
        self.document_vectors = checkpoint['vectors']
        self.doc_ids = checkpoint['doc_ids']
 
class ColBERTRetriever:
    """ColBERT检索器"""
    
    def __init__(self, model: ColBERTModel,
                 indexer: ColBERTIndexer):
        self.model = model
        self.indexer = indexer
        self.device = model.encoder.device
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """
        检索相关文档
        
        Returns:
            [(doc_id, score), ...]
        """
        self.model.eval()
        
        # 编码查询
        with torch.no_grad():
            query_vectors = self.model.encode_query(query)
        
        # 获取文档向量
        doc_vectors = self.indexer.document_vectors.to(self.device)
        
        # 计算晚交互分数
        batch_size = 32
        all_scores = []
        
        for i in range(0, len(doc_vectors), batch_size):
            batch_docs = doc_vectors[i:i+batch_size]
            
            with torch.no_grad():
                scores = self.model.late_interaction(query_vectors, batch_docs)
            
            all_scores.extend(scores.cpu().tolist())
        
        # 排序
        doc_scores = list(zip(self.indexer.doc_ids, all_scores))
        doc_scores.sort(key=lambda x: x[1], reverse=True)
        
        return doc_scores[:top_k]

4.3 ColBERTv2改进

ColBERTv2在原始ColBERT基础上进行了多项改进,包括更强的蒸馏训练、更紧凑的向量表示和更高效的检索算法。

class ColBERTv2Model(ColBERTModel):
    """
    ColBERTv2改进版
    
    改进点:
    1. 使用蒸馏学习提升质量
    2. 向量量化减少存储
    3. 改进的评分机制
    """
    
    def __init__(self, model_name: str, 
                 vector_bits: int = 128):
        super().__init__(model_name)
        self.vector_bits = vector_bits
        self.embedding_dim = vector_bits  # 使用位表示
    
    def quantize_vector(self, vector: torch.Tensor) -> torch.Tensor:
        """
        将向量量化到指定位数
        
        使用PQ (Product Quantization) 的简化版本
        """
        # 简化为二值量化
        # 实际ColBERTv2使用更复杂的量化方法
        return (vector > 0).float()
    
    def encode_document(self, document: str) -> torch.Tensor:
        """编码文档(带量化)"""
        doc_vectors = super().encode_document(document)
        quantized = self.quantize_vector(doc_vectors)
        return quantized
    
    def late_interaction(self, query_vectors: torch.Tensor,
                        doc_vectors: torch.Tensor) -> torch.Tensor:
        """
        晚交互计算(针对量化向量优化)
        
        使用汉明距离代替余弦相似度
        """
        # 二值向量的点积等价于汉明距离
        # (a XOR b).sum() = (a != b).sum()
        # a · b = n - (a XOR b).sum()
        
        # 展开二值向量
        query_expanded = query_vectors.unsqueeze(-1)
        doc_expanded = doc_vectors.unsqueeze(-2)
        
        # 计算汉明距离
        hamming_dist = (query_expanded != doc_expanded).float().sum(dim=-1)
        
        # 转换为相似度分数
        vector_dim = query_vectors.size(-1)
        similarity = vector_dim - hamming_dist
        
        # Max-Sum操作
        max_sim = torch.max(similarity, dim=-1)[0]
        scores = torch.sum(max_sim, dim=-1)
        
        return scores

五、重排策略与集成

5.1 多阶段重排架构

在实际系统中,重排通常采用多阶段级联架构:向量检索产生候选集 → 轻量级重排筛选 → 重量级重排精细排序。

class MultiStageReranker:
    """多阶段重排系统"""
    
    def __init__(self, 
                 vector_retriever,  # Bi-Encoder-based
                 light_reranker,     # MiniLM-based Cross-Encoder
                 heavy_reranker):    # BGE-large Cross-Encoder
        self.vector_retriever = vector_retriever
        self.light_reranker = light_reranker
        self.heavy_reranker = heavy_reranker
        
        # 阶段配置
        self.stage1_top_k = 100   # 向量检索返回数
        self.stage2_top_k = 20    # 轻量重排返回数
        self.stage3_top_k = 10    # 重量重排返回数
    
    def rerank(self, query: str) -> List[Dict]:
        """
        执行多阶段重排
        
        Returns:
            [{"doc_id": str, "content": str, "score": float}, ...]
        """
        # 阶段1: 向量检索
        stage1_results = self.vector_retriever.search(
            query, top_k=self.stage1_top_k
        )
        
        if not stage1_results:
            return []
        
        # 阶段2: 轻量级重排
        stage2_docs = [r["content"] for r in stage1_results]
        stage2_scores = self.light_reranker.score_batch(
            [query] * len(stage2_docs),
            stage2_docs
        )
        
        # 排序并截断
        for i, r in enumerate(stage1_results):
            r["light_score"] = stage2_scores[i]
        
        stage2_results = sorted(
            stage1_results,
            key=lambda x: x["light_score"],
            reverse=True
        )[:self.stage2_top_k]
        
        # 阶段3: 重量级重排
        stage3_docs = [r["content"] for r in stage2_results]
        stage3_scores = self.heavy_reranker.score_batch(
            [query] * len(stage3_docs),
            stage3_docs
        )
        
        # 最终排序
        for i, r in enumerate(stage2_results):
            r["heavy_score"] = stage3_scores[i]
            r["final_score"] = stage3_scores[i]
        
        final_results = sorted(
            stage2_results,
            key=lambda x: x["final_score"],
            reverse=True
        )[:self.stage3_top_k]
        
        return final_results
 
class HybridReranker:
    """
    混合重排器
    
    结合多种重排信号:向量相似度、关键词匹配、文档质量等
    """
    
    def __init__(self, weights: Optional[Dict[str, float]] = None):
        # 默认权重
        self.weights = weights or {
            'vector': 0.3,
            'keyword': 0.3,
            'reranker': 0.4
        }
    
    def combine_scores(self, 
                      vector_score: float,
                      keyword_score: float,
                      reranker_score: float) -> float:
        """综合多维度分数"""
        combined = (
            self.weights['vector'] * vector_score +
            self.weights['keyword'] * keyword_score +
            self.weights['reranker'] * reranker_score
        )
        return combined
    
    def learn_weights(self, training_data: List[Dict], 
                     relevance_labels: List[float]):
        """
        从训练数据中学习最优权重
        
        使用回归方法优化权重
        """
        from sklearn.linear_model import Ridge
        
        # 准备特征
        X = []
        for item in training_data:
            X.append([
                item.get('vector_score', 0),
                item.get('keyword_score', 0),
                item.get('reranker_score', 0)
            ])
        
        # 回归学习
        model = Ridge(alpha=1.0)
        model.fit(X, relevance_labels)
        
        # 归一化权重
        coefs = model.coef_
        weights = np.abs(coefs) / np.abs(coefs).sum()
        
        self.weights = {
            'vector': float(weights[0]),
            'keyword': float(weights[1]),
            'reranker': float(weights[2])
        }
        
        return self.weights

5.2 训练数据构建

高质量的训练数据是重排模型性能的关键。以下方法可用于构建训练数据:

class TrainingDataBuilder:
    """重排模型训练数据构建"""
    
    @staticmethod
    def build_from_click_data(click_logs: List[Dict]) -> List[Tuple[str, str, int]]:
        """
        从点击日志构建训练数据
        
        使用点击数据作为隐式相关性标签
        """
        training_data = []
        
        for log in click_logs:
            query = log["query"]
            doc_id = log["doc_id"]
            doc_content = log["content"]
            clicks = log.get("clicks", 0)
            impressions = log.get("impressions", 1)
            
            # 使用CTR作为隐式标签
            ctr = clicks / impressions if impressions > 0 else 0
            
            # 二值化标签(可调整阈值)
            label = 1 if ctr > 0.1 else 0
            
            training_data.append((query, doc_content, label))
        
        return training_data
    
    @staticmethod
    def build_from_explicit_labels(annotations: List[Dict]) -> List[Tuple[str, str, int]]:
        """
        从人工标注构建训练数据
        
        使用显式相关性标签
        """
        training_data = []
        
        for ann in annotations:
            query = ann["query"]
            doc_id = ann["doc_id"]
            doc_content = ann["content"]
            relevance = ann["relevance"]  # 0, 1, 2, 3, 4 等级别
            
            training_data.append((query, doc_content, relevance))
        
        return training_data
    
    @staticmethod
    def mine_hard_negatives(positive_docs: List[str],
                           all_docs: List[str],
                           bi_encoder) -> List[List[str]]:
        """
        挖掘难负例
        
        难负例:与正样本相似但实际不相关的文档
        """
        hard_negatives = []
        
        for pos_doc in positive_docs:
            # 编码正样本
            pos_emb = bi_encoder.encode_document(pos_doc)
            
            # 编码所有候选文档
            candidates = [d for d in all_docs if d != pos_doc]
            candidate_embs = torch.cat([
                bi_encoder.encode_document(c).cpu() 
                for c in candidates
            ], dim=0)
            
            # 计算相似度
            pos_emb_norm = torch.nn.functional.normalize(pos_emb, p=2, dim=1)
            candidate_embs_norm = torch.nn.functional.normalize(candidate_embs, p=2, dim=1)
            
            similarities = torch.sum(pos_emb_norm * candidate_embs_norm, dim=-1)
            
            # 选择相似度适中的作为难负例(避免选择太相似或太不同的)
            threshold_high = 0.9
            threshold_low = 0.5
            
            hard_neg = []
            for i, sim in enumerate(similarities):
                if threshold_low < sim < threshold_high:
                    hard_neg.append(candidates[i])
                    if len(hard_neg) >= 5:
                        break
            
            hard_negatives.append(hard_neg)
        
        return hard_negatives

六、实战部署指南

6.1 模型选择建议

场景推荐模型说明
高频检索Bi-Encoder支持预计算,延迟最低
精确重排Cross-Encoder (BGE-reranker)精度最高,适合小规模候选集
平衡方案ColBERT质量和效率的折中
实时交互MiniLM-Cross-Encoder轻量级,适合延迟敏感场景

6.2 部署配置示例

# 模型加载配置
RERANKER_CONFIGS = {
    # 轻量级(适合初次重排)
    "light": {
        "model_name": "cross-encoder/ms-marco-MiniLM-L-6-v2",
        "max_length": 512,
        "batch_size": 32
    },
    
    # 中量级(适合二次重排)
    "medium": {
        "model_name": "BAAI/bge-reranker-base",
        "max_length": 512,
        "batch_size": 16
    },
    
    # 重量级(适合最终重排)
    "heavy": {
        "model_name": "BAAI/bge-reranker-large",
        "max_length": 512,
        "batch_size": 4
    }
}
 
def load_reranker(config_name: str = "medium"):
    """加载重排模型"""
    from sentence_transformers import CrossEncoder
    
    config = RERANKER_CONFIGS[config_name]
    
    model = CrossEncoder(
        config["model_name"],
        max_length=config["max_length"],
        device="cuda"  # 或 "cpu"
    )
    
    return model

七、相关主题链接


更新日志

  • 2026-04-18: 初始版本完成
  • 包含Cross-Encoder、Bi-Encoder、ColBERT的详细解析
  • 提供完整的训练和部署代码