Appearance
第45天:RAG原理与架构
学习目标
- 理解RAG的基本概念
- 掌握RAG vs Fine-tuning的区别
- 学习RAG系统架构
- 了解RAG的优势和局限
- 理解RAG应用场景
RAG基本概念
什么是RAG
RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合检索和生成的AI技术。它通过从外部知识库中检索相关信息,然后使用大语言模型基于检索到的内容生成答案。
核心思想:
用户查询 → 检索相关文档 → 结合文档生成答案 → 返回结果工作流程:
- 索引阶段:将文档切分、向量化并存储到向量数据库
- 检索阶段:根据用户查询检索相关文档片段
- 生成阶段:将查询和检索到的文档输入LLM生成答案
RAG vs Fine-tuning
| 特性 | RAG | Fine-tuning |
|---|---|---|
| 知识更新 | 实时更新 | 需要重新训练 |
| 数据需求 | 少量数据 | 大量数据 |
| 训练成本 | 低 | 高 |
| 推理速度 | 较慢 | 较快 |
| 知识准确性 | 高 | 中 |
| 幻觉问题 | 低 | 中 |
| 适用场景 | 动态知识 | 专用任务 |
选择建议:
- 使用RAG:知识频繁更新、需要高准确性、数据量少
- 使用Fine-tuning:专用任务、需要快速推理、数据量大
- 结合使用:RAG + Fine-tuning获得最佳效果
RAG系统架构
基础架构
┌─────────────────────────────────────────┐
│ 文档集合 │
└───────────────────┬─────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 文档处理 │
│ - 文档加载 │
│ - 文本预处理 │
│ - 文档切片 │
└───────────────────┬─────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 向量化 │
│ - Embedding生成 │
│ - 向量索引 │
└───────────────────┬─────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 向量数据库 │
└───────────────────┬─────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 检索阶段 │
│ - 查询向量化 │
│ - 相似度搜索 │
│ - 结果过滤 │
└───────────────────┬─────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 生成阶段 │
│ - Prompt构建 │
│ - LLM生成 │
│ - 答案后处理 │
└─────────────────────────────────────────┘高级架构
python
from typing import List, Dict, Optional
import numpy as np
class RAGSystem:
def __init__(self, config: Dict):
self.document_processor = DocumentProcessor()
self.embedding_generator = EmbeddingGenerator(config["embedding_model"])
self.vector_database = VectorDatabase(config["vector_db"])
self.retriever = Retriever(self.vector_database, self.embedding_generator)
self.reranker = Reranker()
self.llm = LLM(config["llm_model"])
self.prompt_template = PromptTemplate()
def index_documents(self, documents: List[str]) -> Dict:
processed_docs = self.document_processor.process(documents)
embeddings = []
for doc in processed_docs:
embedding = self.embedding_generator.generate(doc["content"])
embeddings.append(embedding)
self.vector_database.add_batch(embeddings, processed_docs)
return {
"status": "success",
"indexed_count": len(processed_docs)
}
def query(self, question: str, top_k: int = 5) -> Dict:
query_embedding = self.embedding_generator.generate(question)
retrieved_docs = self.retriever.retrieve(
query_embedding,
top_k=top_k * 2
)
reranked_docs = self.reranker.rerank(
question,
retrieved_docs
)
top_docs = reranked_docs[:top_k]
answer = self._generate_answer(question, top_docs)
return {
"question": question,
"answer": answer,
"sources": [doc["metadata"] for doc in top_docs],
"retrieved_count": len(retrieved_docs)
}
def _generate_answer(self, question: str,
documents: List[Dict]) -> str:
context = self._build_context(documents)
prompt = self.prompt_template.format(
question=question,
context=context
)
response = self.llm.generate(prompt)
return response
def _build_context(self, documents: List[Dict]) -> str:
return "\n\n".join([
f"Document {i+1}:\n{doc['content']}"
for i, doc in enumerate(documents)
])RAG的优势
1. 知识实时更新
python
class RealTimeRAG(RAGSystem):
def add_document(self, document: str) -> Dict:
processed_doc = self.document_processor.process([document])[0]
embedding = self.embedding_generator.generate(processed_doc["content"])
self.vector_database.add(embedding, processed_doc)
return {
"status": "success",
"message": "Document added successfully"
}
def update_document(self, doc_id: str,
new_content: str) -> Dict:
old_doc = self.vector_database.get_by_id(doc_id)
if old_doc:
self.vector_database.delete(doc_id)
return self.add_document(new_content)2. 减少幻觉
python
class HallucinationReducedRAG(RAGSystem):
def query(self, question: str, top_k: int = 5) -> Dict:
result = super().query(question, top_k)
if not result["sources"]:
return {
"question": question,
"answer": "I don't have enough information to answer this question.",
"sources": [],
"confidence": "low"
}
confidence = self._calculate_confidence(result)
result["confidence"] = confidence
if confidence < 0.5:
result["answer"] += "\n\nNote: This answer is based on limited information."
return result
def _calculate_confidence(self, result: Dict) -> float:
if not result["sources"]:
return 0.0
avg_similarity = np.mean([
doc.get("similarity", 0.0)
for doc in result["sources"]
])
return avg_similarity3. 可追溯性
python
class TraceableRAG(RAGSystem):
def query(self, question: str, top_k: int = 5) -> Dict:
result = super().query(question, top_k)
result["trace"] = {
"retrieval_time": self._measure_retrieval_time(),
"generation_time": self._measure_generation_time(),
"documents_used": len(result["sources"]),
"answer_length": len(result["answer"])
}
return result
def _measure_retrieval_time(self) -> float:
start_time = time.time()
return time.time() - start_time
def _measure_generation_time(self) -> float:
start_time = time.time()
return time.time() - start_timeRAG的局限
1. 检索质量依赖
python
class QualityAwareRAG(RAGSystem):
def query(self, question: str, top_k: int = 5) -> Dict:
result = super().query(question, top_k)
retrieval_quality = self._assess_retrieval_quality(result)
if retrieval_quality < 0.3:
result["warning"] = "Low quality retrieval detected. Results may be inaccurate."
return result
def _assess_retrieval_quality(self, result: Dict) -> float:
if not result["sources"]:
return 0.0
similarities = [
doc.get("similarity", 0.0)
for doc in result["sources"]
]
return np.mean(similarities)2. 上下文窗口限制
python
class ContextWindowRAG(RAGSystem):
def __init__(self, config: Dict):
super().__init__(config)
self.max_context_length = config.get("max_context_length", 4000)
def _build_context(self, documents: List[Dict]) -> str:
context_parts = []
current_length = 0
for doc in documents:
doc_length = len(doc["content"])
if current_length + doc_length > self.max_context_length:
break
context_parts.append(doc["content"])
current_length += doc_length
return "\n\n".join(context_parts)3. 计算开销
python
class CostOptimizedRAG(RAGSystem):
def query(self, question: str, top_k: int = 5) -> Dict:
start_time = time.time()
result = super().query(question, top_k)
end_time = time.time()
result["performance"] = {
"total_time": end_time - start_time,
"tokens_used": self._estimate_tokens(result),
"cost": self._estimate_cost(result)
}
return result
def _estimate_tokens(self, result: Dict) -> int:
return len(result["answer"].split()) + \
sum(len(doc["content"].split())
for doc in result["sources"])
def _estimate_cost(self, result: Dict) -> float:
tokens = self._estimate_tokens(result)
return tokens * 0.00002RAG应用场景
1. 企业知识库
python
class EnterpriseKnowledgeBase(RAGSystem):
def __init__(self, config: Dict):
super().__init__(config)
self.access_control = AccessControl()
def query(self, question: str, user_id: str,
top_k: int = 5) -> Dict:
user_permissions = self.access_control.get_permissions(user_id)
result = super().query(question, top_k)
filtered_sources = [
doc for doc in result["sources"]
if self._has_access(doc, user_permissions)
]
result["sources"] = filtered_sources
if not filtered_sources:
result["answer"] = "You don't have permission to access this information."
return result
def _has_access(self, doc: Dict,
permissions: List[str]) -> bool:
doc_permissions = doc.get("metadata", {}).get("permissions", [])
return any(perm in doc_permissions for perm in permissions)2. 客户服务
python
class CustomerServiceRAG(RAGSystem):
def __init__(self, config: Dict):
super().__init__(config)
self.conversation_history = ConversationHistory()
def query(self, question: str, session_id: str,
top_k: int = 5) -> Dict:
history = self.conversation_history.get_history(session_id)
context = self._build_conversational_context(history)
enhanced_question = f"{context}\n\nCurrent question: {question}"
result = super().query(enhanced_question, top_k)
self.conversation_history.add_message(
session_id,
{"role": "user", "content": question}
)
self.conversation_history.add_message(
session_id,
{"role": "assistant", "content": result["answer"]}
)
return result
def _build_conversational_context(self,
history: List[Dict]) -> str:
recent_history = history[-5:]
return "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in recent_history
])3. 法律文档分析
python
class LegalDocumentRAG(RAGSystem):
def __init__(self, config: Dict):
super().__init__(config)
self.citation_manager = CitationManager()
def query(self, question: str, top_k: int = 5) -> Dict:
result = super().query(question, top_k)
citations = self._extract_citations(result)
result["citations"] = citations
result["answer"] = self._format_answer_with_citations(
result["answer"],
citations
)
return result
def _extract_citations(self, result: Dict) -> List[Dict]:
citations = []
for doc in result["sources"]:
citation = {
"document_id": doc["metadata"].get("id"),
"title": doc["metadata"].get("title"),
"page": doc["metadata"].get("page"),
"relevance": doc.get("similarity", 0.0)
}
citations.append(citation)
return citations
def _format_answer_with_citations(self, answer: str,
citations: List[Dict]) -> str:
formatted_answer = answer
for i, citation in enumerate(citations):
citation_text = f"[{i+1}]"
if citation_text not in formatted_answer:
formatted_answer += f" {citation_text}"
formatted_answer += "\n\nReferences:\n"
formatted_answer += "\n".join([
f"{i+1}. {cit['title']}, Page {cit['page']}"
for i, cit in enumerate(citations)
])
return formatted_answer实践练习
练习1:实现简单的RAG系统
python
class SimpleRAG:
def __init__(self):
self.documents = []
self.embeddings = []
self.embedding_generator = EmbeddingGenerator()
self.llm = LLM()
def add_document(self, document: str):
self.documents.append(document)
embedding = self.embedding_generator.generate(document)
self.embeddings.append(embedding)
def query(self, question: str) -> str:
query_embedding = self.embedding_generator.generate(question)
similarities = [
self._cosine_similarity(query_embedding, doc_embedding)
for doc_embedding in self.embeddings
]
top_indices = np.argsort(similarities)[-3:]
top_docs = [self.documents[i] for i in top_indices]
context = "\n\n".join(top_docs)
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
return self.llm.generate(prompt)
def _cosine_similarity(self, vec1, vec2):
return np.dot(vec1, vec2) / (
np.linalg.norm(vec1) * np.linalg.norm(vec2)
)练习2:实现带重排序的RAG
python
class RerankedRAG(SimpleRAG):
def __init__(self):
super().__init__()
self.reranker = Reranker()
def query(self, question: str) -> str:
query_embedding = self.embedding_generator.generate(question)
similarities = [
self._cosine_similarity(query_embedding, doc_embedding)
for doc_embedding in self.embeddings
]
top_indices = np.argsort(similarities)[-10:]
top_docs = [self.documents[i] for i in top_indices]
reranked_docs = self.reranker.rerank(question, top_docs)
context = "\n\n".join(reranked_docs[:3])
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
return self.llm.generate(prompt)总结
本节我们学习了RAG的原理与架构:
- RAG的基本概念和工作流程
- RAG vs Fine-tuning的区别
- RAG系统架构(基础和高级)
- RAG的优势(实时更新、减少幻觉、可追溯性)
- RAG的局限(检索质量、上下文窗口、计算开销)
- RAG应用场景(企业知识库、客户服务、法律文档)
理解这些基础概念为后续深入学习向量数据库、文档处理、检索优化等技术打下了坚实的基础。
