Skip to content

第50天:RAG高级技术

学习目标

  • 掌握多轮对话RAG
  • 学习知识图谱RAG
  • 理解多模态RAG
  • 掌握RAG与Agent结合
  • 了解RAG性能优化

多轮对话RAG

对话历史管理

python
from typing import List, Dict
from datetime import datetime

class ConversationHistory:
    def __init__(self, max_history: int = 10):
        self.history = []
        self.max_history = max_history
    
    def add_message(self, role: str, content: str):
        message = {
            "role": role,
            "content": content,
            "timestamp": datetime.now().isoformat()
        }
        
        self.history.append(message)
        
        if len(self.history) > self.max_history:
            self.history = self.history[-self.max_history:]
    
    def get_history(self, n: int = None) -> List[Dict]:
        if n is None:
            return self.history
        
        return self.history[-n:]
    
    def clear(self):
        self.history = []
    
    def get_context(self) -> str:
        context_parts = []
        
        for msg in self.history:
            role = msg["role"]
            content = msg["content"]
            context_parts.append(f"{role}: {content}")
        
        return "\n".join(context_parts)

上下文压缩

python
class ContextCompressor:
    def __init__(self, max_tokens: int = 2000):
        self.max_tokens = max_tokens
    
    def compress(self, history: List[Dict]) -> str:
        compressed_parts = []
        current_tokens = 0
        
        for msg in reversed(history):
            content = msg["content"]
            tokens = self._estimate_tokens(content)
            
            if current_tokens + tokens > self.max_tokens:
                break
            
            compressed_parts.insert(0, content)
            current_tokens += tokens
        
        return "\n".join(compressed_parts)
    
    def _estimate_tokens(self, text: str) -> int:
        return len(text.split())

多轮RAG检索

python
class ConversationalRAG:
    def __init__(self, retriever, llm):
        self.retriever = retriever
        self.llm = llm
        self.conversation_history = ConversationHistory()
        self.context_compressor = ContextCompressor()
    
    def query(self, question: str) -> Dict:
        self.conversation_history.add_message("user", question)
        
        context = self._build_context()
        enhanced_query = self._enhance_query(question, context)
        
        retrieved_docs = self.retriever.retrieve(enhanced_query, top_k=5)
        
        answer = self._generate_answer(question, retrieved_docs, context)
        
        self.conversation_history.add_message("assistant", answer)
        
        return {
            "question": question,
            "answer": answer,
            "sources": [doc["metadata"] for doc in retrieved_docs]
        }
    
    def _build_context(self) -> str:
        history = self.conversation_history.get_history()
        return self.context_compressor.compress(history)
    
    def _enhance_query(self, question: str, 
                       context: str) -> str:
        if not context:
            return question
        
        prompt = f"""
        Previous conversation:
        {context}
        
        Current question: {question}
        
        Rewrite the question to include relevant context from the conversation.
        """
        
        enhanced_query = self.llm.generate(prompt)
        
        return enhanced_query
    
    def _generate_answer(self, question: str, 
                        retrieved_docs: List[Dict], 
                        context: str) -> str:
        docs_context = "\n\n".join([
            doc["content"] for doc in retrieved_docs
        ])
        
        prompt = f"""
        Context from previous conversation:
        {context}
        
        Retrieved documents:
        {docs_context}
        
        Question: {question}
        
        Answer the question based on the retrieved documents and conversation context.
        """
        
        answer = self.llm.generate(prompt)
        
        return answer

知识图谱RAG

知识图谱构建

python
from typing import List, Dict, Set

class KnowledgeGraph:
    def __init__(self):
        self.entities = {}
        self.relations = []
    
    def add_entity(self, entity_id: str, entity_type: str, 
                   attributes: Dict):
        self.entities[entity_id] = {
            "type": entity_type,
            "attributes": attributes
        }
    
    def add_relation(self, source: str, relation: str, 
                    target: str):
        self.relations.append({
            "source": source,
            "relation": relation,
            "target": target
        })
    
    def get_neighbors(self, entity_id: str) -> List[Dict]:
        neighbors = []
        
        for rel in self.relations:
            if rel["source"] == entity_id:
                neighbors.append({
                    "entity": rel["target"],
                    "relation": rel["relation"]
                })
            elif rel["target"] == entity_id:
                neighbors.append({
                    "entity": rel["source"],
                    "relation": rel["relation"]
                })
        
        return neighbors
    
    def get_entity(self, entity_id: str) -> Dict:
        return self.entities.get(entity_id, {})
    
    def find_entities_by_type(self, entity_type: str) -> List[str]:
        return [
            entity_id
            for entity_id, entity in self.entities.items()
            if entity["type"] == entity_type
        ]

知识图谱检索

python
class GraphRetriever:
    def __init__(self, knowledge_graph: KnowledgeGraph):
        self.kg = knowledge_graph
    
    def retrieve(self, query: str, top_k: int = 10) -> List[Dict]:
        entities = self._extract_entities(query)
        
        results = []
        
        for entity in entities:
            entity_info = self.kg.get_entity(entity)
            neighbors = self.kg.get_neighbors(entity)
            
            results.append({
                "entity": entity,
                "info": entity_info,
                "neighbors": neighbors
            })
        
        return results[:top_k]
    
    def _extract_entities(self, query: str) -> List[str]:
        entities = []
        
        for entity_id in self.kg.entities:
            if entity_id.lower() in query.lower():
                entities.append(entity_id)
        
        return entities
    
    def retrieve_with_path(self, start_entity: str, 
                           end_entity: str, 
                           max_depth: int = 3) -> List[List[str]]:
        paths = []
        
        self._find_paths(
            start_entity,
            end_entity,
            [],
            max_depth,
            paths
        )
        
        return paths
    
    def _find_paths(self, current: str, target: str, 
                   path: List[str], depth: int, 
                   paths: List[List[str]]):
        if depth == 0:
            return
        
        path = path + [current]
        
        if current == target:
            paths.append(path)
            return
        
        neighbors = self.kg.get_neighbors(current)
        
        for neighbor in neighbors:
            entity = neighbor["entity"]
            
            if entity not in path:
                self._find_paths(
                    entity,
                    target,
                    path,
                    depth - 1,
                    paths
                )

知识图谱RAG

python
class KnowledgeGraphRAG:
    def __init__(self, graph_retriever, vector_retriever, llm):
        self.graph_retriever = graph_retriever
        self.vector_retriever = vector_retriever
        self.llm = llm
    
    def query(self, question: str) -> Dict:
        graph_results = self.graph_retriever.retrieve(question, top_k=5)
        vector_results = self.vector_retriever.retrieve(question, top_k=5)
        
        combined_context = self._combine_contexts(
            graph_results,
            vector_results
        )
        
        answer = self._generate_answer(question, combined_context)
        
        return {
            "question": question,
            "answer": answer,
            "graph_sources": graph_results,
            "vector_sources": vector_results
        }
    
    def _combine_contexts(self, graph_results: List[Dict],
                          vector_results: List[Dict]) -> str:
        context_parts = []
        
        context_parts.append("Knowledge Graph:")
        for result in graph_results:
            context_parts.append(f"- {result['entity']}: {result['info']}")
        
        context_parts.append("\nVector Search:")
        for result in vector_results:
            context_parts.append(f"- {result['content'][:200]}")
        
        return "\n".join(context_parts)
    
    def _generate_answer(self, question: str, 
                        context: str) -> str:
        prompt = f"""
        Context:
        {context}
        
        Question: {question}
        
        Answer the question based on the context.
        """
        
        answer = self.llm.generate(prompt)
        
        return answer

多模态RAG

图像Embedding

python
import numpy as np

class ImageEmbeddingGenerator:
    def __init__(self, model_name: str = "clip-vit-base-patch32"):
        self.model_name = model_name
        self.dimension = 512
    
    def generate(self, image_path: str) -> np.ndarray:
        try:
            from PIL import Image
            import clip
            
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model, preprocess = clip.load(self.model_name, device=device)
            
            image = Image.open(image_path)
            image_input = preprocess(image).unsqueeze(0).to(device)
            
            with torch.no_grad():
                image_features = model.encode_image(image_input)
            
            return image_features.cpu().numpy()[0]
        except ImportError:
            raise ImportError("Install required packages: pip install torch torchvision clip-by-openai")
    
    def generate_batch(self, image_paths: List[str]) -> np.ndarray:
        embeddings = []
        
        for image_path in image_paths:
            embedding = self.generate(image_path)
            embeddings.append(embedding)
        
        return np.array(embeddings)

多模态检索

python
class MultiModalRetriever:
    def __init__(self, text_db, image_db, 
                 text_embedding_gen, image_embedding_gen):
        self.text_db = text_db
        self.image_db = image_db
        self.text_embedding_gen = text_embedding_gen
        self.image_embedding_gen = image_embedding_gen
    
    def retrieve_text(self, query: str, top_k: int = 10) -> List[Dict]:
        query_embedding = self.text_embedding_gen.generate(query)
        return self.text_db.search(query_embedding, top_k)
    
    def retrieve_image(self, query: str, top_k: int = 10) -> List[Dict]:
        query_embedding = self.text_embedding_gen.generate(query)
        return self.image_db.search(query_embedding, top_k)
    
    def retrieve_multi_modal(self, query: str, 
                           top_k: int = 10) -> List[Dict]:
        text_results = self.retrieve_text(query, top_k)
        image_results = self.retrieve_image(query, top_k)
        
        combined = self._combine_results(text_results, image_results)
        
        return combined[:top_k]
    
    def _combine_results(self, text_results: List[Dict],
                         image_results: List[Dict]) -> List[Dict]:
        combined = []
        
        for i, result in enumerate(text_results):
            combined.append({
                "type": "text",
                "content": result["content"],
                "score": result["score"] * (1 - i / len(text_results))
            })
        
        for i, result in enumerate(image_results):
            combined.append({
                "type": "image",
                "content": result["metadata"]["path"],
                "score": result["score"] * (1 - i / len(image_results))
            })
        
        combined.sort(key=lambda x: x["score"], reverse=True)
        
        return combined

多模态RAG

python
class MultiModalRAG:
    def __init__(self, retriever, llm):
        self.retriever = retriever
        self.llm = llm
    
    def query(self, question: str) -> Dict:
        results = self.retriever.retrieve_multi_modal(question, top_k=10)
        
        context = self._build_context(results)
        
        answer = self._generate_answer(question, context)
        
        return {
            "question": question,
            "answer": answer,
            "sources": results
        }
    
    def _build_context(self, results: List[Dict]) -> str:
        context_parts = []
        
        for result in results:
            if result["type"] == "text":
                context_parts.append(f"Text: {result['content'][:200]}")
            elif result["type"] == "image":
                context_parts.append(f"Image: {result['content']}")
        
        return "\n".join(context_parts)
    
    def _generate_answer(self, question: str, 
                        context: str) -> str:
        prompt = f"""
        Context:
        {context}
        
        Question: {question}
        
        Answer the question based on the context.
        """
        
        answer = self.llm.generate(prompt)
        
        return answer

RAG与Agent结合

RAG Agent

python
from typing import List, Dict, Optional

class RAGAgent:
    def __init__(self, retriever, llm, tools: List):
        self.retriever = retriever
        self.llm = llm
        self.tools = {tool.name: tool for tool in tools}
        self.memory = []
    
    def execute(self, task: str) -> Dict:
        self.memory.append({"role": "user", "content": task})
        
        plan = self._plan(task)
        
        results = []
        for step in plan:
            step_result = self._execute_step(step)
            results.append(step_result)
        
        answer = self._generate_answer(task, results)
        
        self.memory.append({"role": "assistant", "content": answer})
        
        return {
            "task": task,
            "answer": answer,
            "steps": results
        }
    
    def _plan(self, task: str) -> List[Dict]:
        prompt = f"""
        Task: {task}
        
        Available tools: {list(self.tools.keys())}
        
        Plan the steps to complete this task.
        Return a list of steps, each with a tool name and parameters.
        """
        
        response = self.llm.generate(prompt)
        
        return self._parse_plan(response)
    
    def _execute_step(self, step: Dict) -> Dict:
        tool_name = step.get("tool")
        parameters = step.get("parameters", {})
        
        if tool_name == "retrieve":
            return self._retrieve(parameters)
        elif tool_name == "generate":
            return self._generate(parameters)
        elif tool_name in self.tools:
            tool = self.tools[tool_name]
            return tool.execute(parameters)
        else:
            return {"error": f"Unknown tool: {tool_name}"}
    
    def _retrieve(self, parameters: Dict) -> Dict:
        query = parameters.get("query", "")
        top_k = parameters.get("top_k", 5)
        
        results = self.retriever.retrieve(query, top_k)
        
        return {
            "action": "retrieve",
            "query": query,
            "results": results
        }
    
    def _generate(self, parameters: Dict) -> Dict:
        context = parameters.get("context", "")
        question = parameters.get("question", "")
        
        prompt = f"""
        Context:
        {context}
        
        Question: {question}
        
        Answer the question based on the context.
        """
        
        answer = self.llm.generate(prompt)
        
        return {
            "action": "generate",
            "question": question,
            "answer": answer
        }
    
    def _generate_answer(self, task: str, 
                         results: List[Dict]) -> str:
        context = "\n".join([
            f"Step {i+1}: {result}"
            for i, result in enumerate(results)
        ])
        
        prompt = f"""
        Task: {task}
        
        Execution steps:
        {context}
        
        Provide a final answer based on the execution steps.
        """
        
        answer = self.llm.generate(prompt)
        
        return answer
    
    def _parse_plan(self, response: str) -> List[Dict]:
        steps = []
        
        try:
            import json
            steps = json.loads(response)
        except:
            steps = [{"tool": "retrieve", "parameters": {"query": response}}]
        
        return steps

自适应RAG Agent

python
class AdaptiveRAGAgent(RAGAgent):
    def __init__(self, retriever, llm, tools: List):
        super().__init__(retriever, llm, tools)
        self.performance_tracker = PerformanceTracker()
    
    def execute(self, task: str) -> Dict:
        self.performance_tracker.start_task(task)
        
        result = super().execute(task)
        
        self.performance_tracker.end_task(task, result)
        
        self._adjust_strategy()
        
        return result
    
    def _adjust_strategy(self):
        performance = self.performance_tracker.get_performance()
        
        if performance["retrieval_accuracy"] < 0.7:
            self._increase_retrieval_depth()
        
        if performance["generation_quality"] < 0.7:
            self._increase_context_length()
    
    def _increase_retrieval_depth(self):
        pass
    
    def _increase_context_length(self):
        pass

RAG性能优化

查询缓存

python
from functools import lru_cache
import hashlib

class QueryCache:
    def __init__(self, max_size: int = 1000):
        self.cache = {}
        self.max_size = max_size
    
    def get(self, query: str) -> Optional[Dict]:
        cache_key = self._generate_key(query)
        
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        return None
    
    def set(self, query: str, result: Dict):
        cache_key = self._generate_key(query)
        
        if len(self.cache) >= self.max_size:
            self._evict()
        
        self.cache[cache_key] = result
    
    def _generate_key(self, query: str) -> str:
        return hashlib.md5(query.encode()).hexdigest()
    
    def _evict(self):
        oldest_key = next(iter(self.cache))
        del self.cache[oldest_key]

批量处理

python
class BatchRAG:
    def __init__(self, retriever, llm):
        self.retriever = retriever
        self.llm = llm
    
    def query_batch(self, queries: List[str]) -> List[Dict]:
        results = []
        
        for query in queries:
            result = self.query(query)
            results.append(result)
        
        return results
    
    def query(self, query: str) -> Dict:
        retrieved_docs = self.retriever.retrieve(query, top_k=5)
        
        answer = self._generate_answer(query, retrieved_docs)
        
        return {
            "question": query,
            "answer": answer,
            "sources": retrieved_docs
        }
    
    def _generate_answer(self, question: str, 
                        retrieved_docs: List[Dict]) -> str:
        context = "\n\n".join([
            doc["content"] for doc in retrieved_docs
        ])
        
        prompt = f"""
        Context:
        {context}
        
        Question: {question}
        
        Answer the question based on the context.
        """
        
        answer = self.llm.generate(prompt)
        
        return answer

实践练习

练习1:实现多轮对话RAG

python
class SimpleConversationalRAG:
    def __init__(self, retriever, llm):
        self.retriever = retriever
        self.llm = llm
        self.history = []
    
    def query(self, question: str):
        self.history.append({"role": "user", "content": question})
        
        context = self._build_context()
        retrieved_docs = self.retriever.retrieve(question, top_k=5)
        
        answer = self._generate_answer(question, retrieved_docs, context)
        
        self.history.append({"role": "assistant", "content": answer})
        
        return {
            "question": question,
            "answer": answer,
            "sources": retrieved_docs
        }
    
    def _build_context(self):
        return "\n".join([
            f"{msg['role']}: {msg['content']}"
            for msg in self.history[-5:]
        ])
    
    def _generate_answer(self, question, retrieved_docs, context):
        docs_context = "\n\n".join([doc["content"] for doc in retrieved_docs])
        
        prompt = f"""
        Context:
        {context}
        
        Documents:
        {docs_context}
        
        Question: {question}
        
        Answer:
        """
        
        return self.llm.generate(prompt)

练习2:实现知识图谱RAG

python
class SimpleGraphRAG:
    def __init__(self, graph_retriever, vector_retriever, llm):
        self.graph_retriever = graph_retriever
        self.vector_retriever = vector_retriever
        self.llm = llm
    
    def query(self, question: str):
        graph_results = self.graph_retriever.retrieve(question, top_k=5)
        vector_results = self.vector_retriever.retrieve(question, top_k=5)
        
        context = self._combine_contexts(graph_results, vector_results)
        
        answer = self._generate_answer(question, context)
        
        return {
            "question": question,
            "answer": answer,
            "graph_sources": graph_results,
            "vector_sources": vector_results
        }
    
    def _combine_contexts(self, graph_results, vector_results):
        context_parts = []
        
        context_parts.append("Knowledge Graph:")
        for result in graph_results:
            context_parts.append(f"- {result['entity']}")
        
        context_parts.append("\nVector Search:")
        for result in vector_results:
            context_parts.append(f"- {result['content'][:200]}")
        
        return "\n".join(context_parts)
    
    def _generate_answer(self, question, context):
        prompt = f"""
        Context:
        {context}
        
        Question: {question}
        
        Answer:
        """
        
        return self.llm.generate(prompt)

总结

本节我们学习了RAG高级技术:

  1. 多轮对话RAG(对话历史、上下文压缩)
  2. 知识图谱RAG(图谱构建、检索、RAG)
  3. 多模态RAG(图像Embedding、多模态检索)
  4. RAG与Agent结合(RAG Agent、自适应Agent)
  5. RAG性能优化(查询缓存、批量处理)

这些高级技术可以显著提升RAG系统的能力和性能。

参考资源