Skip to content

第57天:微调模块总结与项目

学习目标

  • 总结微调模块核心知识
  • 掌握领域专用模型项目
  • 完成项目架构设计
  • 实现核心功能
  • 部署和优化系统

模块知识总结

微调核心概念

微调工作流程

预训练模型 → 加载 → 数据准备 → 微调 → 评估 → 部署

关键技术

  1. 微调方法:全量微调、部分微调、指令微调
  2. 高效微调:LoRA、QLoRA、AdaLoRA
  3. 数据处理:收集、清洗、格式化、增强
  4. 训练优化:超参数调优、监控、评估
  5. 部署优化:量化、推理优化、监控

微调方法对比

方法参数量显存需求训练速度性能适用场景
全量微调100%最高大规模数据
部分微调10-30%特定层微调
LoRA<1%参数高效
QLoRA<1%极低显存受限
AdaLoRA<1%自适应

最佳实践

  1. 数据准备

    • 收集高质量领域数据
    • 进行充分的数据清洗
    • 使用数据增强扩充数据
    • 评估数据质量
  2. 微调策略

    • 从小规模实验开始
    • 使用LoRA/QLoRA降低成本
    • 监控训练过程
    • 充分评估模型性能
  3. 部署优化

    • 使用量化降低显存
    • 实施缓存提高速度
    • 监控系统性能
    • 实施自动缩放

实战项目:领域专用模型

项目概述

项目名称:Domain-Specific Fine-tuned Model (DSFM)

项目描述

构建一个针对特定领域(如医疗、法律、金融)的微调大模型,使其在领域任务上表现优异。

技术栈

  • 微调框架:PEFT + Transformers
  • 微调方法:QLoRA
  • 训练工具:DeepSpeed
  • 评估工具:Weights & Biases
  • 部署:vLLM + FastAPI

系统架构

┌─────────────────────────────────────────────────┐
│                   前端层                      │
│  - 模型测试界面                                 │
│  - 性能监控面板                                 │
│  - 管理控制台                                   │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   API层                        │
│  - 生成API                                     │
│  - 评估API                                     │
│  - 管理API                                     │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   服务层                       │
│  - 模型加载器                                   │
│  - 生成引擎                                     │
│  - 评估引擎                                     │
│  - 缓存管理器                                   │
└───────────────────┬─────────────────────────────┘


┌─────────────────────────────────────────────────┐
│                   模型层                       │
│  - 预训练模型                                   │
│  - LoRA适配器                                   │
│  - 量化模型                                     │
└─────────────────────────────────────────────────┘

项目目录结构

domain-specific-model/
├── data/
│   ├── raw/
│   ├── processed/
│   └── augmented/
├── models/
│   ├── base/
│   ├── lora/
│   └── quantized/
├── training/
│   ├── config/
│   ├── scripts/
│   └── notebooks/
├── evaluation/
│   ├── metrics/
│   └── results/
├── deployment/
│   ├── api/
│   ├── docker/
│   └── kubernetes/
├── monitoring/
│   ├── logs/
│   └── metrics/
└── tests/
    ├── unit/
    └── integration/

核心功能实现

1. 数据准备器

python
from typing import List, Dict
from pathlib import Path

class DomainDataPreparer:
    def __init__(self, domain: str, config: Dict):
        self.domain = domain
        self.config = config
        self.collector = DataCollector(config)
        self.cleaner = TextCleaner(config)
        self.augmenter = DataAugmenter(config)
        self.formatter = InstructionFormatter()
    
    def prepare_data(self, sources: List[str]) -> List[str]:
        collected = self._collect_data(sources)
        cleaned = self._clean_data(collected)
        augmented = self._augment_data(cleaned)
        formatted = self._format_data(augmented)
        
        return formatted
    
    def _collect_data(self, sources: List[str]) -> List[Dict]:
        collected = []
        
        for source in sources:
            if source.startswith("http"):
                data = self.collector.collect_from_api(source)
            else:
                data = self.collector.collect_from_files(source)
            
            collected.extend(data)
        
        return collected
    
    def _clean_data(self, data: List[Dict]) -> List[Dict]:
        cleaned = []
        
        for item in data:
            text = str(item.get("data", item))
            cleaned_text = self.cleaner.clean(text)
            
            if len(cleaned_text.split()) >= self.config.get("min_length", 10):
                cleaned.append({
                    "data": {"content": cleaned_text},
                    "metadata": item.get("metadata", {})
                })
        
        return cleaned
    
    def _augment_data(self, data: List[Dict]) -> List[Dict]:
        augmented = []
        
        for item in data:
            text = item["data"]["content"]
            augmented_texts = self.augmenter.augment(
                text,
                n_augmentations=self.config.get("n_augmentations", 2)
            )
            
            for aug_text in augmented_texts:
                augmented.append({
                    "data": {"content": aug_text},
                    "metadata": item["metadata"]
                })
        
        return augmented
    
    def _format_data(self, data: List[Dict]) -> List[str]:
        formatted = []
        
        for item in data:
            content = item["data"]["content"]
            
            instruction = self._create_instruction(content)
            formatted_item = {
                "instruction": instruction,
                "input": "",
                "output": content
            }
            
            formatted_text = self.formatter.format(formatted_item)
            formatted.append(formatted_text)
        
        return formatted
    
    def _create_instruction(self, content: str) -> str:
        instructions = {
            "medical": "Provide medical information:",
            "legal": "Provide legal information:",
            "financial": "Provide financial information:",
            "general": "Provide information:"
        }
        
        return instructions.get(self.domain, instructions["general"])

2. QLoRA微调器

python
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model

class QLoRATuner:
    def __init__(self, model_name: str, domain: str, config: Dict):
        self.model_name = model_name
        self.domain = domain
        self.config = config
        self.model, self.tokenizer = self._load_model()
        self.model = self._apply_lora()
    
    def _load_model(self):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=bnb_config,
            device_map="auto"
        )
        
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        tokenizer.pad_token = tokenizer.eos_token
        
        return model, tokenizer
    
    def _apply_lora(self):
        lora_config = LoraConfig(
            r=self.config.get("lora_rank", 8),
            lora_alpha=self.config.get("lora_alpha", 32),
            target_modules=self.config.get("target_modules", ["q_proj", "v_proj"]),
            lora_dropout=self.config.get("lora_dropout", 0.05),
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        model = get_peft_model(self.model, lora_config)
        model.print_trainable_parameters()
        
        return model
    
    def fine_tune(self, train_dataset, eval_dataset):
        training_args = TrainingArguments(
            output_dir=self.config.get("output_dir", f"./output/{self.domain}"),
            num_train_epochs=self.config.get("num_epochs", 3),
            per_device_train_batch_size=self.config.get("batch_size", 4),
            per_device_eval_batch_size=self.config.get("batch_size", 4),
            gradient_accumulation_steps=self.config.get("gradient_accumulation_steps", 4),
            learning_rate=self.config.get("learning_rate", 2e-4),
            weight_decay=self.config.get("weight_decay", 0.01),
            warmup_steps=self.config.get("warmup_steps", 500),
            logging_steps=self.config.get("logging_steps", 10),
            save_steps=self.config.get("save_steps", 500),
            evaluation_strategy="steps",
            eval_steps=self.config.get("eval_steps", 500),
            fp16=True,
            gradient_checkpointing=True,
            optim="paged_adamw_32bit"
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer
        )
        
        trainer.train()
        
        return self.model
    
    def save_model(self, output_dir: str):
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)

3. 模型评估器

python
class DomainModelEvaluator:
    def __init__(self, model, tokenizer, domain: str):
        self.model = model
        self.tokenizer = tokenizer
        self.domain = domain
    
    def evaluate(self, test_dataset) -> Dict:
        metrics = {}
        
        metrics["perplexity"] = self._calculate_perplexity(test_dataset)
        metrics["generation_quality"] = self._evaluate_generation(test_dataset)
        metrics["domain_accuracy"] = self._evaluate_domain_accuracy(test_dataset)
        
        return metrics
    
    def _calculate_perplexity(self, test_dataset) -> float:
        self.model.eval()
        
        total_loss = 0
        total_tokens = 0
        
        with torch.no_grad():
            for batch in test_dataset:
                outputs = self.model(**batch)
                loss = outputs.loss
                total_loss += loss.item() * batch["input_ids"].numel()
                total_tokens += batch["input_ids"].numel()
        
        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)
        
        return perplexity
    
    def _evaluate_generation(self, test_dataset) -> Dict:
        generations = []
        
        for batch in test_dataset:
            input_ids = batch["input_ids"]
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids,
                    max_new_tokens=50,
                    do_sample=True,
                    temperature=0.7
                )
            
            generated_text = self.tokenizer.decode(
                outputs[0],
                skip_special_tokens=True
            )
            
            generations.append(generated_text)
        
        metrics = {
            "avg_length": self._calculate_avg_length(generations),
            "diversity": self._calculate_diversity(generations),
            "fluency": self._calculate_fluency(generations)
        }
        
        return metrics
    
    def _evaluate_domain_accuracy(self, test_dataset) -> Dict:
        correct = 0
        total = 0
        
        for batch in test_dataset:
            input_ids = batch["input_ids"]
            labels = batch["labels"]
            
            with torch.no_grad():
                outputs = self.model(input_ids)
                predictions = torch.argmax(outputs.logits, dim=-1)
            
            correct += (predictions == labels).sum().item()
            total += labels.numel()
        
        accuracy = correct / total if total > 0 else 0
        
        return {"accuracy": accuracy}
    
    def _calculate_avg_length(self, generations: List[str]) -> float:
        lengths = [len(gen.split()) for gen in generations]
        return sum(lengths) / len(lengths) if lengths else 0
    
    def _calculate_diversity(self, generations: List[str]) -> float:
        all_words = set()
        
        for gen in generations:
            words = gen.split()
            all_words.update(words)
        
        total_words = sum(len(gen.split()) for gen in generations)
        
        return len(all_words) / total_words if total_words > 0 else 0
    
    def _calculate_fluency(self, generations: List[str]) -> float:
        fluency_scores = []
        
        for gen in generations:
            score = 0.0
            
            if len(gen.split()) >= 5:
                score += 0.4
            
            if any(char in gen for char in '.!?'):
                score += 0.3
            
            if len(gen.split('\n')) > 1:
                score += 0.3
            
            fluency_scores.append(score)
        
        return sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0

4. 模型部署器

python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch

app = FastAPI()

class GenerationRequest(BaseModel):
    prompt: str
    max_length: int = 100
    temperature: float = 0.7
    top_p: float = 0.9

class GenerationResponse(BaseModel):
    generated_text: str
    generation_time: float

class DomainModelServer:
    def __init__(self, model_path: str, domain: str):
        self.domain = domain
        self.model, self.tokenizer = self._load_model(model_path)
        self.model.eval()
        self.cache = ResponseCache()
    
    def _load_model(self, model_path: str):
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        model = AutoModelForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        return model, tokenizer
    
    def generate(self, request: GenerationRequest) -> GenerationResponse:
        cached_response = self.cache.get(request.prompt)
        
        if cached_response:
            return GenerationResponse(
                generated_text=cached_response,
                generation_time=0.0
            )
        
        import time
        start_time = time.time()
        
        inputs = self.tokenizer(
            request.prompt,
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=request.max_length,
                temperature=request.temperature,
                top_p=request.top_p,
                do_sample=True
            )
        
        generated_text = self.tokenizer.decode(
            outputs[0],
            skip_special_tokens=True
        )
        
        generation_time = time.time() - start_time
        
        self.cache.set(request.prompt, generated_text)
        
        return GenerationResponse(
            generated_text=generated_text,
            generation_time=generation_time
        )

model_server = DomainModelServer(
    "models/quantized/medical-gpt",
    "medical"
)

@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
    try:
        return model_server.generate(request)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "domain": model_server.domain}

部署配置

Docker Compose

yaml
version: '3.8'

services:
  model-server:
    build:
      context: ./deployment/docker
      dockerfile: Dockerfile
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=models/quantized/medical-gpt
      - DOMAIN=medical
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

Kubernetes

yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: domain-model-server
spec:
  replicas: 2
  selector:
    matchLabels:
      app: domain-model-server
  template:
    metadata:
      labels:
        app: domain-model-server
    spec:
      containers:
      - name: model-server
        image: domain-model-server:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            nvidia.com/gpu: 1
            memory: "8Gi"
          limits:
            nvidia.com/gpu: 1
            memory: "16Gi"
        env:
        - name: MODEL_PATH
          value: "models/quantized/medical-gpt"
        - name: DOMAIN
          value: "medical"

实践练习

练习1:实现完整的微调流程

python
class CompleteFineTuningPipeline:
    def __init__(self, domain: str, config):
        self.domain = domain
        self.config = config
        self.data_preparer = DomainDataPreparer(domain, config)
        self.tuner = QLoRATuner(config["model_name"], domain, config)
        self.evaluator = DomainModelEvaluator(self.tuner.model, self.tuner.tokenizer, domain)
    
    def run(self, sources: List[str], train_data, eval_data):
        print("Preparing data...")
        formatted_data = self.data_preparer.prepare_data(sources)
        
        print("Fine-tuning model...")
        model = self.tuner.fine_tune(train_data, eval_data)
        
        print("Evaluating model...")
        metrics = self.evaluator.evaluate(eval_data)
        
        print("Saving model...")
        self.tuner.save_model(f"./models/{self.domain}")
        
        return metrics

总结

本节我们完成了微调模块的学习:

  1. 总结了微调模块核心知识
  2. 掌握了领域专用模型项目
  3. 完成了项目架构设计
  4. 实现了核心功能(数据准备、QLoRA微调、评估、部署)
  5. 提供了部署配置

通过这个项目,你将掌握构建领域专用微调模型的完整流程。

参考资源