Skip to content

第83天:企业知识库-文档处理系统

学习目标

  • 掌握文档摄取实现
  • 学习文档解析实现
  • 理解文档处理实现
  • 掌握向量化实现
  • 学习文档存储实现

文档摄取实现

文档上传

python
from fastapi import UploadFile, HTTPException
from typing import Dict, Optional
import os
import hashlib
from datetime import datetime

class DocumentUploader:
    def __init__(self, upload_dir: str = "uploads"):
        self.upload_dir = upload_dir
        os.makedirs(upload_dir, exist_ok=True)
        self.allowed_extensions = {
            ".pdf", ".doc", ".docx", ".xls", ".xlsx",
            ".ppt", ".pptx", ".txt", ".html", ".md"
        }
    
    async def upload_document(
        self,
        file: UploadFile,
        user_id: str,
        metadata: Optional[Dict] = None
    ) -> Dict:
        if not self._is_allowed_file(file.filename):
            raise HTTPException(
                status_code=400,
                detail=f"不支持的文件格式: {file.filename}"
            )
        
        file_hash = await self._calculate_file_hash(file)
        file_path = self._get_file_path(file_hash, file.filename)
        
        try:
            await self._save_file(file, file_path)
            
            document_info = {
                "file_id": file_hash,
                "filename": file.filename,
                "file_path": file_path,
                "file_size": os.path.getsize(file_path),
                "upload_time": datetime.now().isoformat(),
                "user_id": user_id,
                "metadata": metadata or {}
            }
            
            return document_info
        
        except Exception as e:
            if os.path.exists(file_path):
                os.remove(file_path)
            
            raise HTTPException(
                status_code=500,
                detail=f"文件上传失败: {str(e)}"
            )
    
    def _is_allowed_file(self, filename: str) -> bool:
        return any(
            filename.lower().endswith(ext)
            for ext in self.allowed_extensions
        )
    
    async def _calculate_file_hash(self, file: UploadFile) -> str:
        hash_sha256 = hashlib.sha256()
        
        while chunk := await file.read(8192):
            hash_sha256.update(chunk)
        
        await file.seek(0)
        
        return hash_sha256.hexdigest()
    
    def _get_file_path(self, file_hash: str, filename: str) -> str:
        ext = os.path.splitext(filename)[1]
        return os.path.join(self.upload_dir, f"{file_hash}{ext}")
    
    async def _save_file(self, file: UploadFile, file_path: str):
        with open(file_path, "wb") as buffer:
            while chunk := await file.read(8192):
                buffer.write(chunk)

文档解析

python
from typing import Dict, Optional
import PyPDF2
import docx
import openpyxl
from pptx import Presentation
from bs4 import BeautifulSoup
import markdown

class DocumentParser:
    def __init__(self):
        self.parsers = {
            ".pdf": self._parse_pdf,
            ".doc": self._parse_doc,
            ".docx": self._parse_docx,
            ".xls": self._parse_xls,
            ".xlsx": self._parse_xlsx,
            ".ppt": self._parse_ppt,
            ".pptx": self._parse_pptx,
            ".txt": self._parse_txt,
            ".html": self._parse_html,
            ".md": self._parse_markdown
        }
    
    async def parse_document(
        self,
        file_path: str,
        filename: str
    ) -> Dict:
        ext = self._get_file_extension(filename)
        
        if ext not in self.parsers:
            raise ValueError(f"不支持的文件格式: {ext}")
        
        try:
            parser = self.parsers[ext]
            result = await parser(file_path)
            
            return {
                "success": True,
                "text": result["text"],
                "metadata": result.get("metadata", {}),
                "pages": result.get("pages", 1)
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def _get_file_extension(self, filename: str) -> str:
        return os.path.splitext(filename)[1].lower()
    
    async def _parse_pdf(self, file_path: str) -> Dict:
        text = []
        
        with open(file_path, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            
            for page in pdf_reader.pages:
                text.append(page.extract_text())
        
        return {
            "text": "\n\n".join(text),
            "metadata": {
                "pages": len(pdf_reader.pages)
            },
            "pages": len(pdf_reader.pages)
        }
    
    async def _parse_docx(self, file_path: str) -> Dict:
        doc = docx.Document(file_path)
        
        text = []
        for paragraph in doc.paragraphs:
            text.append(paragraph.text)
        
        return {
            "text": "\n".join(text),
            "metadata": {
                "author": doc.core_properties.author,
                "created": str(doc.core_properties.created),
                "modified": str(doc.core_properties.modified)
            }
        }
    
    async def _parse_xlsx(self, file_path: str) -> Dict:
        workbook = openpyxl.load_workbook(file_path)
        
        text = []
        for sheet_name in workbook.sheetnames:
            sheet = workbook[sheet_name]
            
            for row in sheet.iter_rows(values_only=True):
                row_text = "\t".join(
                    str(cell) if cell is not None else ""
                    for cell in row
                )
                text.append(row_text)
        
        return {
            "text": "\n".join(text),
            "metadata": {
                "sheets": len(workbook.sheetnames)
            }
        }
    
    async def _parse_pptx(self, file_path: str) -> Dict:
        prs = Presentation(file_path)
        
        text = []
        for slide in prs.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text"):
                    text.append(shape.text)
        
        return {
            "text": "\n\n".join(text),
            "metadata": {
                "slides": len(prs.slides)
            }
        }
    
    async def _parse_txt(self, file_path: str) -> Dict:
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        
        return {
            "text": text,
            "metadata": {}
        }
    
    async def _parse_html(self, file_path: str) -> Dict:
        with open(file_path, 'r', encoding='utf-8') as file:
            html = file.read()
        
        soup = BeautifulSoup(html, 'html.parser')
        text = soup.get_text(separator='\n', strip=True)
        
        return {
            "text": text,
            "metadata": {
                "title": soup.title.string if soup.title else ""
            }
        }
    
    async def _parse_markdown(self, file_path: str) -> Dict:
        with open(file_path, 'r', encoding='utf-8') as file:
            md = file.read()
        
        html = markdown.markdown(md)
        soup = BeautifulSoup(html, 'html.parser')
        text = soup.get_text(separator='\n', strip=True)
        
        return {
            "text": text,
            "metadata": {}
        }
    
    async def _parse_doc(self, file_path: str) -> Dict:
        raise NotImplementedError("旧版Word文档需要特殊处理")
    
    async def _parse_xls(self, file_path: str) -> Dict:
        raise NotImplementedError("旧版Excel文档需要特殊处理")
    
    async def _parse_ppt(self, file_path: str) -> Dict:
        raise NotImplementedError("旧版PowerPoint文档需要特殊处理")

文档处理实现

文本清洗

python
import re
from typing import List

class TextCleaner:
    def __init__(self):
        self.cleaning_rules = [
            self._remove_extra_whitespace,
            self._remove_special_chars,
            self._normalize_quotes,
            self._normalize_dashes,
            self._remove_page_numbers
        ]
    
    async def clean_text(self, text: str) -> str:
        cleaned_text = text
        
        for rule in self.cleaning_rules:
            cleaned_text = rule(cleaned_text)
        
        return cleaned_text
    
    def _remove_extra_whitespace(self, text: str) -> str:
        return re.sub(r'\s+', ' ', text).strip()
    
    def _remove_special_chars(self, text: str) -> str:
        return re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]', '', text)
    
    def _normalize_quotes(self, text: str) -> str:
        text = text.replace('"', '"').replace('"', '"')
        text = text.replace(''', "'").replace(''', "'")
        return text
    
    def _normalize_dashes(self, text: str) -> str:
        text = text.replace('–', '-').replace('—', '-')
        return text
    
    def _remove_page_numbers(self, text: str) -> str:
        return re.sub(r'\n\s*\d+\s*\n', '\n', text)

文本分块

python
from typing import List, Dict

class TextChunker:
    def __init__(
        self,
        chunk_size: int = 500,
        chunk_overlap: int = 50
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    async def chunk_text(
        self,
        text: str,
        metadata: Optional[Dict] = None
    ) -> List[Dict]:
        chunks = []
        
        sentences = self._split_into_sentences(text)
        
        current_chunk = ""
        chunk_index = 0
        
        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= self.chunk_size:
                current_chunk += sentence + " "
            else:
                if current_chunk:
                    chunks.append({
                        "text": current_chunk.strip(),
                        "chunk_index": chunk_index,
                        "metadata": metadata or {}
                    })
                    chunk_index += 1
                
                current_chunk = sentence + " "
        
        if current_chunk:
            chunks.append({
                "text": current_chunk.strip(),
                "chunk_index": chunk_index,
                "metadata": metadata or {}
            })
        
        return chunks
    
    def _split_into_sentences(self, text: str) -> List[str]:
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]
    
    async def chunk_by_paragraph(
        self,
        text: str,
        metadata: Optional[Dict] = None
    ) -> List[Dict]:
        paragraphs = text.split('\n\n')
        
        chunks = []
        chunk_index = 0
        
        for paragraph in paragraphs:
            paragraph = paragraph.strip()
            
            if not paragraph:
                continue
            
            if len(paragraph) > self.chunk_size:
                sub_chunks = await self.chunk_text(paragraph, metadata)
                chunks.extend(sub_chunks)
            else:
                chunks.append({
                    "text": paragraph,
                    "chunk_index": chunk_index,
                    "metadata": metadata or {}
                })
                chunk_index += 1
        
        return chunks

关键词提取

python
import jieba
import jieba.analyse
from typing import List

class KeywordExtractor:
    def __init__(self, top_k: int = 10):
        self.top_k = top_k
        jieba.initialize()
    
    async def extract_keywords(
        self,
        text: str,
        top_k: Optional[int] = None
    ) -> List[str]:
        top_k = top_k or self.top_k
        
        keywords = jieba.analyse.extract_tags(
            text,
            topK=top_k,
            withWeight=True
        )
        
        return [keyword for keyword, weight in keywords]
    
    async def extract_keywords_with_scores(
        self,
        text: str,
        top_k: Optional[int] = None
    ) -> List[Dict]:
        top_k = top_k or self.top_k
        
        keywords = jieba.analyse.extract_tags(
            text,
            topK=top_k,
            withWeight=True
        )
        
        return [
            {
                "keyword": keyword,
                "score": weight
            }
            for keyword, weight in keywords
        ]

摘要生成

python
import openai
from typing import Dict

class SummaryGenerator:
    def __init__(self, llm_client: openai.OpenAI):
        self.llm_client = llm_client
    
    async def generate_summary(
        self,
        text: str,
        max_length: int = 300
    ) -> Dict:
        prompt = f"""请为以下文本生成摘要,摘要长度不超过{max_length}字:

文本:
{text}

摘要:"""
        
        try:
            completion = self.llm_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "你是一个专业的摘要生成器"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.3,
                max_tokens=max_length * 2
            )
            
            summary = completion.choices[0].message.content
            
            return {
                "success": True,
                "summary": summary
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }

向量化实现

文本向量化

python
import openai
from typing import List, Dict
import numpy as np

class TextVectorizer:
    def __init__(
        self,
        llm_client: openai.OpenAI,
        model: str = "text-embedding-3-small"
    ):
        self.llm_client = llm_client
        self.model = model
    
    async def create_embedding(
        self,
        text: str
    ) -> List[float]:
        try:
            response = self.llm_client.embeddings.create(
                model=self.model,
                input=text
            )
            
            embedding = response.data[0].embedding
            
            return embedding
        
        except Exception as e:
            raise Exception(f"向量化失败: {str(e)}")
    
    async def create_embeddings(
        self,
        texts: List[str]
    ) -> List[List[float]]:
        embeddings = []
        
        for text in texts:
            embedding = await self.create_embedding(text)
            embeddings.append(embedding)
        
        return embeddings
    
    async def create_embeddings_batch(
        self,
        texts: List[str],
        batch_size: int = 100
    ) -> List[List[float]]:
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            
            try:
                response = self.llm_client.embeddings.create(
                    model=self.model,
                    input=batch
                )
                
                batch_embeddings = [
                    item.embedding
                    for item in response.data
                ]
                
                all_embeddings.extend(batch_embeddings)
            
            except Exception as e:
                raise Exception(f"批量向量化失败: {str(e)}")
        
        return all_embeddings
    
    def calculate_similarity(
        self,
        embedding1: List[float],
        embedding2: List[float]
    ) -> float:
        vec1 = np.array(embedding1)
        vec2 = np.array(embedding2)
        
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        similarity = dot_product / (norm1 * norm2)
        
        return float(similarity)

向量存储

python
import chromadb
from typing import List, Dict, Optional

class VectorStore:
    def __init__(
        self,
        collection_name: str = "documents",
        persist_directory: Optional[str] = None
    ):
        self.client = chromadb.Client(
            chromadb.config.Settings(
                persist_directory=persist_directory
            )
        )
        
        self.collection = self.client.get_or_create_collection(
            name=collection_name
        )
    
    async def add_documents(
        self,
        documents: List[Dict]
    ) -> Dict:
        ids = []
        embeddings = []
        texts = []
        metadatas = []
        
        for doc in documents:
            ids.append(doc["id"])
            embeddings.append(doc["embedding"])
            texts.append(doc["text"])
            metadatas.append(doc.get("metadata", {}))
        
        try:
            self.collection.add(
                ids=ids,
                embeddings=embeddings,
                documents=texts,
                metadatas=metadatas
            )
            
            return {
                "success": True,
                "count": len(ids)
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    async def search(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        filter_metadata: Optional[Dict] = None
    ) -> List[Dict]:
        try:
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=top_k,
                where=filter_metadata
            )
            
            documents = []
            
            for i in range(len(results["ids"][0])):
                documents.append({
                    "id": results["ids"][0][i],
                    "text": results["documents"][0][i],
                    "metadata": results["metadatas"][0][i],
                    "distance": results["distances"][0][i]
                })
            
            return documents
        
        except Exception as e:
            raise Exception(f"搜索失败: {str(e)}")
    
    async def delete_document(self, document_id: str) -> Dict:
        try:
            self.collection.delete(ids=[document_id])
            
            return {
                "success": True
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    async def update_document(
        self,
        document_id: str,
        embedding: List[float],
        text: str,
        metadata: Optional[Dict] = None
    ) -> Dict:
        try:
            self.collection.update(
                ids=[document_id],
                embeddings=[embedding],
                documents=[text],
                metadatas=[metadata or {}]
            )
            
            return {
                "success": True
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }

文档存储实现

文档存储

python
import sqlite3
from typing import Dict, List, Optional
from datetime import datetime
import json

class DocumentStorage:
    def __init__(self, db_path: str = "documents.db"):
        self.db_path = db_path
        self._initialize_db()
    
    def _initialize_db(self):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS documents (
                id TEXT PRIMARY KEY,
                filename TEXT NOT NULL,
                file_path TEXT NOT NULL,
                file_size INTEGER,
                upload_time TEXT,
                user_id TEXT,
                metadata TEXT,
                summary TEXT,
                keywords TEXT,
                status TEXT DEFAULT 'processing'
            )
        """)
        
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS chunks (
                id TEXT PRIMARY KEY,
                document_id TEXT NOT NULL,
                chunk_index INTEGER,
                text TEXT NOT NULL,
                embedding BLOB,
                metadata TEXT,
                FOREIGN KEY (document_id) REFERENCES documents (id)
            )
        """)
        
        conn.commit()
        conn.close()
    
    async def add_document(
        self,
        document: Dict
    ) -> Dict:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                INSERT INTO documents (
                    id, filename, file_path, file_size,
                    upload_time, user_id, metadata,
                    summary, keywords, status
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                document["file_id"],
                document["filename"],
                document["file_path"],
                document["file_size"],
                document["upload_time"],
                document["user_id"],
                json.dumps(document.get("metadata", {})),
                document.get("summary", ""),
                json.dumps(document.get("keywords", [])),
                document.get("status", "processing")
            ))
            
            conn.commit()
            
            return {
                "success": True,
                "document_id": document["file_id"]
            }
        
        except Exception as e:
            conn.rollback()
            return {
                "success": False,
                "error": str(e)
            }
        
        finally:
            conn.close()
    
    async def get_document(
        self,
        document_id: str
    ) -> Optional[Dict]:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                SELECT * FROM documents WHERE id = ?
            """, (document_id,))
            
            row = cursor.fetchone()
            
            if row:
                return {
                    "id": row[0],
                    "filename": row[1],
                    "file_path": row[2],
                    "file_size": row[3],
                    "upload_time": row[4],
                    "user_id": row[5],
                    "metadata": json.loads(row[6]),
                    "summary": row[7],
                    "keywords": json.loads(row[8]),
                    "status": row[9]
                }
            
            return None
        
        finally:
            conn.close()
    
    async def update_document_status(
        self,
        document_id: str,
        status: str
    ) -> Dict:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                UPDATE documents SET status = ?
                WHERE id = ?
            """, (status, document_id))
            
            conn.commit()
            
            return {
                "success": True
            }
        
        except Exception as e:
            conn.rollback()
            return {
                "success": False,
                "error": str(e)
            }
        
        finally:
            conn.close()
    
    async def add_chunk(
        self,
        chunk: Dict
    ) -> Dict:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                INSERT INTO chunks (
                    id, document_id, chunk_index, text,
                    embedding, metadata
                ) VALUES (?, ?, ?, ?, ?, ?)
            """, (
                chunk["id"],
                chunk["document_id"],
                chunk["chunk_index"],
                chunk["text"],
                json.dumps(chunk["embedding"]),
                json.dumps(chunk.get("metadata", {}))
            ))
            
            conn.commit()
            
            return {
                "success": True,
                "chunk_id": chunk["id"]
            }
        
        except Exception as e:
            conn.rollback()
            return {
                "success": False,
                "error": str(e)
            }
        
        finally:
            conn.close()

实践练习

练习1:实现文档摄取

python
def implement_document_ingestion():
    uploader = DocumentUploader()
    parser = DocumentParser()
    
    return uploader, parser

练习2:实现文档处理

python
def implement_document_processing():
    cleaner = TextCleaner()
    chunker = TextChunker()
    keyword_extractor = KeywordExtractor()
    
    return cleaner, chunker, keyword_extractor

练习3:实现向量化

python
def implement_vectorization():
    llm_client = openai.OpenAI(api_key="your-api-key")
    vectorizer = TextVectorizer(llm_client)
    vector_store = VectorStore()
    
    return vectorizer, vector_store

总结

本节我们学习了企业知识库的文档处理系统:

  1. 文档摄取实现
  2. 文档解析实现
  3. 文档处理实现
  4. 向量化实现
  5. 文档存储实现

文档处理是知识库的基础,需要高效、准确地处理各种格式的文档。

参考资源