Appearance
第49天:检索优化
学习目标
- 理解混合检索原理
- 掌握重排序技术
- 学习查询扩展方法
- 了解过滤和筛选
- 掌握检索性能优化
混合检索
向量+关键词检索
python
from typing import List, Dict, Tuple
import numpy as np
class HybridRetriever:
def __init__(self, vector_db, keyword_index,
alpha: float = 0.5):
self.vector_db = vector_db
self.keyword_index = keyword_index
self.alpha = alpha
def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
vector_results = self._vector_search(query, top_k * 2)
keyword_results = self._keyword_search(query, top_k * 2)
combined_scores = self._combine_scores(
vector_results,
keyword_results
)
sorted_results = sorted(
combined_scores.items(),
key=lambda x: x[1],
reverse=True
)
return [
{
"content": result["content"],
"score": score,
"metadata": result["metadata"]
}
for result, score in sorted_results[:top_k]
]
def _vector_search(self, query: str,
top_k: int) -> List[Dict]:
query_embedding = self._generate_embedding(query)
return self.vector_db.search(query_embedding, top_k)
def _keyword_search(self, query: str,
top_k: int) -> List[Dict]:
return self.keyword_index.search(query, top_k)
def _combine_scores(self, vector_results: List[Dict],
keyword_results: List[Dict]) -> Dict:
combined = {}
for i, result in enumerate(vector_results):
doc_id = result.get("id", f"vec_{i}")
vector_score = (1 - self.alpha) * (1 - i / len(vector_results))
combined[doc_id] = combined.get(doc_id, 0) + vector_score
for i, result in enumerate(keyword_results):
doc_id = result.get("id", f"kw_{i}")
keyword_score = self.alpha * (1 - i / len(keyword_results))
combined[doc_id] = combined.get(doc_id, 0) + keyword_score
return combined多路召回
python
class MultiPathRetriever:
def __init__(self, retrievers: List):
self.retrievers = retrievers
def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
all_results = []
for retriever in self.retrievers:
results = retriever.retrieve(query, top_k)
all_results.extend(results)
deduplicated_results = self._deduplicate(all_results)
reranked_results = self._rerank(query, deduplicated_results)
return reranked_results[:top_k]
def _deduplicate(self, results: List[Dict]) -> List[Dict]:
seen = set()
deduplicated = []
for result in results:
result_id = result.get("id", result.get("content", ""))
if result_id not in seen:
seen.add(result_id)
deduplicated.append(result)
return deduplicated
def _rerank(self, query: str,
results: List[Dict]) -> List[Dict]:
for result in results:
result["rerank_score"] = self._calculate_rerank_score(
query,
result
)
return sorted(
results,
key=lambda x: x["rerank_score"],
reverse=True
)
def _calculate_rerank_score(self, query: str,
result: Dict) -> float:
content = result.get("content", "")
query_words = set(query.lower().split())
content_words = set(content.lower().split())
intersection = query_words & content_words
union = query_words | content_words
jaccard_similarity = len(intersection) / len(union) if union else 0
original_score = result.get("score", 0.5)
return 0.7 * original_score + 0.3 * jaccard_similarity重排序
Cross-Encoder重排序
python
class CrossEncoderReranker:
def __init__(self, model):
self.model = model
def rerank(self, query: str,
documents: List[Dict]) -> List[Dict]:
if not documents:
return []
scores = []
for doc in documents:
score = self._calculate_score(query, doc)
scores.append(score)
for i, doc in enumerate(documents):
doc["rerank_score"] = scores[i]
reranked = sorted(
documents,
key=lambda x: x["rerank_score"],
reverse=True
)
return reranked
def _calculate_score(self, query: str,
document: Dict) -> float:
content = document.get("content", "")
input_text = f"[CLS] {query} [SEP] {content}"
score = self.model.predict(input_text)
return scoreLLM重排序
python
class LLMReranker:
def __init__(self, llm):
self.llm = llm
def rerank(self, query: str,
documents: List[Dict]) -> List[Dict]:
if not documents:
return []
prompt = self._build_rerank_prompt(query, documents)
response = self.llm.generate(prompt)
rankings = self._parse_rankings(response)
for doc in documents:
doc_id = doc.get("id", "")
doc["rerank_score"] = rankings.get(doc_id, 0)
reranked = sorted(
documents,
key=lambda x: x["rerank_score"],
reverse=True
)
return reranked
def _build_rerank_prompt(self, query: str,
documents: List[Dict]) -> str:
prompt = f"Query: {query}\n\n"
prompt += "Rank the following documents by relevance to the query.\n"
prompt += "Return the ranking as a JSON object with document IDs as keys and scores (1-10) as values.\n\n"
for i, doc in enumerate(documents):
doc_id = doc.get("id", f"doc_{i}")
content = doc.get("content", "")[:500]
prompt += f"{i+1}. {doc_id}: {content}\n"
return prompt
def _parse_rankings(self, response: str) -> Dict:
try:
import json
return json.loads(response)
except:
return {}多样性重排序
python
class DiversityReranker:
def __init__(self, lambda_param: float = 0.5):
self.lambda_param = lambda_param
def rerank(self, query: str,
documents: List[Dict]) -> List[Dict]:
if not documents:
return []
for i, doc in enumerate(documents):
relevance_score = doc.get("score", 0.5)
diversity_score = self._calculate_diversity(doc, documents[:i])
doc["rerank_score"] = (
(1 - self.lambda_param) * relevance_score +
self.lambda_param * diversity_score
)
reranked = sorted(
documents,
key=lambda x: x["rerank_score"],
reverse=True
)
return reranked
def _calculate_diversity(self, doc: Dict,
previous_docs: List[Dict]) -> float:
if not previous_docs:
return 1.0
current_content = doc.get("content", "")
similarities = []
for prev_doc in previous_docs:
prev_content = prev_doc.get("content", "")
similarity = self._calculate_similarity(
current_content,
prev_content
)
similarities.append(similarity)
avg_similarity = sum(similarities) / len(similarities)
return 1 - avg_similarity
def _calculate_similarity(self, text1: str,
text2: str) -> float:
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union) if union else 0查询扩展
同义词扩展
python
class SynonymExpander:
def __init__(self):
self.synonym_dict = self._build_synonym_dict()
def _build_synonym_dict(self) -> Dict:
return {
"ai": ["artificial intelligence", "machine intelligence"],
"ml": ["machine learning", "ml"],
"dl": ["deep learning", "neural networks"],
"nlp": ["natural language processing", "text processing"],
"cv": ["computer vision", "image processing"]
}
def expand(self, query: str) -> List[str]:
expanded_queries = [query]
words = query.lower().split()
for word in words:
if word in self.synonym_dict:
for synonym in self.synonym_dict[word]:
expanded_query = query.replace(word, synonym, 1)
expanded_queries.append(expanded_query)
return expanded_queriesLLM查询扩展
python
class LLMQueryExpander:
def __init__(self, llm):
self.llm = llm
def expand(self, query: str, n_expansions: int = 3) -> List[str]:
prompt = f"""
Generate {n_expansions} alternative queries for the following question:
Question: {query}
The alternative queries should:
1. Have the same meaning
2. Use different wording
3. Be natural and clear
Return only the queries, one per line.
"""
response = self.llm.generate(prompt)
expanded_queries = self._parse_expansions(response)
expanded_queries.insert(0, query)
return expanded_queries
def _parse_expansions(self, response: str) -> List[str]:
expansions = []
for line in response.split('\n'):
line = line.strip()
if line:
expansions.append(line)
return expansions历史查询扩展
python
class HistoryQueryExpander:
def __init__(self, max_history: int = 10):
self.query_history = []
self.max_history = max_history
def add_query(self, query: str):
self.query_history.append(query)
if len(self.query_history) > self.max_history:
self.query_history = self.query_history[-self.max_history:]
def expand(self, query: str) -> List[str]:
expanded_queries = [query]
similar_queries = self._find_similar_queries(query)
for similar_query in similar_queries:
if similar_query != query:
expanded_queries.append(similar_query)
return expanded_queries
def _find_similar_queries(self, query: str) -> List[str]:
query_words = set(query.lower().split())
similar_queries = []
for history_query in self.query_history:
history_words = set(history_query.lower().split())
similarity = len(query_words & history_words) / len(query_words | history_words)
if similarity > 0.5:
similar_queries.append(history_query)
return similar_queries过滤和筛选
元数据过滤
python
class MetadataFilter:
def __init__(self):
pass
def filter(self, documents: List[Dict],
filters: Dict) -> List[Dict]:
filtered = []
for doc in documents:
if self._matches_filters(doc, filters):
filtered.append(doc)
return filtered
def _matches_filters(self, doc: Dict,
filters: Dict) -> bool:
metadata = doc.get("metadata", {})
for key, value in filters.items():
if key not in metadata:
return False
if isinstance(value, list):
if metadata[key] not in value:
return False
elif metadata[key] != value:
return False
return True时间范围过滤
python
class TimeRangeFilter:
def __init__(self):
pass
def filter(self, documents: List[Dict],
start_time: str, end_time: str) -> List[Dict]:
filtered = []
for doc in documents:
metadata = doc.get("metadata", {})
doc_time = metadata.get("created", "")
if self._in_time_range(doc_time, start_time, end_time):
filtered.append(doc)
return filtered
def _in_time_range(self, doc_time: str,
start_time: str, end_time: str) -> bool:
try:
from datetime import datetime
doc_dt = datetime.fromisoformat(doc_time)
start_dt = datetime.fromisoformat(start_time)
end_dt = datetime.fromisoformat(end_time)
return start_dt <= doc_dt <= end_dt
except:
return False权限过滤
python
class PermissionFilter:
def __init__(self):
pass
def filter(self, documents: List[Dict],
user_permissions: List[str]) -> List[Dict]:
filtered = []
for doc in documents:
metadata = doc.get("metadata", {})
doc_permissions = metadata.get("permissions", [])
if self._has_permission(doc_permissions, user_permissions):
filtered.append(doc)
return filtered
def _has_permission(self, doc_permissions: List[str],
user_permissions: List[str]) -> bool:
for perm in doc_permissions:
if perm in user_permissions:
return True
return False检索性能优化
缓存机制
python
from functools import lru_cache
import hashlib
class CachedRetriever:
def __init__(self, retriever, cache_size: int = 1000):
self.retriever = retriever
self.cache = {}
self.cache_size = cache_size
def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
cache_key = self._generate_cache_key(query, top_k)
if cache_key in self.cache:
return self.cache[cache_key]
results = self.retriever.retrieve(query, top_k)
self._update_cache(cache_key, results)
return results
def _generate_cache_key(self, query: str, top_k: int) -> str:
key_str = f"{query}:{top_k}"
return hashlib.md5(key_str.encode()).hexdigest()
def _update_cache(self, cache_key: str, results: List[Dict]):
if len(self.cache) >= self.cache_size:
self._evict_cache()
self.cache[cache_key] = results
def _evict_cache(self):
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]批量检索
python
class BatchRetriever:
def __init__(self, retriever):
self.retriever = retriever
def retrieve_batch(self, queries: List[str],
top_k: int = 10) -> List[List[Dict]]:
results = []
for query in queries:
query_results = self.retriever.retrieve(query, top_k)
results.append(query_results)
return results并行检索
python
import asyncio
from typing import List, Dict
class ParallelRetriever:
def __init__(self, retrievers: List):
self.retrievers = retrievers
async def retrieve_parallel(self, query: str,
top_k: int = 10) -> List[Dict]:
tasks = [
self._retrieve_from_retriever(retriever, query, top_k)
for retriever in self.retrievers
]
results = await asyncio.gather(*tasks)
all_results = []
for result_list in results:
all_results.extend(result_list)
deduplicated = self._deduplicate(all_results)
return deduplicated[:top_k]
async def _retrieve_from_retriever(self, retriever,
query: str, top_k: int) -> List[Dict]:
return retriever.retrieve(query, top_k)
def _deduplicate(self, results: List[Dict]) -> List[Dict]:
seen = set()
deduplicated = []
for result in results:
result_id = result.get("id", result.get("content", ""))
if result_id not in seen:
seen.add(result_id)
deduplicated.append(result)
return deduplicated实践练习
练习1:实现混合检索系统
python
class SimpleHybridRetriever:
def __init__(self, vector_db, keyword_index):
self.vector_db = vector_db
self.keyword_index = keyword_index
def retrieve(self, query: str, top_k: int = 10):
vector_results = self.vector_db.search(query, top_k)
keyword_results = self.keyword_index.search(query, top_k)
combined = self._combine_results(vector_results, keyword_results)
return combined[:top_k]
def _combine_results(self, vector_results, keyword_results):
combined = {}
for i, result in enumerate(vector_results):
doc_id = result.get("id", f"vec_{i}")
combined[doc_id] = combined.get(doc_id, 0) + (1 - i / len(vector_results)) * 0.5
for i, result in enumerate(keyword_results):
doc_id = result.get("id", f"kw_{i}")
combined[doc_id] = combined.get(doc_id, 0) + (1 - i / len(keyword_results)) * 0.5
sorted_results = sorted(
combined.items(),
key=lambda x: x[1],
reverse=True
)
return [
{"id": doc_id, "score": score}
for doc_id, score in sorted_results
]练习2:实现带重排序的检索系统
python
class RerankedRetriever:
def __init__(self, retriever, reranker):
self.retriever = retriever
self.reranker = reranker
def retrieve(self, query: str, top_k: int = 10):
initial_results = self.retriever.retrieve(query, top_k * 2)
reranked_results = self.reranker.rerank(query, initial_results)
return reranked_results[:top_k]总结
本节我们学习了检索优化:
- 混合检索(向量+关键词、多路召回)
- 重排序技术(Cross-Encoder、LLM、多样性)
- 查询扩展方法(同义词、LLLM、历史)
- 过滤和筛选(元数据、时间范围、权限)
- 检索性能优化(缓存、批量、并行)
检索优化是提高RAG系统准确性和性能的关键技术。
