Skip to content

第51天:RAG模块总结与项目

学习目标

  • 总结RAG模块核心知识
  • 掌握企业知识库系统项目
  • 完成项目架构设计
  • 实现核心功能
  • 部署和优化系统

模块知识总结

RAG核心概念

RAG工作流程

文档 → 处理 → 切片 → 向量化 → 存储

查询 → 向量化 → 检索 → 重排序 → 生成

关键技术

  1. 文档处理:加载、预处理、切片、元数据提取
  2. 向量数据库:存储、索引、检索
  3. 检索优化:混合检索、重排序、查询扩展
  4. 高级技术:多轮对话、知识图谱、多模态

向量数据库对比

数据库优势劣势适用场景
Pinecone托管、易用成本高生产环境
Weaviate多模态、开源学习曲线陡研究项目
Qdrant高性能、开源文档较少生产环境
Chroma轻量、易用规模受限原型开发
Milvus大规模、高性能复杂企业应用
FAISS极高性能无服务器研究项目

RAG最佳实践

  1. 文档处理

    • 选择合适的切片策略
    • 提取丰富的元数据
    • 保持文档上下文完整性
  2. 检索优化

    • 使用混合检索提高准确性
    • 实施重排序提升相关性
    • 添加查询扩展覆盖更多场景
  3. 性能优化

    • 实施缓存机制
    • 使用批量处理
    • 并行检索加速

实战项目:企业知识库系统

项目概述

项目名称:Enterprise Knowledge Base (EKB)

项目描述

构建一个智能企业知识库系统,能够:

  • 管理企业文档(上传、索引、更新)
  • 智能检索相关文档
  • 自然语言问答
  • 多轮对话支持
  • 答案溯源和引用

技术栈

  • 后端:FastAPI + LangChain
  • 前端:React + Ant Design
  • 向量数据库:Qdrant
  • LLM:OpenAI GPT-4
  • 部署:Docker + Kubernetes

系统架构

┌─────────────────────────────────────────────────┐
│                   前端层                      │
│  - 文档管理界面                                 │
│  - 搜索界面                                     │
│  - 问答界面                                     │
│  - 管理后台                                     │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   API层                        │
│  - 文档管理API                                  │
│  - 搜索API                                      │
│  - 问答API                                      │
│  - 对话API                                      │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   业务层                       │
│  - 文档处理器                                   │
│  - 检索器                                       │
│  - RAG引擎                                      │
│  - 对话管理器                                   │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   数据层                       │
│  - Qdrant向量数据库                             │
│  - PostgreSQL文档元数据                         │
│  - Redis缓存                                    │
│  - MinIO对象存储                                │
└─────────────────────────────────────────────────┘

项目目录结构

enterprise-knowledge-base/
├── backend/
│   ├── app/
│   │   ├── api/
│   │   │   ├── documents.py
│   │   │   ├── search.py
│   │   │   ├── qa.py
│   │   │   └── chat.py
│   │   ├── core/
│   │   │   ├── config.py
│   │   │   ├── database.py
│   │   │   └── security.py
│   │   ├── services/
│   │   │   ├── document_processor.py
│   │   │   ├── retriever.py
│   │   │   ├── rag_engine.py
│   │   │   └── chat_manager.py
│   │   └── models/
│   │       ├── document.py
│   │       └── conversation.py
│   ├── tests/
│   └── main.py
├── frontend/
│   ├── src/
│   │   ├── components/
│   │   │   ├── DocumentUpload.tsx
│   │   │   ├── SearchBar.tsx
│   │   │   ├── QAInterface.tsx
│   │   │   └── ChatInterface.tsx
│   │   ├── pages/
│   │   │   ├── Documents.tsx
│   │   │   ├── Search.tsx
│   │   │   ├── QA.tsx
│   │   │   └── Chat.tsx
│   │   └── App.tsx
│   └── package.json
├── docker/
│   ├── Dockerfile.backend
│   ├── Dockerfile.frontend
│   └── docker-compose.yml
└── kubernetes/
    ├── deployment.yaml
    └── service.yaml

核心功能实现

1. 文档处理器

python
from typing import List, Dict
from pathlib import Path

class DocumentProcessor:
    def __init__(self, config: Dict):
        self.config = config
        self.loader = DocumentLoader()
        self.preprocessor = TextPreprocessor()
        self.splitter = SemanticSplitter(
            max_chunk_size=config.get("chunk_size", 1000)
        )
        self.embedding_generator = EmbeddingGenerator(
            model_name=config.get("embedding_model", "text-embedding-ada-002")
        )
    
    def process_document(self, file_path: str, 
                        metadata: Dict = None) -> Dict:
        document = self.loader.load(file_path)
        
        document["content"] = self.preprocessor.preprocess(
            document["content"]
        )
        
        if metadata:
            document["metadata"].update(metadata)
        
        chunks = self.splitter.split(document["content"])
        
        processed_chunks = []
        for i, chunk in enumerate(chunks):
            chunk_metadata = {
                **document["metadata"],
                "chunk_id": i,
                "total_chunks": len(chunks)
            }
            
            embedding = self.embedding_generator.generate(chunk["content"])
            
            processed_chunks.append({
                "content": chunk["content"],
                "embedding": embedding,
                "metadata": chunk_metadata
            })
        
        return {
            "document_id": self._generate_document_id(file_path),
            "chunks": processed_chunks,
            "metadata": document["metadata"]
        }
    
    def _generate_document_id(self, file_path: str) -> str:
        path = Path(file_path)
        return f"{path.stem}_{int(time.time())}"

2. 检索器

python
class HybridRetriever:
    def __init__(self, vector_db, keyword_index, 
                 reranker, config: Dict):
        self.vector_db = vector_db
        self.keyword_index = keyword_index
        self.reranker = reranker
        self.config = config
    
    def retrieve(self, query: str, 
                 filters: Dict = None,
                 top_k: int = 10) -> List[Dict]:
        vector_results = self._vector_search(query, top_k * 2)
        keyword_results = self._keyword_search(query, top_k * 2)
        
        combined_results = self._combine_results(
            vector_results,
            keyword_results
        )
        
        if filters:
            combined_results = self._apply_filters(
                combined_results,
                filters
            )
        
        reranked_results = self.reranker.rerank(query, combined_results)
        
        return reranked_results[:top_k]
    
    def _vector_search(self, query: str, 
                        top_k: int) -> List[Dict]:
        query_embedding = self._generate_embedding(query)
        return self.vector_db.search(query_embedding, top_k)
    
    def _keyword_search(self, query: str, 
                         top_k: int) -> List[Dict]:
        return self.keyword_index.search(query, top_k)
    
    def _combine_results(self, vector_results: List[Dict],
                         keyword_results: List[Dict]) -> List[Dict]:
        combined = {}
        
        for i, result in enumerate(vector_results):
            doc_id = result.get("id", f"vec_{i}")
            score = (1 - i / len(vector_results)) * 0.6
            combined[doc_id] = combined.get(doc_id, 0) + score
        
        for i, result in enumerate(keyword_results):
            doc_id = result.get("id", f"kw_{i}")
            score = (1 - i / len(keyword_results)) * 0.4
            combined[doc_id] = combined.get(doc_id, 0) + score
        
        sorted_results = sorted(
            combined.items(),
            key=lambda x: x[1],
            reverse=True
        )
        
        return [
            {"id": doc_id, "score": score}
            for doc_id, score in sorted_results
        ]
    
    def _apply_filters(self, results: List[Dict],
                       filters: Dict) -> List[Dict]:
        filtered = []
        
        for result in results:
            if self._matches_filters(result, filters):
                filtered.append(result)
        
        return filtered

3. RAG引擎

python
class RAGEngine:
    def __init__(self, retriever, llm, 
                 prompt_template, config: Dict):
        self.retriever = retriever
        self.llm = llm
        self.prompt_template = prompt_template
        self.config = config
    
    def query(self, question: str, 
              filters: Dict = None,
              top_k: int = 5) -> Dict:
        retrieved_docs = self.retriever.retrieve(
            question,
            filters=filters,
            top_k=top_k
        )
        
        if not retrieved_docs:
            return {
                "question": question,
                "answer": "I couldn't find relevant information in the knowledge base.",
                "sources": [],
                "confidence": "low"
            }
        
        context = self._build_context(retrieved_docs)
        
        answer = self._generate_answer(question, context)
        
        confidence = self._calculate_confidence(retrieved_docs)
        
        return {
            "question": question,
            "answer": answer,
            "sources": [
                {
                    "id": doc.get("id"),
                    "content": doc.get("content", "")[:200],
                    "metadata": doc.get("metadata", {}),
                    "score": doc.get("score", 0)
                }
                for doc in retrieved_docs
            ],
            "confidence": confidence
        }
    
    def _build_context(self, documents: List[Dict]) -> str:
        context_parts = []
        
        for i, doc in enumerate(documents):
            content = doc.get("content", "")
            metadata = doc.get("metadata", {})
            
            context_parts.append(
                f"Document {i+1}:\n"
                f"Source: {metadata.get('filename', 'Unknown')}\n"
                f"Content: {content}\n"
            )
        
        return "\n\n".join(context_parts)
    
    def _generate_answer(self, question: str, 
                        context: str) -> str:
        prompt = self.prompt_template.format(
            question=question,
            context=context
        )
        
        response = self.llm.generate(prompt)
        
        return response
    
    def _calculate_confidence(self, 
                             documents: List[Dict]) -> str:
        if not documents:
            return "low"
        
        avg_score = sum(
            doc.get("score", 0) for doc in documents
        ) / len(documents)
        
        if avg_score > 0.8:
            return "high"
        elif avg_score > 0.5:
            return "medium"
        else:
            return "low"

4. 对话管理器

python
class ChatManager:
    def __init__(self, rag_engine, llm, 
                 max_history: int = 10):
        self.rag_engine = rag_engine
        self.llm = llm
        self.conversations = {}
        self.max_history = max_history
    
    def create_conversation(self, user_id: str) -> str:
        conversation_id = self._generate_conversation_id()
        
        self.conversations[conversation_id] = {
            "user_id": user_id,
            "messages": [],
            "created_at": time.time()
        }
        
        return conversation_id
    
    def send_message(self, conversation_id: str, 
                    message: str) -> Dict:
        if conversation_id not in self.conversations:
            raise ValueError("Conversation not found")
        
        conversation = self.conversations[conversation_id]
        
        conversation["messages"].append({
            "role": "user",
            "content": message,
            "timestamp": time.time()
        })
        
        context = self._build_conversation_context(
            conversation["messages"]
        )
        
        enhanced_message = self._enhance_message(
            message,
            context
        )
        
        result = self.rag_engine.query(enhanced_message)
        
        answer = result["answer"]
        
        conversation["messages"].append({
            "role": "assistant",
            "content": answer,
            "timestamp": time.time()
        })
        
        if len(conversation["messages"]) > self.max_history * 2:
            conversation["messages"] = conversation["messages"][-self.max_history * 2:]
        
        return {
            "conversation_id": conversation_id,
            "answer": answer,
            "sources": result["sources"],
            "confidence": result["confidence"]
        }
    
    def _build_conversation_context(self, 
                                   messages: List[Dict]) -> str:
        recent_messages = messages[-self.max_history:]
        
        context_parts = []
        for msg in recent_messages:
            role = msg["role"]
            content = msg["content"]
            context_parts.append(f"{role}: {content}")
        
        return "\n".join(context_parts)
    
    def _enhance_message(self, message: str, 
                        context: str) -> str:
        if not context:
            return message
        
        prompt = f"""
        Previous conversation:
        {context}
        
        Current message: {message}
        
        Rewrite the message to include relevant context from the conversation.
        """
        
        enhanced_message = self.llm.generate(prompt)
        
        return enhanced_message
    
    def _generate_conversation_id(self) -> str:
        return f"conv_{int(time.time())}_{uuid.uuid4().hex[:8]}"

API实现

文档管理API

python
from fastapi import FastAPI, UploadFile, File, HTTPException
from typing import List

app = FastAPI()

@app.post("/api/documents/upload")
async def upload_document(
    file: UploadFile = File(...),
    metadata: str = None
):
    try:
        import json
        
        meta_dict = json.loads(metadata) if metadata else {}
        
        temp_path = f"/tmp/{file.filename}"
        with open(temp_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        
        processor = DocumentProcessor(config)
        result = processor.process_document(temp_path, meta_dict)
        
        for chunk in result["chunks"]:
            vector_db.add(
                chunk["embedding"],
                chunk["metadata"]
            )
        
        return {
            "status": "success",
            "document_id": result["document_id"],
            "chunks_processed": len(result["chunks"])
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/api/documents/{document_id}")
async def get_document(document_id: str):
    try:
        document = document_db.get(document_id)
        
        if not document:
            raise HTTPException(status_code=404, detail="Document not found")
        
        return document
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.delete("/api/documents/{document_id}")
async def delete_document(document_id: str):
    try:
        vector_db.delete_by_document_id(document_id)
        document_db.delete(document_id)
        
        return {"status": "success"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

问答API

python
@app.post("/api/qa/query")
async def query_qa(request: QARequest):
    try:
        rag_engine = RAGEngine(
            retriever=retriever,
            llm=llm,
            prompt_template=prompt_template,
            config=config
        )
        
        result = rag_engine.query(
            question=request.question,
            filters=request.filters,
            top_k=request.top_k
        )
        
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

class QARequest(BaseModel):
    question: str
    filters: Dict = None
    top_k: int = 5

对话API

python
@app.post("/api/chat/conversations")
async def create_conversation(user_id: str):
    try:
        chat_manager = ChatManager(rag_engine, llm)
        conversation_id = chat_manager.create_conversation(user_id)
        
        return {
            "conversation_id": conversation_id
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/chat/{conversation_id}/messages")
async def send_message(conversation_id: str, message: str):
    try:
        chat_manager = ChatManager(rag_engine, llm)
        result = chat_manager.send_message(conversation_id, message)
        
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

部署配置

Docker Compose

yaml
version: '3.8'

services:
  backend:
    build:
      context: ./backend
      dockerfile: Dockerfile
    ports:
      - "8000:8000"
    environment:
      - DATABASE_URL=postgresql://user:password@postgres:5432/ekb
      - REDIS_URL=redis://redis:6379
      - QDRANT_URL=http://qdrant:6333
    depends_on:
      - postgres
      - redis
      - qdrant
  
  frontend:
    build:
      context: ./frontend
      dockerfile: Dockerfile
    ports:
      - "3000:3000"
    depends_on:
      - backend
  
  postgres:
    image: postgres:15
    environment:
      - POSTGRES_USER=user
      - POSTGRES_PASSWORD=password
      - POSTGRES_DB=ekb
    volumes:
      - postgres_data:/var/lib/postgresql/data
  
  redis:
    image: redis:7
    volumes:
      - redis_data:/data
  
  qdrant:
    image: qdrant/qdrant:latest
    ports:
      - "6333:6333"
    volumes:
      - qdrant_data:/qdrant/storage

volumes:
  postgres_data:
  redis_data:
  qdrant_data:

性能优化

  1. 缓存策略

    • Redis缓存常见查询
    • 文档向量缓存
    • 对话历史缓存
  2. 批量处理

    • 批量文档上传
    • 批量向量插入
    • 批量查询处理
  3. 并行处理

    • 并行文档处理
    • 并行向量检索
    • 异步API响应

实践练习

练习1:实现完整的RAG系统

python
class CompleteRAGSystem:
    def __init__(self, config):
        self.document_processor = DocumentProcessor(config)
        self.retriever = HybridRetriever(
            vector_db,
            keyword_index,
            reranker,
            config
        )
        self.rag_engine = RAGEngine(
            self.retriever,
            llm,
            prompt_template,
            config
        )
        self.chat_manager = ChatManager(
            self.rag_engine,
            llm
        )
    
    def add_document(self, file_path, metadata=None):
        return self.document_processor.process_document(
            file_path,
            metadata
        )
    
    def query(self, question, filters=None, top_k=5):
        return self.rag_engine.query(question, filters, top_k)
    
    def chat(self, conversation_id, message):
        return self.chat_manager.send_message(conversation_id, message)

总结

本节我们完成了RAG模块的学习:

  1. 总结了RAG模块核心知识
  2. 掌握了企业知识库系统项目
  3. 完成了项目架构设计
  4. 实现了核心功能(文档处理、检索、RAG、对话)
  5. 提供了部署配置和优化建议

通过这个项目,你将掌握构建生产级RAG系统的完整流程。

参考资源