重排技术深度指南
摘要
重排技术是现代RAG系统中的关键环节,通过对初检结果进行精细化排序,显著提升检索质量。本指南深入解析Cross-Encoder、Bi-Encoder和ColBERT三种主流重排架构的技术原理、数学推导和工程实现,为构建高性能检索系统提供系统性指导。
关键词速查表
| 关键词 | 说明 |
|---|---|
| Cross-Encoder | 交叉编码器,将查询和文档同时输入产生联合表示 |
| Bi-Encoder | 双编码器,分别编码查询和文档再计算相似度 |
| ColBERT | 延迟交互模型,通过延迟最大化实现高效重排 |
| Late Interaction | 晚交互,在编码后进行交互以平衡效率和效果 |
| Pointwise Ranking | 点级排序,独立评估每个文档的相关性 |
| Pairwise Ranking | 对级排序,比较文档对的相对顺序 |
| Listwise Ranking | 列表级排序,考虑整个文档列表的排序 |
| 注意机制 | Self-Attention和Cross-Attention的统称 |
| 蒸馏学习 | Knowledge Distillation,从大模型迁移知识 |
| 微调策略 | Fine-tuning策略,包括全量微调和PEFT |
一、重排技术的定位与价值
1.1 检索系统架构演进
现代检索系统通常采用两阶段甚至多阶段架构:粗排阶段使用轻量级模型快速筛选候选文档,精排阶段使用重排模型进行精细排序。这种分层设计在效果和效率之间取得平衡,使系统能够处理海量文档的同时保证排序质量。
两阶段架构的数学表达如下:设查询为 ,文档集合为 ,初排模型 产生候选集 ,重排模型 对 进行最终排序。理想情况下, 比 更精确,但计算成本更高,因此只在较小的候选集上运行。
1.2 重排的核心价值
重排技术解决了向量检索的固有局限。纯向量检索依赖全局语义匹配,可能遗漏精确的关键词匹配或忽略查询中的特定约束。重排模型通过更深层的语义理解,能够捕捉查询意图的细微差别,识别文档中的关键段落,并在全局视角下进行最优排序。
例如,当用户查询”Python如何处理并发”时,向量检索可能返回涵盖”Python异步编程”和”Python多线程”的文档。重排模型能够进一步分析这些文档中与”并发”相关的具体内容权重,从而给出更精准的排序。
1.3 重排技术分类
根据交互方式的不同,重排技术可分为三类:Cross-Encoder提供最完整的交互但计算成本最高;Bi-Encoder计算效率高但交互能力有限;ColBERT等晚交互模型则在两者之间取得平衡。理解这些技术的设计权衡对于正确选择和部署至关重要。
二、Bi-Encoder双编码器详解
2.1 架构原理
Bi-Encoder采用双塔结构,分别独立编码查询和文档。这种设计允许预先计算文档编码,实现高效的批量检索,但代价是查询和文档之间缺乏早期交互,可能丢失细粒度的语义关联。
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from typing import List, Tuple, Optional
class BiEncoderModel(nn.Module):
"""
Bi-Encoder双编码器模型
查询和文档使用共享的编码器参数,但独立编码
"""
def __init__(self, model_name: str, embedding_dim: int = 768,
pooling_strategy: str = "cls"):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.embedding_dim = embedding_dim
self.pooling_strategy = pooling_strategy
# 投影层(可选)
self.projection = nn.Linear(
self.encoder.config.hidden_size,
embedding_dim
) if embedding_dim != self.encoder.config.hidden_size else nn.Identity()
def encode_query(self, query: str) -> torch.Tensor:
"""编码查询"""
inputs = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt"
)
inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
outputs = self.encoder(**inputs)
return self._pooling(outputs.last_hidden_state, inputs["attention_mask"])
def encode_document(self, document: str, title: Optional[str] = None) -> torch.Tensor:
"""
编码文档
可选择性地包含标题信息
"""
if title:
text = f"{title}: {document}"
else:
text = document
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
outputs = self.encoder(**inputs)
return self._pooling(outputs.last_hidden_state, inputs["attention_mask"])
def _pooling(self, hidden_states: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""
池化策略
支持: cls, mean, max, weighted
"""
if self.pooling_strategy == "cls":
# 使用[CLS] token
return hidden_states[:, 0, :]
elif self.pooling_strategy == "mean":
# Mean Pooling
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
sum_mask = mask_expanded.sum(dim=1)
return sum_embeddings / sum_mask.clamp(min=1e-9)
elif self.pooling_strategy == "max":
# Max Pooling
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
hidden_states[mask_expanded == 0] = -1e9
return torch.max(hidden_states, dim=1)[0]
elif self.pooling_strategy == "weighted":
# Attention-weighted Pooling
attention_weights = torch.softmax(hidden_states, dim=1)
return torch.sum(hidden_states * attention_weights, dim=1)
raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")
def compute_similarity(self, query_emb: torch.Tensor,
doc_emb: torch.Tensor) -> torch.Tensor:
"""计算查询-文档相似度"""
# 归一化后计算点积(等价于余弦相似度)
query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=1)
doc_emb = torch.nn.functional.normalize(doc_emb, p=2, dim=1)
return torch.sum(query_emb * doc_emb, dim=-1)2.2 对比学习训练
Bi-Encoder的训练通常采用对比学习方法,通过最大化正样本对相似度、最小化负样本对相似度来学习有效的表示空间。
class ContrastiveLoss(nn.Module):
"""
对比损失函数
支持难负例挖掘和温度系数调节
"""
def __init__(self, temperature: float = 0.07,
use_hard_negatives: bool = True):
super().__init__()
self.temperature = temperature
self.use_hard_negatives = use_hard_negatives
def forward(self, query_emb: torch.Tensor,
positive_emb: torch.Tensor,
negative_embs: torch.Tensor) -> torch.Tensor:
"""
计算对比损失
Args:
query_emb: 查询嵌入 (batch_size, embedding_dim)
positive_emb: 正样本嵌入 (batch_size, embedding_dim)
negative_embs: 负样本嵌入 (batch_size, num_negatives, embedding_dim)
"""
# 归一化
query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=1)
positive_emb = torch.nn.functional.normalize(positive_emb, p=2, dim=1)
negative_embs = torch.nn.functional.normalize(negative_embs, p=2, dim=-1)
# 计算正样本相似度
positive_sim = torch.sum(query_emb * positive_emb, dim=-1) / self.temperature
# 计算负样本相似度
batch_size = query_emb.size(0)
num_negatives = negative_embs.size(1)
# (batch_size, num_negatives)
negative_sim = torch.bmm(
negative_embs,
query_emb.unsqueeze(-1)
).squeeze(-1) / self.temperature
# InfoNCE损失
logits = torch.cat([positive_sim.unsqueeze(-1), negative_sim], dim=-1)
labels = torch.zeros(batch_size, dtype=torch.long, device=query_emb.device)
loss = nn.CrossEntropyLoss()(logits, labels)
return loss
class BiEncoderTrainer:
"""Bi-Encoder训练器"""
def __init__(self, model: BiEncoderModel,
learning_rate: float = 2e-5,
batch_size: int = 32,
num_negatives: int = 4):
self.model = model
self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
self.batch_size = batch_size
self.num_negatives = num_negatives
self.loss_fn = ContrastiveLoss(temperature=0.07)
def train_step(self, queries: List[str],
positives: List[str],
negatives: List[List[str]]) -> float:
"""单步训练"""
self.model.train()
# 编码
query_embs = []
positive_embs = []
negative_embs_list = []
for i in range(0, len(queries), self.batch_size):
batch_queries = queries[i:i+self.batch_size]
batch_positives = positives[i:i+self.batch_size]
batch_negatives = negatives[i:i+self.batch_size]
# 编码查询
q_embs = torch.cat([
self.model.encode_query(q).cpu()
for q in batch_queries
], dim=0)
query_embs.append(q_embs)
# 编码正样本
p_embs = torch.cat([
self.model.encode_document(p).cpu()
for p in batch_positives
], dim=0)
positive_embs.append(p_embs)
# 编码负样本
neg_embs = []
for neg_list in batch_negatives:
neg_batch = [
self.model.encode_document(n).cpu()
for n in neg_list[:self.num_negatives]
]
# Padding到统一长度
while len(neg_batch) < self.num_negatives:
neg_batch.append(torch.zeros_like(neg_batch[0]))
neg_embs.append(torch.stack(neg_batch))
negative_embs_list.append(torch.stack(neg_embs))
query_embs = torch.cat(query_embs, dim=0).to(self.model.encoder.device)
positive_embs = torch.cat(positive_embs, dim=0).to(self.model.encoder.device)
negative_embs = torch.cat(negative_embs_list, dim=0).to(self.model.encoder.device)
# 计算损失
loss = self.loss_fn(query_embs, positive_embs, negative_embs)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()2.3 应用场景与限制
Bi-Encoder的优势在于预计算能力:文档编码可以离线完成并缓存,查询时只需编码查询向量即可计算相似度。这使其非常适合大规模文档库的检索场景。典型应用包括语义搜索、相似文档发现和聚类分析。
然而,Bi-Encoder的独立编码特性也带来了局限。由于查询和文档在编码阶段不交互,模型无法捕捉查询中特定词与文档中对应词的精确匹配关系。例如,“2024年Q3财报”这样的精确查询在Bi-Encoder中可能被当作普通语义匹配处理,而Cross-Encoder则能精确关联”2024”、“Q3”和”财报”这些关键词。
三、Cross-Encoder交叉编码器详解
3.1 架构原理
Cross-Encoder的核心创新是将查询和文档拼接后一同输入编码器,通过Self-Attention和Cross-Attention机制实现深层的交互学习。这种设计虽然失去了预计算能力,但获得了最完整的交互能力。
class CrossEncoderModel(nn.Module):
"""
Cross-Encoder交叉编码器模型
查询和文档拼接后联合编码
"""
def __init__(self, model_name: str, num_labels: int = 1):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.num_labels = num_labels
# 分类头
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(self.encoder.config.hidden_size, 256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, num_labels)
)
def forward(self, query: str, document: str) -> torch.Tensor:
"""
前向传播
返回相关性分数(回归任务)或概率分布(分类任务)
"""
# 拼接查询和文档
inputs = self.tokenizer(
query,
document,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
outputs = self.encoder(**inputs)
# 使用[CLS] token的表示
cls_output = outputs.last_hidden_state[:, 0, :]
# 分类
logits = self.classifier(cls_output)
return logits
def score(self, queries: List[str],
documents: List[str]) -> List[float]:
"""
批量计算相关性分数
注意:此方法效率较低,适合小规模重排场景
"""
self.model.eval()
scores = []
for query, doc in zip(queries, documents):
with torch.no_grad():
score = self.forward(query, doc)
scores.append(score.item())
return scores
class CrossEncoderWithInteraction(nn.Module):
"""
带显式交互层的Cross-Encoder
在标准Cross-Encoder基础上增加双向注意力交互
"""
def __init__(self, model_name: str):
super().__init__()
self.query_encoder = AutoModel.from_pretrained(model_name)
self.doc_encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
hidden_size = self.query_encoder.config.hidden_size
# Cross-Attention层
self.cross_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=8,
dropout=0.1,
batch_first=True
)
# 输出层
self.output_layer = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, 1)
)
def forward(self, query: str, document: str) -> torch.Tensor:
"""带显式交互的前向传播"""
# 独立编码
q_inputs = self.tokenizer(
query,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
q_inputs = {k: v.to(self.query_encoder.device) for k, v in q_inputs.items()}
q_outputs = self.query_encoder(**q_inputs)
q_hidden = q_outputs.last_hidden_state # (1, seq_len_q, hidden)
d_inputs = self.tokenizer(
document,
return_tensors="pt",
padding=True,
truncation=True,
max_length=384
)
d_inputs = {k: v.to(self.doc_encoder.device) for k, v in d_inputs.items()}
d_outputs = self.doc_encoder(**d_inputs)
d_hidden = d_outputs.last_hidden_state # (1, seq_len_d, hidden)
# Cross-Attention: Query关注Document
cross_attn_output, _ = self.cross_attention(
q_hidden, d_hidden, d_hidden
)
# 融合查询和交互表示
combined = torch.cat([
q_hidden[:, 0, :], # CLS token
cross_attn_output[:, 0, :] # 交互后的CLS
], dim=-1)
return self.output_layer(combined)3.2 排序学习训练范式
Cross-Encoder的训练采用排序学习(Learning to Rank)范式,主要有点级、对级和列表级三种方法。
from typing import List, Tuple
class RankingLossFunctions:
"""排序损失函数集合"""
@staticmethod
def pairwise_hinge_loss(scores: torch.Tensor,
labels: torch.Tensor,
margin: float = 1.0) -> torch.Tensor:
"""
Pairwise Hinge Loss (RankSVM风格)
强制正确文档的分数比错误文档高至少margin
"""
batch_size = scores.size(0)
# 构建文档对
pairs = []
for i in range(batch_size):
for j in range(i + 1, batch_size):
if labels[i] > labels[j]: # i应该排在j前面
pairs.append((i, j))
elif labels[j] > labels[i]: # j应该排在i前面
pairs.append((j, i))
if not pairs:
return torch.tensor(0.0, device=scores.device)
# 计算hinge loss
loss = 0
for pos_idx, neg_idx in pairs:
diff = scores[neg_idx] - scores[pos_idx] + margin
if diff > 0:
loss += diff
return loss / len(pairs)
@staticmethod
def listwise_softmax_loss(scores: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""
Listwise Softmax Loss (ListNet风格)
直接优化整个列表的排序概率分布
"""
# 将标签转换为概率分布
labels_exp = torch.exp(labels.float())
labels_prob = labels_exp / labels_exp.sum()
# 将分数转换为概率分布
scores_exp = torch.exp(scores.float())
scores_prob = scores_exp / scores_exp.sum()
# 交叉熵损失
loss = -torch.sum(labels_prob * torch.log(scores_prob + 1e-10))
return loss
@staticmethod
def lambdarank_loss(scores: torch.Tensor,
labels: torch.Tensor,
k: int = 10) -> torch.Tensor:
"""
LambdaRank损失
结合了pairwise学习和NDCG增益的梯度加权
"""
batch_size = scores.size(0)
# 计算DCG
def dcg_at_k(relevance, k):
relevance = relevance[:k]
gains = 2 ** relevance - 1
discounts = torch.log2(torch.arange(len(relevance), device=relevance.device) + 2)
return torch.sum(gains / discounts)
# 计算排序
sorted_indices = torch.argsort(scores, descending=True)
sorted_labels = labels[sorted_indices]
# IDCG
ideal_sorted = torch.sort(labels, descending=True)[0]
idcg = dcg_at_k(ideal_sorted, min(k, len(labels)))
# DCG
dcg = dcg_at_k(sorted_labels, min(k, len(labels)))
# NDCG
ndcg = dcg / (idcg + 1e-10)
# LambdaRank梯度(简化版)
# 实际实现需要计算每个pair的lambda权重
loss = 1 - ndcg
return loss
class CrossEncoderTrainer:
"""Cross-Encoder训练器"""
def __init__(self, model: CrossEncoderModel,
loss_type: str = "pairwise",
**loss_kwargs):
self.model = model
self.loss_type = loss_type
self.loss_kwargs = loss_kwargs
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=2e-5,
weight_decay=0.01
)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=1000
)
def train_step(self, query: str,
documents: List[str],
labels: List[float]) -> float:
"""
单步训练
Args:
query: 查询文本
documents: 文档列表
labels: 相关性标签列表
"""
self.model.train()
# 编码
scores = []
for doc in documents:
score = self.model(query, doc)
scores.append(score.squeeze())
scores = torch.stack(scores)
labels = torch.tensor(labels, device=scores.device).float()
# 计算损失
if self.loss_type == "pairwise":
loss = RankingLossFunctions.pairwise_hinge_loss(scores, labels)
elif self.loss_type == "listwise":
loss = RankingLossFunctions.listwise_softmax_loss(scores, labels)
elif self.loss_type == "lambdarank":
loss = RankingLossFunctions.lambdarank_loss(scores, labels)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
# 反向传播
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
return loss.item()3.3 效率优化策略
Cross-Encoder的主要瓶颈在于无法预计算文档编码,每个查询-文档对都需要重新编码。以下策略可有效优化效率:
class CrossEncoderOptimizer:
"""Cross-Encoder效率优化"""
@staticmethod
def batch_processing(queries: List[str],
documents: List[str],
batch_size: int = 8) -> List[List[float]]:
"""
批量处理以提高GPU利用率
将多个查询-文档对组成batch并行处理
"""
scores_matrix = []
for i in range(0, len(queries), batch_size):
batch_queries = queries[i:i+batch_size]
batch_docs = documents[i:i+batch_size]
# 构建batch的tokenization
# 这里需要模型支持批量处理
# 实际实现依赖具体的模型架构
pass
return scores_matrix
@staticmethod
def cache_document_encodings(model: CrossEncoderModel,
documents: List[str],
cache_dir: str = "./cache"):
"""
缓存文档编码(部分计算)
适用于只需要cross-attention层计算的场景
"""
import os
os.makedirs(cache_dir, exist_ok=True)
# 保存文档的tokenized表示
tokenizer = model.tokenizer
doc_ids = tokenizer(
documents,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# 保存到缓存
torch.save(doc_ids, os.path.join(cache_dir, "doc_encodings.pt"))
@staticmethod
def precompute_token_types(model: CrossEncoderModel,
documents: List[str]):
"""
预计算文档的token序列
仅在查询时执行cross-attention计算
"""
tokenizer = model.tokenizer
doc_encoding = tokenizer(
documents,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# 使用模型的encoder预计算文档表示
with torch.no_grad():
doc_outputs = model.encoder(**{
k: v for k, v in doc_encoding.items()
})
doc_hidden = doc_outputs.last_hidden_state
return doc_hidden四、ColBERT晚交互模型
4.1 核心设计理念
ColBERT(Contextualized Late Interaction over BERT)由斯坦福大学提出,是一种创新的检索模型。它既保留了Bi-Encoder的预计算能力,又通过延迟交互机制实现了接近Cross-Encoder的检索质量。
ColBERT的核心思想是:先独立编码查询和文档的每个token,然后在编码后执行轻量级的交互计算。这种”晚交互”策略使模型能够精确匹配查询中的每个token与文档中的相关片段,同时保持计算效率。
4.2 架构详解
class ColBERTModel(nn.Module):
"""
ColBERT晚交互检索模型
支持预计算的文档编码和高效的查询-文档交互
"""
def __init__(self, model_name: str,
similarity_metric: str = "cosine"):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.similarity_metric = similarity_metric
# ColBERT使用特殊的向量维度(通常较小以提高效率)
self.embedding_dim = 128
# 线性投影层
self.projection = nn.Linear(
self.encoder.config.hidden_size,
self.embedding_dim
)
def encode_query(self, query: str) -> torch.Tensor:
"""
编码查询
返回所有token的向量表示
"""
inputs = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=32, # ColBERT查询通常较短
return_tensors="pt"
)
inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
outputs = self.encoder(**inputs)
# 获取token表示并投影
query_vectors = self.projection(outputs.last_hidden_state)
# Mask padding
mask = inputs["attention_mask"].unsqueeze(-1).float()
query_vectors = query_vectors * mask
return query_vectors # (1, seq_len, embedding_dim)
def encode_document(self, document: str) -> torch.Tensor:
"""
编码文档
返回所有token的向量表示
"""
inputs = self.tokenizer(
document,
padding=True,
truncation=True,
max_length=512, # ColBERT文档可以较长
return_tensors="pt"
)
inputs = {k: v.to(self.encoder.device) for k, v in inputs.items()}
outputs = self.encoder(**inputs)
# 获取token表示并投影
doc_vectors = self.projection(outputs.last_hidden_state)
return doc_vectors # (1, seq_len, embedding_dim)
def late_interaction(self, query_vectors: torch.Tensor,
doc_vectors: torch.Tensor) -> torch.Tensor:
"""
晚交互计算
核心操作:Max-Sim
对每个查询token,找到文档中最相似的token,计算加权和
"""
# 归一化
query_vectors = torch.nn.functional.normalize(
query_vectors, p=2, dim=-1
)
doc_vectors = torch.nn.functional.normalize(
doc_vectors, p=2, dim=-1
)
# 计算所有查询token与所有文档token的相似度
# (batch, query_len, doc_len)
similarity_matrix = torch.matmul(
query_vectors,
doc_vectors.transpose(-2, -1)
)
# 对每个查询token,取最大值
# (batch, query_len)
max_similarity = torch.max(similarity_matrix, dim=-1)[0]
# 求和作为最终分数
scores = torch.sum(max_similarity, dim=-1)
return scores # (batch,)
class ColBERTIndexer:
"""ColBERT文档索引器"""
def __init__(self, model: ColBERTModel,
index_path: str = "./colbert_index"):
self.model = model
self.index_path = index_path
self.document_vectors = []
self.doc_ids = []
def index_document(self, doc_id: str, document: str):
"""索引单个文档"""
self.model.eval()
with torch.no_grad():
doc_vectors = self.model.encode_document(document)
# 保存向量
self.document_vectors.append(doc_vectors.cpu())
self.doc_ids.append(doc_id)
def batch_index(self, documents: List[Tuple[str, str]]):
"""
批量索引文档
documents: [(doc_id, content), ...]
"""
self.model.eval()
for doc_id, content in documents:
self.index_document(doc_id, content)
def save_index(self):
"""保存索引到磁盘"""
import os
os.makedirs(self.index_path, exist_ok=True)
# 保存文档向量
all_vectors = torch.cat(self.document_vectors, dim=0)
torch.save({
'vectors': all_vectors,
'doc_ids': self.doc_ids
}, os.path.join(self.index_path, "doc_index.pt"))
def load_index(self):
"""从磁盘加载索引"""
import os
checkpoint = torch.load(
os.path.join(self.index_path, "doc_index.pt")
)
self.document_vectors = checkpoint['vectors']
self.doc_ids = checkpoint['doc_ids']
class ColBERTRetriever:
"""ColBERT检索器"""
def __init__(self, model: ColBERTModel,
indexer: ColBERTIndexer):
self.model = model
self.indexer = indexer
self.device = model.encoder.device
def retrieve(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
"""
检索相关文档
Returns:
[(doc_id, score), ...]
"""
self.model.eval()
# 编码查询
with torch.no_grad():
query_vectors = self.model.encode_query(query)
# 获取文档向量
doc_vectors = self.indexer.document_vectors.to(self.device)
# 计算晚交互分数
batch_size = 32
all_scores = []
for i in range(0, len(doc_vectors), batch_size):
batch_docs = doc_vectors[i:i+batch_size]
with torch.no_grad():
scores = self.model.late_interaction(query_vectors, batch_docs)
all_scores.extend(scores.cpu().tolist())
# 排序
doc_scores = list(zip(self.indexer.doc_ids, all_scores))
doc_scores.sort(key=lambda x: x[1], reverse=True)
return doc_scores[:top_k]4.3 ColBERTv2改进
ColBERTv2在原始ColBERT基础上进行了多项改进,包括更强的蒸馏训练、更紧凑的向量表示和更高效的检索算法。
class ColBERTv2Model(ColBERTModel):
"""
ColBERTv2改进版
改进点:
1. 使用蒸馏学习提升质量
2. 向量量化减少存储
3. 改进的评分机制
"""
def __init__(self, model_name: str,
vector_bits: int = 128):
super().__init__(model_name)
self.vector_bits = vector_bits
self.embedding_dim = vector_bits # 使用位表示
def quantize_vector(self, vector: torch.Tensor) -> torch.Tensor:
"""
将向量量化到指定位数
使用PQ (Product Quantization) 的简化版本
"""
# 简化为二值量化
# 实际ColBERTv2使用更复杂的量化方法
return (vector > 0).float()
def encode_document(self, document: str) -> torch.Tensor:
"""编码文档(带量化)"""
doc_vectors = super().encode_document(document)
quantized = self.quantize_vector(doc_vectors)
return quantized
def late_interaction(self, query_vectors: torch.Tensor,
doc_vectors: torch.Tensor) -> torch.Tensor:
"""
晚交互计算(针对量化向量优化)
使用汉明距离代替余弦相似度
"""
# 二值向量的点积等价于汉明距离
# (a XOR b).sum() = (a != b).sum()
# a · b = n - (a XOR b).sum()
# 展开二值向量
query_expanded = query_vectors.unsqueeze(-1)
doc_expanded = doc_vectors.unsqueeze(-2)
# 计算汉明距离
hamming_dist = (query_expanded != doc_expanded).float().sum(dim=-1)
# 转换为相似度分数
vector_dim = query_vectors.size(-1)
similarity = vector_dim - hamming_dist
# Max-Sum操作
max_sim = torch.max(similarity, dim=-1)[0]
scores = torch.sum(max_sim, dim=-1)
return scores五、重排策略与集成
5.1 多阶段重排架构
在实际系统中,重排通常采用多阶段级联架构:向量检索产生候选集 → 轻量级重排筛选 → 重量级重排精细排序。
class MultiStageReranker:
"""多阶段重排系统"""
def __init__(self,
vector_retriever, # Bi-Encoder-based
light_reranker, # MiniLM-based Cross-Encoder
heavy_reranker): # BGE-large Cross-Encoder
self.vector_retriever = vector_retriever
self.light_reranker = light_reranker
self.heavy_reranker = heavy_reranker
# 阶段配置
self.stage1_top_k = 100 # 向量检索返回数
self.stage2_top_k = 20 # 轻量重排返回数
self.stage3_top_k = 10 # 重量重排返回数
def rerank(self, query: str) -> List[Dict]:
"""
执行多阶段重排
Returns:
[{"doc_id": str, "content": str, "score": float}, ...]
"""
# 阶段1: 向量检索
stage1_results = self.vector_retriever.search(
query, top_k=self.stage1_top_k
)
if not stage1_results:
return []
# 阶段2: 轻量级重排
stage2_docs = [r["content"] for r in stage1_results]
stage2_scores = self.light_reranker.score_batch(
[query] * len(stage2_docs),
stage2_docs
)
# 排序并截断
for i, r in enumerate(stage1_results):
r["light_score"] = stage2_scores[i]
stage2_results = sorted(
stage1_results,
key=lambda x: x["light_score"],
reverse=True
)[:self.stage2_top_k]
# 阶段3: 重量级重排
stage3_docs = [r["content"] for r in stage2_results]
stage3_scores = self.heavy_reranker.score_batch(
[query] * len(stage3_docs),
stage3_docs
)
# 最终排序
for i, r in enumerate(stage2_results):
r["heavy_score"] = stage3_scores[i]
r["final_score"] = stage3_scores[i]
final_results = sorted(
stage2_results,
key=lambda x: x["final_score"],
reverse=True
)[:self.stage3_top_k]
return final_results
class HybridReranker:
"""
混合重排器
结合多种重排信号:向量相似度、关键词匹配、文档质量等
"""
def __init__(self, weights: Optional[Dict[str, float]] = None):
# 默认权重
self.weights = weights or {
'vector': 0.3,
'keyword': 0.3,
'reranker': 0.4
}
def combine_scores(self,
vector_score: float,
keyword_score: float,
reranker_score: float) -> float:
"""综合多维度分数"""
combined = (
self.weights['vector'] * vector_score +
self.weights['keyword'] * keyword_score +
self.weights['reranker'] * reranker_score
)
return combined
def learn_weights(self, training_data: List[Dict],
relevance_labels: List[float]):
"""
从训练数据中学习最优权重
使用回归方法优化权重
"""
from sklearn.linear_model import Ridge
# 准备特征
X = []
for item in training_data:
X.append([
item.get('vector_score', 0),
item.get('keyword_score', 0),
item.get('reranker_score', 0)
])
# 回归学习
model = Ridge(alpha=1.0)
model.fit(X, relevance_labels)
# 归一化权重
coefs = model.coef_
weights = np.abs(coefs) / np.abs(coefs).sum()
self.weights = {
'vector': float(weights[0]),
'keyword': float(weights[1]),
'reranker': float(weights[2])
}
return self.weights5.2 训练数据构建
高质量的训练数据是重排模型性能的关键。以下方法可用于构建训练数据:
class TrainingDataBuilder:
"""重排模型训练数据构建"""
@staticmethod
def build_from_click_data(click_logs: List[Dict]) -> List[Tuple[str, str, int]]:
"""
从点击日志构建训练数据
使用点击数据作为隐式相关性标签
"""
training_data = []
for log in click_logs:
query = log["query"]
doc_id = log["doc_id"]
doc_content = log["content"]
clicks = log.get("clicks", 0)
impressions = log.get("impressions", 1)
# 使用CTR作为隐式标签
ctr = clicks / impressions if impressions > 0 else 0
# 二值化标签(可调整阈值)
label = 1 if ctr > 0.1 else 0
training_data.append((query, doc_content, label))
return training_data
@staticmethod
def build_from_explicit_labels(annotations: List[Dict]) -> List[Tuple[str, str, int]]:
"""
从人工标注构建训练数据
使用显式相关性标签
"""
training_data = []
for ann in annotations:
query = ann["query"]
doc_id = ann["doc_id"]
doc_content = ann["content"]
relevance = ann["relevance"] # 0, 1, 2, 3, 4 等级别
training_data.append((query, doc_content, relevance))
return training_data
@staticmethod
def mine_hard_negatives(positive_docs: List[str],
all_docs: List[str],
bi_encoder) -> List[List[str]]:
"""
挖掘难负例
难负例:与正样本相似但实际不相关的文档
"""
hard_negatives = []
for pos_doc in positive_docs:
# 编码正样本
pos_emb = bi_encoder.encode_document(pos_doc)
# 编码所有候选文档
candidates = [d for d in all_docs if d != pos_doc]
candidate_embs = torch.cat([
bi_encoder.encode_document(c).cpu()
for c in candidates
], dim=0)
# 计算相似度
pos_emb_norm = torch.nn.functional.normalize(pos_emb, p=2, dim=1)
candidate_embs_norm = torch.nn.functional.normalize(candidate_embs, p=2, dim=1)
similarities = torch.sum(pos_emb_norm * candidate_embs_norm, dim=-1)
# 选择相似度适中的作为难负例(避免选择太相似或太不同的)
threshold_high = 0.9
threshold_low = 0.5
hard_neg = []
for i, sim in enumerate(similarities):
if threshold_low < sim < threshold_high:
hard_neg.append(candidates[i])
if len(hard_neg) >= 5:
break
hard_negatives.append(hard_neg)
return hard_negatives六、实战部署指南
6.1 模型选择建议
| 场景 | 推荐模型 | 说明 |
|---|---|---|
| 高频检索 | Bi-Encoder | 支持预计算,延迟最低 |
| 精确重排 | Cross-Encoder (BGE-reranker) | 精度最高,适合小规模候选集 |
| 平衡方案 | ColBERT | 质量和效率的折中 |
| 实时交互 | MiniLM-Cross-Encoder | 轻量级,适合延迟敏感场景 |
6.2 部署配置示例
# 模型加载配置
RERANKER_CONFIGS = {
# 轻量级(适合初次重排)
"light": {
"model_name": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"max_length": 512,
"batch_size": 32
},
# 中量级(适合二次重排)
"medium": {
"model_name": "BAAI/bge-reranker-base",
"max_length": 512,
"batch_size": 16
},
# 重量级(适合最终重排)
"heavy": {
"model_name": "BAAI/bge-reranker-large",
"max_length": 512,
"batch_size": 4
}
}
def load_reranker(config_name: str = "medium"):
"""加载重排模型"""
from sentence_transformers import CrossEncoder
config = RERANKER_CONFIGS[config_name]
model = CrossEncoder(
config["model_name"],
max_length=config["max_length"],
device="cuda" # 或 "cpu"
)
return model七、相关主题链接
- 主流重排模型对比 - 各类重排模型的详细对比
- Embedding模型选择 - 编码模型选型指南
- 向量数据库对比 - 底层存储选型
- 混合检索技术 - 混合搜索策略
- 评估体系 - 排序质量评估方法
- GraphRAG深度指南 - GraphRAG中的重排应用
更新日志
- 2026-04-18: 初始版本完成
- 包含Cross-Encoder、Bi-Encoder、ColBERT的详细解析
- 提供完整的训练和部署代码