Skip to content

第45天:RAG原理与架构

学习目标

  • 理解RAG的基本概念
  • 掌握RAG vs Fine-tuning的区别
  • 学习RAG系统架构
  • 了解RAG的优势和局限
  • 理解RAG应用场景

RAG基本概念

什么是RAG

RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合检索和生成的AI技术。它通过从外部知识库中检索相关信息,然后使用大语言模型基于检索到的内容生成答案。

核心思想

用户查询 → 检索相关文档 → 结合文档生成答案 → 返回结果

工作流程

  1. 索引阶段:将文档切分、向量化并存储到向量数据库
  2. 检索阶段:根据用户查询检索相关文档片段
  3. 生成阶段:将查询和检索到的文档输入LLM生成答案

RAG vs Fine-tuning

特性RAGFine-tuning
知识更新实时更新需要重新训练
数据需求少量数据大量数据
训练成本
推理速度较慢较快
知识准确性
幻觉问题
适用场景动态知识专用任务

选择建议

  • 使用RAG:知识频繁更新、需要高准确性、数据量少
  • 使用Fine-tuning:专用任务、需要快速推理、数据量大
  • 结合使用:RAG + Fine-tuning获得最佳效果

RAG系统架构

基础架构

┌─────────────────────────────────────────┐
│         文档集合                   │
└───────────────────┬─────────────────────┘


┌─────────────────────────────────────────┐
│         文档处理                   │
│  - 文档加载                        │
│  - 文本预处理                      │
│  - 文档切片                        │
└───────────────────┬─────────────────────┘


┌─────────────────────────────────────────┐
│         向量化                      │
│  - Embedding生成                    │
│  - 向量索引                        │
└───────────────────┬─────────────────────┘


┌─────────────────────────────────────────┐
│         向量数据库                   │
└───────────────────┬─────────────────────┘


┌─────────────────────────────────────────┐
│         检索阶段                   │
│  - 查询向量化                      │
│  - 相似度搜索                      │
│  - 结果过滤                        │
└───────────────────┬─────────────────────┘


┌─────────────────────────────────────────┐
│         生成阶段                   │
│  - Prompt构建                       │
│  - LLM生成                         │
│  - 答案后处理                      │
└─────────────────────────────────────────┘

高级架构

python
from typing import List, Dict, Optional
import numpy as np

class RAGSystem:
    def __init__(self, config: Dict):
        self.document_processor = DocumentProcessor()
        self.embedding_generator = EmbeddingGenerator(config["embedding_model"])
        self.vector_database = VectorDatabase(config["vector_db"])
        self.retriever = Retriever(self.vector_database, self.embedding_generator)
        self.reranker = Reranker()
        self.llm = LLM(config["llm_model"])
        self.prompt_template = PromptTemplate()
    
    def index_documents(self, documents: List[str]) -> Dict:
        processed_docs = self.document_processor.process(documents)
        
        embeddings = []
        for doc in processed_docs:
            embedding = self.embedding_generator.generate(doc["content"])
            embeddings.append(embedding)
        
        self.vector_database.add_batch(embeddings, processed_docs)
        
        return {
            "status": "success",
            "indexed_count": len(processed_docs)
        }
    
    def query(self, question: str, top_k: int = 5) -> Dict:
        query_embedding = self.embedding_generator.generate(question)
        
        retrieved_docs = self.retriever.retrieve(
            query_embedding,
            top_k=top_k * 2
        )
        
        reranked_docs = self.reranker.rerank(
            question,
            retrieved_docs
        )
        
        top_docs = reranked_docs[:top_k]
        
        answer = self._generate_answer(question, top_docs)
        
        return {
            "question": question,
            "answer": answer,
            "sources": [doc["metadata"] for doc in top_docs],
            "retrieved_count": len(retrieved_docs)
        }
    
    def _generate_answer(self, question: str, 
                       documents: List[Dict]) -> str:
        context = self._build_context(documents)
        prompt = self.prompt_template.format(
            question=question,
            context=context
        )
        
        response = self.llm.generate(prompt)
        
        return response
    
    def _build_context(self, documents: List[Dict]) -> str:
        return "\n\n".join([
            f"Document {i+1}:\n{doc['content']}"
            for i, doc in enumerate(documents)
        ])

RAG的优势

1. 知识实时更新

python
class RealTimeRAG(RAGSystem):
    def add_document(self, document: str) -> Dict:
        processed_doc = self.document_processor.process([document])[0]
        embedding = self.embedding_generator.generate(processed_doc["content"])
        
        self.vector_database.add(embedding, processed_doc)
        
        return {
            "status": "success",
            "message": "Document added successfully"
        }
    
    def update_document(self, doc_id: str, 
                       new_content: str) -> Dict:
        old_doc = self.vector_database.get_by_id(doc_id)
        
        if old_doc:
            self.vector_database.delete(doc_id)
        
        return self.add_document(new_content)

2. 减少幻觉

python
class HallucinationReducedRAG(RAGSystem):
    def query(self, question: str, top_k: int = 5) -> Dict:
        result = super().query(question, top_k)
        
        if not result["sources"]:
            return {
                "question": question,
                "answer": "I don't have enough information to answer this question.",
                "sources": [],
                "confidence": "low"
            }
        
        confidence = self._calculate_confidence(result)
        result["confidence"] = confidence
        
        if confidence < 0.5:
            result["answer"] += "\n\nNote: This answer is based on limited information."
        
        return result
    
    def _calculate_confidence(self, result: Dict) -> float:
        if not result["sources"]:
            return 0.0
        
        avg_similarity = np.mean([
            doc.get("similarity", 0.0)
            for doc in result["sources"]
        ])
        
        return avg_similarity

3. 可追溯性

python
class TraceableRAG(RAGSystem):
    def query(self, question: str, top_k: int = 5) -> Dict:
        result = super().query(question, top_k)
        
        result["trace"] = {
            "retrieval_time": self._measure_retrieval_time(),
            "generation_time": self._measure_generation_time(),
            "documents_used": len(result["sources"]),
            "answer_length": len(result["answer"])
        }
        
        return result
    
    def _measure_retrieval_time(self) -> float:
        start_time = time.time()
        return time.time() - start_time
    
    def _measure_generation_time(self) -> float:
        start_time = time.time()
        return time.time() - start_time

RAG的局限

1. 检索质量依赖

python
class QualityAwareRAG(RAGSystem):
    def query(self, question: str, top_k: int = 5) -> Dict:
        result = super().query(question, top_k)
        
        retrieval_quality = self._assess_retrieval_quality(result)
        
        if retrieval_quality < 0.3:
            result["warning"] = "Low quality retrieval detected. Results may be inaccurate."
        
        return result
    
    def _assess_retrieval_quality(self, result: Dict) -> float:
        if not result["sources"]:
            return 0.0
        
        similarities = [
            doc.get("similarity", 0.0)
            for doc in result["sources"]
        ]
        
        return np.mean(similarities)

2. 上下文窗口限制

python
class ContextWindowRAG(RAGSystem):
    def __init__(self, config: Dict):
        super().__init__(config)
        self.max_context_length = config.get("max_context_length", 4000)
    
    def _build_context(self, documents: List[Dict]) -> str:
        context_parts = []
        current_length = 0
        
        for doc in documents:
            doc_length = len(doc["content"])
            
            if current_length + doc_length > self.max_context_length:
                break
            
            context_parts.append(doc["content"])
            current_length += doc_length
        
        return "\n\n".join(context_parts)

3. 计算开销

python
class CostOptimizedRAG(RAGSystem):
    def query(self, question: str, top_k: int = 5) -> Dict:
        start_time = time.time()
        
        result = super().query(question, top_k)
        
        end_time = time.time()
        result["performance"] = {
            "total_time": end_time - start_time,
            "tokens_used": self._estimate_tokens(result),
            "cost": self._estimate_cost(result)
        }
        
        return result
    
    def _estimate_tokens(self, result: Dict) -> int:
        return len(result["answer"].split()) + \
               sum(len(doc["content"].split()) 
                   for doc in result["sources"])
    
    def _estimate_cost(self, result: Dict) -> float:
        tokens = self._estimate_tokens(result)
        return tokens * 0.00002

RAG应用场景

1. 企业知识库

python
class EnterpriseKnowledgeBase(RAGSystem):
    def __init__(self, config: Dict):
        super().__init__(config)
        self.access_control = AccessControl()
    
    def query(self, question: str, user_id: str, 
             top_k: int = 5) -> Dict:
        user_permissions = self.access_control.get_permissions(user_id)
        
        result = super().query(question, top_k)
        
        filtered_sources = [
            doc for doc in result["sources"]
            if self._has_access(doc, user_permissions)
        ]
        
        result["sources"] = filtered_sources
        
        if not filtered_sources:
            result["answer"] = "You don't have permission to access this information."
        
        return result
    
    def _has_access(self, doc: Dict, 
                    permissions: List[str]) -> bool:
        doc_permissions = doc.get("metadata", {}).get("permissions", [])
        return any(perm in doc_permissions for perm in permissions)

2. 客户服务

python
class CustomerServiceRAG(RAGSystem):
    def __init__(self, config: Dict):
        super().__init__(config)
        self.conversation_history = ConversationHistory()
    
    def query(self, question: str, session_id: str, 
             top_k: int = 5) -> Dict:
        history = self.conversation_history.get_history(session_id)
        context = self._build_conversational_context(history)
        
        enhanced_question = f"{context}\n\nCurrent question: {question}"
        
        result = super().query(enhanced_question, top_k)
        
        self.conversation_history.add_message(
            session_id,
            {"role": "user", "content": question}
        )
        self.conversation_history.add_message(
            session_id,
            {"role": "assistant", "content": result["answer"]}
        )
        
        return result
    
    def _build_conversational_context(self, 
                                     history: List[Dict]) -> str:
        recent_history = history[-5:]
        return "\n".join([
            f"{msg['role']}: {msg['content']}"
            for msg in recent_history
        ])

3. 法律文档分析

python
class LegalDocumentRAG(RAGSystem):
    def __init__(self, config: Dict):
        super().__init__(config)
        self.citation_manager = CitationManager()
    
    def query(self, question: str, top_k: int = 5) -> Dict:
        result = super().query(question, top_k)
        
        citations = self._extract_citations(result)
        result["citations"] = citations
        
        result["answer"] = self._format_answer_with_citations(
            result["answer"],
            citations
        )
        
        return result
    
    def _extract_citations(self, result: Dict) -> List[Dict]:
        citations = []
        
        for doc in result["sources"]:
            citation = {
                "document_id": doc["metadata"].get("id"),
                "title": doc["metadata"].get("title"),
                "page": doc["metadata"].get("page"),
                "relevance": doc.get("similarity", 0.0)
            }
            citations.append(citation)
        
        return citations
    
    def _format_answer_with_citations(self, answer: str, 
                                     citations: List[Dict]) -> str:
        formatted_answer = answer
        
        for i, citation in enumerate(citations):
            citation_text = f"[{i+1}]"
            if citation_text not in formatted_answer:
                formatted_answer += f" {citation_text}"
        
        formatted_answer += "\n\nReferences:\n"
        formatted_answer += "\n".join([
            f"{i+1}. {cit['title']}, Page {cit['page']}"
            for i, cit in enumerate(citations)
        ])
        
        return formatted_answer

实践练习

练习1:实现简单的RAG系统

python
class SimpleRAG:
    def __init__(self):
        self.documents = []
        self.embeddings = []
        self.embedding_generator = EmbeddingGenerator()
        self.llm = LLM()
    
    def add_document(self, document: str):
        self.documents.append(document)
        embedding = self.embedding_generator.generate(document)
        self.embeddings.append(embedding)
    
    def query(self, question: str) -> str:
        query_embedding = self.embedding_generator.generate(question)
        
        similarities = [
            self._cosine_similarity(query_embedding, doc_embedding)
            for doc_embedding in self.embeddings
        ]
        
        top_indices = np.argsort(similarities)[-3:]
        top_docs = [self.documents[i] for i in top_indices]
        
        context = "\n\n".join(top_docs)
        prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
        
        return self.llm.generate(prompt)
    
    def _cosine_similarity(self, vec1, vec2):
        return np.dot(vec1, vec2) / (
            np.linalg.norm(vec1) * np.linalg.norm(vec2)
        )

练习2:实现带重排序的RAG

python
class RerankedRAG(SimpleRAG):
    def __init__(self):
        super().__init__()
        self.reranker = Reranker()
    
    def query(self, question: str) -> str:
        query_embedding = self.embedding_generator.generate(question)
        
        similarities = [
            self._cosine_similarity(query_embedding, doc_embedding)
            for doc_embedding in self.embeddings
        ]
        
        top_indices = np.argsort(similarities)[-10:]
        top_docs = [self.documents[i] for i in top_indices]
        
        reranked_docs = self.reranker.rerank(question, top_docs)
        
        context = "\n\n".join(reranked_docs[:3])
        prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
        
        return self.llm.generate(prompt)

总结

本节我们学习了RAG的原理与架构:

  1. RAG的基本概念和工作流程
  2. RAG vs Fine-tuning的区别
  3. RAG系统架构(基础和高级)
  4. RAG的优势(实时更新、减少幻觉、可追溯性)
  5. RAG的局限(检索质量、上下文窗口、计算开销)
  6. RAG应用场景(企业知识库、客户服务、法律文档)

理解这些基础概念为后续深入学习向量数据库、文档处理、检索优化等技术打下了坚实的基础。

参考资源