Appearance
第57天:微调模块总结与项目
学习目标
- 总结微调模块核心知识
- 掌握领域专用模型项目
- 完成项目架构设计
- 实现核心功能
- 部署和优化系统
模块知识总结
微调核心概念
微调工作流程:
预训练模型 → 加载 → 数据准备 → 微调 → 评估 → 部署关键技术:
- 微调方法:全量微调、部分微调、指令微调
- 高效微调:LoRA、QLoRA、AdaLoRA
- 数据处理:收集、清洗、格式化、增强
- 训练优化:超参数调优、监控、评估
- 部署优化:量化、推理优化、监控
微调方法对比
| 方法 | 参数量 | 显存需求 | 训练速度 | 性能 | 适用场景 |
|---|---|---|---|---|---|
| 全量微调 | 100% | 高 | 慢 | 最高 | 大规模数据 |
| 部分微调 | 10-30% | 中 | 中 | 高 | 特定层微调 |
| LoRA | <1% | 低 | 快 | 高 | 参数高效 |
| QLoRA | <1% | 极低 | 快 | 高 | 显存受限 |
| AdaLoRA | <1% | 低 | 快 | 高 | 自适应 |
最佳实践
数据准备:
- 收集高质量领域数据
- 进行充分的数据清洗
- 使用数据增强扩充数据
- 评估数据质量
微调策略:
- 从小规模实验开始
- 使用LoRA/QLoRA降低成本
- 监控训练过程
- 充分评估模型性能
部署优化:
- 使用量化降低显存
- 实施缓存提高速度
- 监控系统性能
- 实施自动缩放
实战项目:领域专用模型
项目概述
项目名称:Domain-Specific Fine-tuned Model (DSFM)
项目描述:
构建一个针对特定领域(如医疗、法律、金融)的微调大模型,使其在领域任务上表现优异。
技术栈:
- 微调框架:PEFT + Transformers
- 微调方法:QLoRA
- 训练工具:DeepSpeed
- 评估工具:Weights & Biases
- 部署:vLLM + FastAPI
系统架构
┌─────────────────────────────────────────────────┐
│ 前端层 │
│ - 模型测试界面 │
│ - 性能监控面板 │
│ - 管理控制台 │
└───────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ API层 │
│ - 生成API │
│ - 评估API │
│ - 管理API │
└───────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ 服务层 │
│ - 模型加载器 │
│ - 生成引擎 │
│ - 评估引擎 │
│ - 缓存管理器 │
└───────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ 模型层 │
│ - 预训练模型 │
│ - LoRA适配器 │
│ - 量化模型 │
└─────────────────────────────────────────────────┘项目目录结构
domain-specific-model/
├── data/
│ ├── raw/
│ ├── processed/
│ └── augmented/
├── models/
│ ├── base/
│ ├── lora/
│ └── quantized/
├── training/
│ ├── config/
│ ├── scripts/
│ └── notebooks/
├── evaluation/
│ ├── metrics/
│ └── results/
├── deployment/
│ ├── api/
│ ├── docker/
│ └── kubernetes/
├── monitoring/
│ ├── logs/
│ └── metrics/
└── tests/
├── unit/
└── integration/核心功能实现
1. 数据准备器
python
from typing import List, Dict
from pathlib import Path
class DomainDataPreparer:
def __init__(self, domain: str, config: Dict):
self.domain = domain
self.config = config
self.collector = DataCollector(config)
self.cleaner = TextCleaner(config)
self.augmenter = DataAugmenter(config)
self.formatter = InstructionFormatter()
def prepare_data(self, sources: List[str]) -> List[str]:
collected = self._collect_data(sources)
cleaned = self._clean_data(collected)
augmented = self._augment_data(cleaned)
formatted = self._format_data(augmented)
return formatted
def _collect_data(self, sources: List[str]) -> List[Dict]:
collected = []
for source in sources:
if source.startswith("http"):
data = self.collector.collect_from_api(source)
else:
data = self.collector.collect_from_files(source)
collected.extend(data)
return collected
def _clean_data(self, data: List[Dict]) -> List[Dict]:
cleaned = []
for item in data:
text = str(item.get("data", item))
cleaned_text = self.cleaner.clean(text)
if len(cleaned_text.split()) >= self.config.get("min_length", 10):
cleaned.append({
"data": {"content": cleaned_text},
"metadata": item.get("metadata", {})
})
return cleaned
def _augment_data(self, data: List[Dict]) -> List[Dict]:
augmented = []
for item in data:
text = item["data"]["content"]
augmented_texts = self.augmenter.augment(
text,
n_augmentations=self.config.get("n_augmentations", 2)
)
for aug_text in augmented_texts:
augmented.append({
"data": {"content": aug_text},
"metadata": item["metadata"]
})
return augmented
def _format_data(self, data: List[Dict]) -> List[str]:
formatted = []
for item in data:
content = item["data"]["content"]
instruction = self._create_instruction(content)
formatted_item = {
"instruction": instruction,
"input": "",
"output": content
}
formatted_text = self.formatter.format(formatted_item)
formatted.append(formatted_text)
return formatted
def _create_instruction(self, content: str) -> str:
instructions = {
"medical": "Provide medical information:",
"legal": "Provide legal information:",
"financial": "Provide financial information:",
"general": "Provide information:"
}
return instructions.get(self.domain, instructions["general"])2. QLoRA微调器
python
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer
)
from peft import LoraConfig, get_peft_model
class QLoRATuner:
def __init__(self, model_name: str, domain: str, config: Dict):
self.model_name = model_name
self.domain = domain
self.config = config
self.model, self.tokenizer = self._load_model()
self.model = self._apply_lora()
def _load_model(self):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def _apply_lora(self):
lora_config = LoraConfig(
r=self.config.get("lora_rank", 8),
lora_alpha=self.config.get("lora_alpha", 32),
target_modules=self.config.get("target_modules", ["q_proj", "v_proj"]),
lora_dropout=self.config.get("lora_dropout", 0.05),
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(self.model, lora_config)
model.print_trainable_parameters()
return model
def fine_tune(self, train_dataset, eval_dataset):
training_args = TrainingArguments(
output_dir=self.config.get("output_dir", f"./output/{self.domain}"),
num_train_epochs=self.config.get("num_epochs", 3),
per_device_train_batch_size=self.config.get("batch_size", 4),
per_device_eval_batch_size=self.config.get("batch_size", 4),
gradient_accumulation_steps=self.config.get("gradient_accumulation_steps", 4),
learning_rate=self.config.get("learning_rate", 2e-4),
weight_decay=self.config.get("weight_decay", 0.01),
warmup_steps=self.config.get("warmup_steps", 500),
logging_steps=self.config.get("logging_steps", 10),
save_steps=self.config.get("save_steps", 500),
evaluation_strategy="steps",
eval_steps=self.config.get("eval_steps", 500),
fp16=True,
gradient_checkpointing=True,
optim="paged_adamw_32bit"
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.tokenizer
)
trainer.train()
return self.model
def save_model(self, output_dir: str):
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)3. 模型评估器
python
class DomainModelEvaluator:
def __init__(self, model, tokenizer, domain: str):
self.model = model
self.tokenizer = tokenizer
self.domain = domain
def evaluate(self, test_dataset) -> Dict:
metrics = {}
metrics["perplexity"] = self._calculate_perplexity(test_dataset)
metrics["generation_quality"] = self._evaluate_generation(test_dataset)
metrics["domain_accuracy"] = self._evaluate_domain_accuracy(test_dataset)
return metrics
def _calculate_perplexity(self, test_dataset) -> float:
self.model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in test_dataset:
outputs = self.model(**batch)
loss = outputs.loss
total_loss += loss.item() * batch["input_ids"].numel()
total_tokens += batch["input_ids"].numel()
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
def _evaluate_generation(self, test_dataset) -> Dict:
generations = []
for batch in test_dataset:
input_ids = batch["input_ids"]
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=50,
do_sample=True,
temperature=0.7
)
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
generations.append(generated_text)
metrics = {
"avg_length": self._calculate_avg_length(generations),
"diversity": self._calculate_diversity(generations),
"fluency": self._calculate_fluency(generations)
}
return metrics
def _evaluate_domain_accuracy(self, test_dataset) -> Dict:
correct = 0
total = 0
for batch in test_dataset:
input_ids = batch["input_ids"]
labels = batch["labels"]
with torch.no_grad():
outputs = self.model(input_ids)
predictions = torch.argmax(outputs.logits, dim=-1)
correct += (predictions == labels).sum().item()
total += labels.numel()
accuracy = correct / total if total > 0 else 0
return {"accuracy": accuracy}
def _calculate_avg_length(self, generations: List[str]) -> float:
lengths = [len(gen.split()) for gen in generations]
return sum(lengths) / len(lengths) if lengths else 0
def _calculate_diversity(self, generations: List[str]) -> float:
all_words = set()
for gen in generations:
words = gen.split()
all_words.update(words)
total_words = sum(len(gen.split()) for gen in generations)
return len(all_words) / total_words if total_words > 0 else 0
def _calculate_fluency(self, generations: List[str]) -> float:
fluency_scores = []
for gen in generations:
score = 0.0
if len(gen.split()) >= 5:
score += 0.4
if any(char in gen for char in '.!?'):
score += 0.3
if len(gen.split('\n')) > 1:
score += 0.3
fluency_scores.append(score)
return sum(fluency_scores) / len(fluency_scores) if fluency_scores else 04. 模型部署器
python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
app = FastAPI()
class GenerationRequest(BaseModel):
prompt: str
max_length: int = 100
temperature: float = 0.7
top_p: float = 0.9
class GenerationResponse(BaseModel):
generated_text: str
generation_time: float
class DomainModelServer:
def __init__(self, model_path: str, domain: str):
self.domain = domain
self.model, self.tokenizer = self._load_model(model_path)
self.model.eval()
self.cache = ResponseCache()
def _load_model(self, model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate(self, request: GenerationRequest) -> GenerationResponse:
cached_response = self.cache.get(request.prompt)
if cached_response:
return GenerationResponse(
generated_text=cached_response,
generation_time=0.0
)
import time
start_time = time.time()
inputs = self.tokenizer(
request.prompt,
return_tensors="pt",
padding=True,
truncation=True
)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True
)
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
generation_time = time.time() - start_time
self.cache.set(request.prompt, generated_text)
return GenerationResponse(
generated_text=generated_text,
generation_time=generation_time
)
model_server = DomainModelServer(
"models/quantized/medical-gpt",
"medical"
)
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
try:
return model_server.generate(request)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "domain": model_server.domain}部署配置
Docker Compose
yaml
version: '3.8'
services:
model-server:
build:
context: ./deployment/docker
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
- MODEL_PATH=models/quantized/medical-gpt
- DOMAIN=medical
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]Kubernetes
yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: domain-model-server
spec:
replicas: 2
selector:
matchLabels:
app: domain-model-server
template:
metadata:
labels:
app: domain-model-server
spec:
containers:
- name: model-server
image: domain-model-server:latest
ports:
- containerPort: 8000
resources:
requests:
nvidia.com/gpu: 1
memory: "8Gi"
limits:
nvidia.com/gpu: 1
memory: "16Gi"
env:
- name: MODEL_PATH
value: "models/quantized/medical-gpt"
- name: DOMAIN
value: "medical"实践练习
练习1:实现完整的微调流程
python
class CompleteFineTuningPipeline:
def __init__(self, domain: str, config):
self.domain = domain
self.config = config
self.data_preparer = DomainDataPreparer(domain, config)
self.tuner = QLoRATuner(config["model_name"], domain, config)
self.evaluator = DomainModelEvaluator(self.tuner.model, self.tuner.tokenizer, domain)
def run(self, sources: List[str], train_data, eval_data):
print("Preparing data...")
formatted_data = self.data_preparer.prepare_data(sources)
print("Fine-tuning model...")
model = self.tuner.fine_tune(train_data, eval_data)
print("Evaluating model...")
metrics = self.evaluator.evaluate(eval_data)
print("Saving model...")
self.tuner.save_model(f"./models/{self.domain}")
return metrics总结
本节我们完成了微调模块的学习:
- 总结了微调模块核心知识
- 掌握了领域专用模型项目
- 完成了项目架构设计
- 实现了核心功能(数据准备、QLoRA微调、评估、部署)
- 提供了部署配置
通过这个项目,你将掌握构建领域专用微调模型的完整流程。
