Skip to content

第49天:检索优化

学习目标

  • 理解混合检索原理
  • 掌握重排序技术
  • 学习查询扩展方法
  • 了解过滤和筛选
  • 掌握检索性能优化

混合检索

向量+关键词检索

python
from typing import List, Dict, Tuple
import numpy as np

class HybridRetriever:
    def __init__(self, vector_db, keyword_index, 
                 alpha: float = 0.5):
        self.vector_db = vector_db
        self.keyword_index = keyword_index
        self.alpha = alpha
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
        vector_results = self._vector_search(query, top_k * 2)
        keyword_results = self._keyword_search(query, top_k * 2)
        
        combined_scores = self._combine_scores(
            vector_results,
            keyword_results
        )
        
        sorted_results = sorted(
            combined_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )
        
        return [
            {
                "content": result["content"],
                "score": score,
                "metadata": result["metadata"]
            }
            for result, score in sorted_results[:top_k]
        ]
    
    def _vector_search(self, query: str, 
                        top_k: int) -> List[Dict]:
        query_embedding = self._generate_embedding(query)
        return self.vector_db.search(query_embedding, top_k)
    
    def _keyword_search(self, query: str, 
                         top_k: int) -> List[Dict]:
        return self.keyword_index.search(query, top_k)
    
    def _combine_scores(self, vector_results: List[Dict],
                       keyword_results: List[Dict]) -> Dict:
        combined = {}
        
        for i, result in enumerate(vector_results):
            doc_id = result.get("id", f"vec_{i}")
            vector_score = (1 - self.alpha) * (1 - i / len(vector_results))
            combined[doc_id] = combined.get(doc_id, 0) + vector_score
        
        for i, result in enumerate(keyword_results):
            doc_id = result.get("id", f"kw_{i}")
            keyword_score = self.alpha * (1 - i / len(keyword_results))
            combined[doc_id] = combined.get(doc_id, 0) + keyword_score
        
        return combined

多路召回

python
class MultiPathRetriever:
    def __init__(self, retrievers: List):
        self.retrievers = retrievers
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
        all_results = []
        
        for retriever in self.retrievers:
            results = retriever.retrieve(query, top_k)
            all_results.extend(results)
        
        deduplicated_results = self._deduplicate(all_results)
        
        reranked_results = self._rerank(query, deduplicated_results)
        
        return reranked_results[:top_k]
    
    def _deduplicate(self, results: List[Dict]) -> List[Dict]:
        seen = set()
        deduplicated = []
        
        for result in results:
            result_id = result.get("id", result.get("content", ""))
            
            if result_id not in seen:
                seen.add(result_id)
                deduplicated.append(result)
        
        return deduplicated
    
    def _rerank(self, query: str, 
                  results: List[Dict]) -> List[Dict]:
        for result in results:
            result["rerank_score"] = self._calculate_rerank_score(
                query,
                result
            )
        
        return sorted(
            results,
            key=lambda x: x["rerank_score"],
            reverse=True
        )
    
    def _calculate_rerank_score(self, query: str, 
                               result: Dict) -> float:
        content = result.get("content", "")
        query_words = set(query.lower().split())
        content_words = set(content.lower().split())
        
        intersection = query_words & content_words
        union = query_words | content_words
        
        jaccard_similarity = len(intersection) / len(union) if union else 0
        
        original_score = result.get("score", 0.5)
        
        return 0.7 * original_score + 0.3 * jaccard_similarity

重排序

Cross-Encoder重排序

python
class CrossEncoderReranker:
    def __init__(self, model):
        self.model = model
    
    def rerank(self, query: str, 
               documents: List[Dict]) -> List[Dict]:
        if not documents:
            return []
        
        scores = []
        
        for doc in documents:
            score = self._calculate_score(query, doc)
            scores.append(score)
        
        for i, doc in enumerate(documents):
            doc["rerank_score"] = scores[i]
        
        reranked = sorted(
            documents,
            key=lambda x: x["rerank_score"],
            reverse=True
        )
        
        return reranked
    
    def _calculate_score(self, query: str, 
                       document: Dict) -> float:
        content = document.get("content", "")
        
        input_text = f"[CLS] {query} [SEP] {content}"
        
        score = self.model.predict(input_text)
        
        return score

LLM重排序

python
class LLMReranker:
    def __init__(self, llm):
        self.llm = llm
    
    def rerank(self, query: str, 
               documents: List[Dict]) -> List[Dict]:
        if not documents:
            return []
        
        prompt = self._build_rerank_prompt(query, documents)
        
        response = self.llm.generate(prompt)
        
        rankings = self._parse_rankings(response)
        
        for doc in documents:
            doc_id = doc.get("id", "")
            doc["rerank_score"] = rankings.get(doc_id, 0)
        
        reranked = sorted(
            documents,
            key=lambda x: x["rerank_score"],
            reverse=True
        )
        
        return reranked
    
    def _build_rerank_prompt(self, query: str, 
                           documents: List[Dict]) -> str:
        prompt = f"Query: {query}\n\n"
        prompt += "Rank the following documents by relevance to the query.\n"
        prompt += "Return the ranking as a JSON object with document IDs as keys and scores (1-10) as values.\n\n"
        
        for i, doc in enumerate(documents):
            doc_id = doc.get("id", f"doc_{i}")
            content = doc.get("content", "")[:500]
            prompt += f"{i+1}. {doc_id}: {content}\n"
        
        return prompt
    
    def _parse_rankings(self, response: str) -> Dict:
        try:
            import json
            return json.loads(response)
        except:
            return {}

多样性重排序

python
class DiversityReranker:
    def __init__(self, lambda_param: float = 0.5):
        self.lambda_param = lambda_param
    
    def rerank(self, query: str, 
               documents: List[Dict]) -> List[Dict]:
        if not documents:
            return []
        
        for i, doc in enumerate(documents):
            relevance_score = doc.get("score", 0.5)
            diversity_score = self._calculate_diversity(doc, documents[:i])
            
            doc["rerank_score"] = (
                (1 - self.lambda_param) * relevance_score +
                self.lambda_param * diversity_score
            )
        
        reranked = sorted(
            documents,
            key=lambda x: x["rerank_score"],
            reverse=True
        )
        
        return reranked
    
    def _calculate_diversity(self, doc: Dict, 
                           previous_docs: List[Dict]) -> float:
        if not previous_docs:
            return 1.0
        
        current_content = doc.get("content", "")
        similarities = []
        
        for prev_doc in previous_docs:
            prev_content = prev_doc.get("content", "")
            similarity = self._calculate_similarity(
                current_content,
                prev_content
            )
            similarities.append(similarity)
        
        avg_similarity = sum(similarities) / len(similarities)
        
        return 1 - avg_similarity
    
    def _calculate_similarity(self, text1: str, 
                          text2: str) -> float:
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())
        
        intersection = words1 & words2
        union = words1 | words2
        
        return len(intersection) / len(union) if union else 0

查询扩展

同义词扩展

python
class SynonymExpander:
    def __init__(self):
        self.synonym_dict = self._build_synonym_dict()
    
    def _build_synonym_dict(self) -> Dict:
        return {
            "ai": ["artificial intelligence", "machine intelligence"],
            "ml": ["machine learning", "ml"],
            "dl": ["deep learning", "neural networks"],
            "nlp": ["natural language processing", "text processing"],
            "cv": ["computer vision", "image processing"]
        }
    
    def expand(self, query: str) -> List[str]:
        expanded_queries = [query]
        
        words = query.lower().split()
        
        for word in words:
            if word in self.synonym_dict:
                for synonym in self.synonym_dict[word]:
                    expanded_query = query.replace(word, synonym, 1)
                    expanded_queries.append(expanded_query)
        
        return expanded_queries

LLM查询扩展

python
class LLMQueryExpander:
    def __init__(self, llm):
        self.llm = llm
    
    def expand(self, query: str, n_expansions: int = 3) -> List[str]:
        prompt = f"""
        Generate {n_expansions} alternative queries for the following question:
        
        Question: {query}
        
        The alternative queries should:
        1. Have the same meaning
        2. Use different wording
        3. Be natural and clear
        
        Return only the queries, one per line.
        """
        
        response = self.llm.generate(prompt)
        
        expanded_queries = self._parse_expansions(response)
        expanded_queries.insert(0, query)
        
        return expanded_queries
    
    def _parse_expansions(self, response: str) -> List[str]:
        expansions = []
        
        for line in response.split('\n'):
            line = line.strip()
            if line:
                expansions.append(line)
        
        return expansions

历史查询扩展

python
class HistoryQueryExpander:
    def __init__(self, max_history: int = 10):
        self.query_history = []
        self.max_history = max_history
    
    def add_query(self, query: str):
        self.query_history.append(query)
        
        if len(self.query_history) > self.max_history:
            self.query_history = self.query_history[-self.max_history:]
    
    def expand(self, query: str) -> List[str]:
        expanded_queries = [query]
        
        similar_queries = self._find_similar_queries(query)
        
        for similar_query in similar_queries:
            if similar_query != query:
                expanded_queries.append(similar_query)
        
        return expanded_queries
    
    def _find_similar_queries(self, query: str) -> List[str]:
        query_words = set(query.lower().split())
        similar_queries = []
        
        for history_query in self.query_history:
            history_words = set(history_query.lower().split())
            
            similarity = len(query_words & history_words) / len(query_words | history_words)
            
            if similarity > 0.5:
                similar_queries.append(history_query)
        
        return similar_queries

过滤和筛选

元数据过滤

python
class MetadataFilter:
    def __init__(self):
        pass
    
    def filter(self, documents: List[Dict], 
               filters: Dict) -> List[Dict]:
        filtered = []
        
        for doc in documents:
            if self._matches_filters(doc, filters):
                filtered.append(doc)
        
        return filtered
    
    def _matches_filters(self, doc: Dict, 
                        filters: Dict) -> bool:
        metadata = doc.get("metadata", {})
        
        for key, value in filters.items():
            if key not in metadata:
                return False
            
            if isinstance(value, list):
                if metadata[key] not in value:
                    return False
            elif metadata[key] != value:
                return False
        
        return True

时间范围过滤

python
class TimeRangeFilter:
    def __init__(self):
        pass
    
    def filter(self, documents: List[Dict], 
               start_time: str, end_time: str) -> List[Dict]:
        filtered = []
        
        for doc in documents:
            metadata = doc.get("metadata", {})
            doc_time = metadata.get("created", "")
            
            if self._in_time_range(doc_time, start_time, end_time):
                filtered.append(doc)
        
        return filtered
    
    def _in_time_range(self, doc_time: str, 
                        start_time: str, end_time: str) -> bool:
        try:
            from datetime import datetime
            
            doc_dt = datetime.fromisoformat(doc_time)
            start_dt = datetime.fromisoformat(start_time)
            end_dt = datetime.fromisoformat(end_time)
            
            return start_dt <= doc_dt <= end_dt
        except:
            return False

权限过滤

python
class PermissionFilter:
    def __init__(self):
        pass
    
    def filter(self, documents: List[Dict], 
               user_permissions: List[str]) -> List[Dict]:
        filtered = []
        
        for doc in documents:
            metadata = doc.get("metadata", {})
            doc_permissions = metadata.get("permissions", [])
            
            if self._has_permission(doc_permissions, user_permissions):
                filtered.append(doc)
        
        return filtered
    
    def _has_permission(self, doc_permissions: List[str],
                         user_permissions: List[str]) -> bool:
        for perm in doc_permissions:
            if perm in user_permissions:
                return True
        
        return False

检索性能优化

缓存机制

python
from functools import lru_cache
import hashlib

class CachedRetriever:
    def __init__(self, retriever, cache_size: int = 1000):
        self.retriever = retriever
        self.cache = {}
        self.cache_size = cache_size
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
        cache_key = self._generate_cache_key(query, top_k)
        
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        results = self.retriever.retrieve(query, top_k)
        
        self._update_cache(cache_key, results)
        
        return results
    
    def _generate_cache_key(self, query: str, top_k: int) -> str:
        key_str = f"{query}:{top_k}"
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def _update_cache(self, cache_key: str, results: List[Dict]):
        if len(self.cache) >= self.cache_size:
            self._evict_cache()
        
        self.cache[cache_key] = results
    
    def _evict_cache(self):
        oldest_key = next(iter(self.cache))
        del self.cache[oldest_key]

批量检索

python
class BatchRetriever:
    def __init__(self, retriever):
        self.retriever = retriever
    
    def retrieve_batch(self, queries: List[str], 
                     top_k: int = 10) -> List[List[Dict]]:
        results = []
        
        for query in queries:
            query_results = self.retriever.retrieve(query, top_k)
            results.append(query_results)
        
        return results

并行检索

python
import asyncio
from typing import List, Dict

class ParallelRetriever:
    def __init__(self, retrievers: List):
        self.retrievers = retrievers
    
    async def retrieve_parallel(self, query: str, 
                            top_k: int = 10) -> List[Dict]:
        tasks = [
            self._retrieve_from_retriever(retriever, query, top_k)
            for retriever in self.retrievers
        ]
        
        results = await asyncio.gather(*tasks)
        
        all_results = []
        for result_list in results:
            all_results.extend(result_list)
        
        deduplicated = self._deduplicate(all_results)
        
        return deduplicated[:top_k]
    
    async def _retrieve_from_retriever(self, retriever, 
                                       query: str, top_k: int) -> List[Dict]:
        return retriever.retrieve(query, top_k)
    
    def _deduplicate(self, results: List[Dict]) -> List[Dict]:
        seen = set()
        deduplicated = []
        
        for result in results:
            result_id = result.get("id", result.get("content", ""))
            
            if result_id not in seen:
                seen.add(result_id)
                deduplicated.append(result)
        
        return deduplicated

实践练习

练习1:实现混合检索系统

python
class SimpleHybridRetriever:
    def __init__(self, vector_db, keyword_index):
        self.vector_db = vector_db
        self.keyword_index = keyword_index
    
    def retrieve(self, query: str, top_k: int = 10):
        vector_results = self.vector_db.search(query, top_k)
        keyword_results = self.keyword_index.search(query, top_k)
        
        combined = self._combine_results(vector_results, keyword_results)
        
        return combined[:top_k]
    
    def _combine_results(self, vector_results, keyword_results):
        combined = {}
        
        for i, result in enumerate(vector_results):
            doc_id = result.get("id", f"vec_{i}")
            combined[doc_id] = combined.get(doc_id, 0) + (1 - i / len(vector_results)) * 0.5
        
        for i, result in enumerate(keyword_results):
            doc_id = result.get("id", f"kw_{i}")
            combined[doc_id] = combined.get(doc_id, 0) + (1 - i / len(keyword_results)) * 0.5
        
        sorted_results = sorted(
            combined.items(),
            key=lambda x: x[1],
            reverse=True
        )
        
        return [
            {"id": doc_id, "score": score}
            for doc_id, score in sorted_results
        ]

练习2:实现带重排序的检索系统

python
class RerankedRetriever:
    def __init__(self, retriever, reranker):
        self.retriever = retriever
        self.reranker = reranker
    
    def retrieve(self, query: str, top_k: int = 10):
        initial_results = self.retriever.retrieve(query, top_k * 2)
        
        reranked_results = self.reranker.rerank(query, initial_results)
        
        return reranked_results[:top_k]

总结

本节我们学习了检索优化:

  1. 混合检索(向量+关键词、多路召回)
  2. 重排序技术(Cross-Encoder、LLM、多样性)
  3. 查询扩展方法(同义词、LLLM、历史)
  4. 过滤和筛选(元数据、时间范围、权限)
  5. 检索性能优化(缓存、批量、并行)

检索优化是提高RAG系统准确性和性能的关键技术。

参考资源