Appearance
第50天:RAG高级技术
学习目标
- 掌握多轮对话RAG
- 学习知识图谱RAG
- 理解多模态RAG
- 掌握RAG与Agent结合
- 了解RAG性能优化
多轮对话RAG
对话历史管理
python
from typing import List, Dict
from datetime import datetime
class ConversationHistory:
def __init__(self, max_history: int = 10):
self.history = []
self.max_history = max_history
def add_message(self, role: str, content: str):
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
}
self.history.append(message)
if len(self.history) > self.max_history:
self.history = self.history[-self.max_history:]
def get_history(self, n: int = None) -> List[Dict]:
if n is None:
return self.history
return self.history[-n:]
def clear(self):
self.history = []
def get_context(self) -> str:
context_parts = []
for msg in self.history:
role = msg["role"]
content = msg["content"]
context_parts.append(f"{role}: {content}")
return "\n".join(context_parts)上下文压缩
python
class ContextCompressor:
def __init__(self, max_tokens: int = 2000):
self.max_tokens = max_tokens
def compress(self, history: List[Dict]) -> str:
compressed_parts = []
current_tokens = 0
for msg in reversed(history):
content = msg["content"]
tokens = self._estimate_tokens(content)
if current_tokens + tokens > self.max_tokens:
break
compressed_parts.insert(0, content)
current_tokens += tokens
return "\n".join(compressed_parts)
def _estimate_tokens(self, text: str) -> int:
return len(text.split())多轮RAG检索
python
class ConversationalRAG:
def __init__(self, retriever, llm):
self.retriever = retriever
self.llm = llm
self.conversation_history = ConversationHistory()
self.context_compressor = ContextCompressor()
def query(self, question: str) -> Dict:
self.conversation_history.add_message("user", question)
context = self._build_context()
enhanced_query = self._enhance_query(question, context)
retrieved_docs = self.retriever.retrieve(enhanced_query, top_k=5)
answer = self._generate_answer(question, retrieved_docs, context)
self.conversation_history.add_message("assistant", answer)
return {
"question": question,
"answer": answer,
"sources": [doc["metadata"] for doc in retrieved_docs]
}
def _build_context(self) -> str:
history = self.conversation_history.get_history()
return self.context_compressor.compress(history)
def _enhance_query(self, question: str,
context: str) -> str:
if not context:
return question
prompt = f"""
Previous conversation:
{context}
Current question: {question}
Rewrite the question to include relevant context from the conversation.
"""
enhanced_query = self.llm.generate(prompt)
return enhanced_query
def _generate_answer(self, question: str,
retrieved_docs: List[Dict],
context: str) -> str:
docs_context = "\n\n".join([
doc["content"] for doc in retrieved_docs
])
prompt = f"""
Context from previous conversation:
{context}
Retrieved documents:
{docs_context}
Question: {question}
Answer the question based on the retrieved documents and conversation context.
"""
answer = self.llm.generate(prompt)
return answer知识图谱RAG
知识图谱构建
python
from typing import List, Dict, Set
class KnowledgeGraph:
def __init__(self):
self.entities = {}
self.relations = []
def add_entity(self, entity_id: str, entity_type: str,
attributes: Dict):
self.entities[entity_id] = {
"type": entity_type,
"attributes": attributes
}
def add_relation(self, source: str, relation: str,
target: str):
self.relations.append({
"source": source,
"relation": relation,
"target": target
})
def get_neighbors(self, entity_id: str) -> List[Dict]:
neighbors = []
for rel in self.relations:
if rel["source"] == entity_id:
neighbors.append({
"entity": rel["target"],
"relation": rel["relation"]
})
elif rel["target"] == entity_id:
neighbors.append({
"entity": rel["source"],
"relation": rel["relation"]
})
return neighbors
def get_entity(self, entity_id: str) -> Dict:
return self.entities.get(entity_id, {})
def find_entities_by_type(self, entity_type: str) -> List[str]:
return [
entity_id
for entity_id, entity in self.entities.items()
if entity["type"] == entity_type
]知识图谱检索
python
class GraphRetriever:
def __init__(self, knowledge_graph: KnowledgeGraph):
self.kg = knowledge_graph
def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
entities = self._extract_entities(query)
results = []
for entity in entities:
entity_info = self.kg.get_entity(entity)
neighbors = self.kg.get_neighbors(entity)
results.append({
"entity": entity,
"info": entity_info,
"neighbors": neighbors
})
return results[:top_k]
def _extract_entities(self, query: str) -> List[str]:
entities = []
for entity_id in self.kg.entities:
if entity_id.lower() in query.lower():
entities.append(entity_id)
return entities
def retrieve_with_path(self, start_entity: str,
end_entity: str,
max_depth: int = 3) -> List[List[str]]:
paths = []
self._find_paths(
start_entity,
end_entity,
[],
max_depth,
paths
)
return paths
def _find_paths(self, current: str, target: str,
path: List[str], depth: int,
paths: List[List[str]]):
if depth == 0:
return
path = path + [current]
if current == target:
paths.append(path)
return
neighbors = self.kg.get_neighbors(current)
for neighbor in neighbors:
entity = neighbor["entity"]
if entity not in path:
self._find_paths(
entity,
target,
path,
depth - 1,
paths
)知识图谱RAG
python
class KnowledgeGraphRAG:
def __init__(self, graph_retriever, vector_retriever, llm):
self.graph_retriever = graph_retriever
self.vector_retriever = vector_retriever
self.llm = llm
def query(self, question: str) -> Dict:
graph_results = self.graph_retriever.retrieve(question, top_k=5)
vector_results = self.vector_retriever.retrieve(question, top_k=5)
combined_context = self._combine_contexts(
graph_results,
vector_results
)
answer = self._generate_answer(question, combined_context)
return {
"question": question,
"answer": answer,
"graph_sources": graph_results,
"vector_sources": vector_results
}
def _combine_contexts(self, graph_results: List[Dict],
vector_results: List[Dict]) -> str:
context_parts = []
context_parts.append("Knowledge Graph:")
for result in graph_results:
context_parts.append(f"- {result['entity']}: {result['info']}")
context_parts.append("\nVector Search:")
for result in vector_results:
context_parts.append(f"- {result['content'][:200]}")
return "\n".join(context_parts)
def _generate_answer(self, question: str,
context: str) -> str:
prompt = f"""
Context:
{context}
Question: {question}
Answer the question based on the context.
"""
answer = self.llm.generate(prompt)
return answer多模态RAG
图像Embedding
python
import numpy as np
class ImageEmbeddingGenerator:
def __init__(self, model_name: str = "clip-vit-base-patch32"):
self.model_name = model_name
self.dimension = 512
def generate(self, image_path: str) -> np.ndarray:
try:
from PIL import Image
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(self.model_name, device=device)
image = Image.open(image_path)
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input)
return image_features.cpu().numpy()[0]
except ImportError:
raise ImportError("Install required packages: pip install torch torchvision clip-by-openai")
def generate_batch(self, image_paths: List[str]) -> np.ndarray:
embeddings = []
for image_path in image_paths:
embedding = self.generate(image_path)
embeddings.append(embedding)
return np.array(embeddings)多模态检索
python
class MultiModalRetriever:
def __init__(self, text_db, image_db,
text_embedding_gen, image_embedding_gen):
self.text_db = text_db
self.image_db = image_db
self.text_embedding_gen = text_embedding_gen
self.image_embedding_gen = image_embedding_gen
def retrieve_text(self, query: str, top_k: int = 10) -> List[Dict]:
query_embedding = self.text_embedding_gen.generate(query)
return self.text_db.search(query_embedding, top_k)
def retrieve_image(self, query: str, top_k: int = 10) -> List[Dict]:
query_embedding = self.text_embedding_gen.generate(query)
return self.image_db.search(query_embedding, top_k)
def retrieve_multi_modal(self, query: str,
top_k: int = 10) -> List[Dict]:
text_results = self.retrieve_text(query, top_k)
image_results = self.retrieve_image(query, top_k)
combined = self._combine_results(text_results, image_results)
return combined[:top_k]
def _combine_results(self, text_results: List[Dict],
image_results: List[Dict]) -> List[Dict]:
combined = []
for i, result in enumerate(text_results):
combined.append({
"type": "text",
"content": result["content"],
"score": result["score"] * (1 - i / len(text_results))
})
for i, result in enumerate(image_results):
combined.append({
"type": "image",
"content": result["metadata"]["path"],
"score": result["score"] * (1 - i / len(image_results))
})
combined.sort(key=lambda x: x["score"], reverse=True)
return combined多模态RAG
python
class MultiModalRAG:
def __init__(self, retriever, llm):
self.retriever = retriever
self.llm = llm
def query(self, question: str) -> Dict:
results = self.retriever.retrieve_multi_modal(question, top_k=10)
context = self._build_context(results)
answer = self._generate_answer(question, context)
return {
"question": question,
"answer": answer,
"sources": results
}
def _build_context(self, results: List[Dict]) -> str:
context_parts = []
for result in results:
if result["type"] == "text":
context_parts.append(f"Text: {result['content'][:200]}")
elif result["type"] == "image":
context_parts.append(f"Image: {result['content']}")
return "\n".join(context_parts)
def _generate_answer(self, question: str,
context: str) -> str:
prompt = f"""
Context:
{context}
Question: {question}
Answer the question based on the context.
"""
answer = self.llm.generate(prompt)
return answerRAG与Agent结合
RAG Agent
python
from typing import List, Dict, Optional
class RAGAgent:
def __init__(self, retriever, llm, tools: List):
self.retriever = retriever
self.llm = llm
self.tools = {tool.name: tool for tool in tools}
self.memory = []
def execute(self, task: str) -> Dict:
self.memory.append({"role": "user", "content": task})
plan = self._plan(task)
results = []
for step in plan:
step_result = self._execute_step(step)
results.append(step_result)
answer = self._generate_answer(task, results)
self.memory.append({"role": "assistant", "content": answer})
return {
"task": task,
"answer": answer,
"steps": results
}
def _plan(self, task: str) -> List[Dict]:
prompt = f"""
Task: {task}
Available tools: {list(self.tools.keys())}
Plan the steps to complete this task.
Return a list of steps, each with a tool name and parameters.
"""
response = self.llm.generate(prompt)
return self._parse_plan(response)
def _execute_step(self, step: Dict) -> Dict:
tool_name = step.get("tool")
parameters = step.get("parameters", {})
if tool_name == "retrieve":
return self._retrieve(parameters)
elif tool_name == "generate":
return self._generate(parameters)
elif tool_name in self.tools:
tool = self.tools[tool_name]
return tool.execute(parameters)
else:
return {"error": f"Unknown tool: {tool_name}"}
def _retrieve(self, parameters: Dict) -> Dict:
query = parameters.get("query", "")
top_k = parameters.get("top_k", 5)
results = self.retriever.retrieve(query, top_k)
return {
"action": "retrieve",
"query": query,
"results": results
}
def _generate(self, parameters: Dict) -> Dict:
context = parameters.get("context", "")
question = parameters.get("question", "")
prompt = f"""
Context:
{context}
Question: {question}
Answer the question based on the context.
"""
answer = self.llm.generate(prompt)
return {
"action": "generate",
"question": question,
"answer": answer
}
def _generate_answer(self, task: str,
results: List[Dict]) -> str:
context = "\n".join([
f"Step {i+1}: {result}"
for i, result in enumerate(results)
])
prompt = f"""
Task: {task}
Execution steps:
{context}
Provide a final answer based on the execution steps.
"""
answer = self.llm.generate(prompt)
return answer
def _parse_plan(self, response: str) -> List[Dict]:
steps = []
try:
import json
steps = json.loads(response)
except:
steps = [{"tool": "retrieve", "parameters": {"query": response}}]
return steps自适应RAG Agent
python
class AdaptiveRAGAgent(RAGAgent):
def __init__(self, retriever, llm, tools: List):
super().__init__(retriever, llm, tools)
self.performance_tracker = PerformanceTracker()
def execute(self, task: str) -> Dict:
self.performance_tracker.start_task(task)
result = super().execute(task)
self.performance_tracker.end_task(task, result)
self._adjust_strategy()
return result
def _adjust_strategy(self):
performance = self.performance_tracker.get_performance()
if performance["retrieval_accuracy"] < 0.7:
self._increase_retrieval_depth()
if performance["generation_quality"] < 0.7:
self._increase_context_length()
def _increase_retrieval_depth(self):
pass
def _increase_context_length(self):
passRAG性能优化
查询缓存
python
from functools import lru_cache
import hashlib
class QueryCache:
def __init__(self, max_size: int = 1000):
self.cache = {}
self.max_size = max_size
def get(self, query: str) -> Optional[Dict]:
cache_key = self._generate_key(query)
if cache_key in self.cache:
return self.cache[cache_key]
return None
def set(self, query: str, result: Dict):
cache_key = self._generate_key(query)
if len(self.cache) >= self.max_size:
self._evict()
self.cache[cache_key] = result
def _generate_key(self, query: str) -> str:
return hashlib.md5(query.encode()).hexdigest()
def _evict(self):
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]批量处理
python
class BatchRAG:
def __init__(self, retriever, llm):
self.retriever = retriever
self.llm = llm
def query_batch(self, queries: List[str]) -> List[Dict]:
results = []
for query in queries:
result = self.query(query)
results.append(result)
return results
def query(self, query: str) -> Dict:
retrieved_docs = self.retriever.retrieve(query, top_k=5)
answer = self._generate_answer(query, retrieved_docs)
return {
"question": query,
"answer": answer,
"sources": retrieved_docs
}
def _generate_answer(self, question: str,
retrieved_docs: List[Dict]) -> str:
context = "\n\n".join([
doc["content"] for doc in retrieved_docs
])
prompt = f"""
Context:
{context}
Question: {question}
Answer the question based on the context.
"""
answer = self.llm.generate(prompt)
return answer实践练习
练习1:实现多轮对话RAG
python
class SimpleConversationalRAG:
def __init__(self, retriever, llm):
self.retriever = retriever
self.llm = llm
self.history = []
def query(self, question: str):
self.history.append({"role": "user", "content": question})
context = self._build_context()
retrieved_docs = self.retriever.retrieve(question, top_k=5)
answer = self._generate_answer(question, retrieved_docs, context)
self.history.append({"role": "assistant", "content": answer})
return {
"question": question,
"answer": answer,
"sources": retrieved_docs
}
def _build_context(self):
return "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in self.history[-5:]
])
def _generate_answer(self, question, retrieved_docs, context):
docs_context = "\n\n".join([doc["content"] for doc in retrieved_docs])
prompt = f"""
Context:
{context}
Documents:
{docs_context}
Question: {question}
Answer:
"""
return self.llm.generate(prompt)练习2:实现知识图谱RAG
python
class SimpleGraphRAG:
def __init__(self, graph_retriever, vector_retriever, llm):
self.graph_retriever = graph_retriever
self.vector_retriever = vector_retriever
self.llm = llm
def query(self, question: str):
graph_results = self.graph_retriever.retrieve(question, top_k=5)
vector_results = self.vector_retriever.retrieve(question, top_k=5)
context = self._combine_contexts(graph_results, vector_results)
answer = self._generate_answer(question, context)
return {
"question": question,
"answer": answer,
"graph_sources": graph_results,
"vector_sources": vector_results
}
def _combine_contexts(self, graph_results, vector_results):
context_parts = []
context_parts.append("Knowledge Graph:")
for result in graph_results:
context_parts.append(f"- {result['entity']}")
context_parts.append("\nVector Search:")
for result in vector_results:
context_parts.append(f"- {result['content'][:200]}")
return "\n".join(context_parts)
def _generate_answer(self, question, context):
prompt = f"""
Context:
{context}
Question: {question}
Answer:
"""
return self.llm.generate(prompt)总结
本节我们学习了RAG高级技术:
- 多轮对话RAG(对话历史、上下文压缩)
- 知识图谱RAG(图谱构建、检索、RAG)
- 多模态RAG(图像Embedding、多模态检索)
- RAG与Agent结合(RAG Agent、自适应Agent)
- RAG性能优化(查询缓存、批量处理)
这些高级技术可以显著提升RAG系统的能力和性能。
