Appearance
第53天:LoRA与QLoRA
学习目标
- 掌握LoRA原理
- 学习QLoRA技术
- 理解参数高效微调
- 掌握PEFT方法对比
- 了解实际应用
LoRA原理
什么是LoRA
LoRA(Low-Rank Adaptation)是一种参数高效的微调方法,通过在预训练模型的权重矩阵上添加低秩矩阵来实现微调,大大减少了需要训练的参数数量。
核心思想:
原始权重 W → W + ΔW = W + BA其中:
- W:预训练模型的权重矩阵
- B:低秩矩阵(d × r)
- A:低秩矩阵(r × d)
- r:秩(通常很小,如8、16、32)
优势:
- 参数效率:只训练少量参数(<1%)
- 显存节省:大幅降低显存需求
- 训练速度:训练速度显著提升
- 模型切换:可以轻松切换不同的LoRA适配器
LoRA实现
python
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
def __init__(self, in_features: int, out_features: int,
rank: int = 8, alpha: float = 32.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.alpha = alpha
self.lora_A = nn.Parameter(torch.randn(rank, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
self.scaling = self.alpha / self.rank
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def forward(self, x: torch.Tensor) -> torch.Tensor:
lora_output = (x @ self.lora_A.T @ self.lora_B.T) * self.scaling
return lora_output
class LoRALinear(nn.Module):
def __init__(self, in_features: int, out_features: int,
rank: int = 8, alpha: float = 32.0,
use_lora: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.use_lora = use_lora
self.linear = nn.Linear(in_features, out_features)
if use_lora:
self.lora = LoRALayer(in_features, out_features, rank, alpha)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.linear(x)
if self.use_lora:
output = output + self.lora(x)
return outputLoRA应用
python
class LoRATuner:
def __init__(self, model, config):
self.model = model
self.config = config
self._replace_layers_with_lora()
def _replace_layers_with_lora(self):
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
if self._should_apply_lora(name):
self._apply_lora_to_module(module)
def _should_apply_lora(self, module_name: str) -> bool:
target_modules = self.config.get("target_modules", [])
for target in target_modules:
if target in module_name:
return True
return False
def _apply_lora_to_module(self, module: nn.Linear):
in_features = module.in_features
out_features = module.out_features
rank = self.config.get("lora_rank", 8)
alpha = self.config.get("lora_alpha", 32.0)
lora_layer = LoRALayer(in_features, out_features, rank, alpha)
original_forward = module.forward
def new_forward(x):
return original_forward(x) + lora_layer(x)
module.forward = new_forward
module.lora_layer = lora_layer
def get_lora_parameters(self):
lora_params = []
for name, param in self.model.named_parameters():
if "lora" in name:
lora_params.append(param)
return lora_params
def get_base_parameters(self):
base_params = []
for name, param in self.model.named_parameters():
if "lora" not in name:
base_params.append(param)
return base_paramsQLoRA技术
什么是QLoRA
QLoRA(Quantized LoRA)是在LoRA基础上引入量化技术的微调方法。通过将预训练模型量化为4位精度,进一步降低显存需求,使得在消费级GPU上微调大模型成为可能。
核心技术:
- 4-bit NormalFloat (NF4):优化的4位量化格式
- Double Quantization:对量化常数进行二次量化
- Paged Optimizers:使用CPU内存处理优化器状态
优势:
- 极低显存:显存需求降低50%以上
- 保持性能:微调后性能接近全精度
- 大模型微调:可在单卡上微调65B+模型
QLoRA实现
python
import torch
from transformers import BitsAndBytesConfig
class QLoRATuner:
def __init__(self, model_name: str, config):
self.config = config
self.model = self._load_quantized_model(model_name)
self._apply_lora()
def _load_quantized_model(self, model_name: str):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
return model
def _apply_lora(self):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=self.config.get("lora_rank", 8),
lora_alpha=self.config.get("lora_alpha", 32),
target_modules=self.config.get("target_modules", ["q_proj", "v_proj"]),
lora_dropout=self.config.get("lora_dropout", 0.05),
bias="none",
task_type="CAUSAL_LM"
)
self.model = get_peft_model(self.model, lora_config)
self.model.print_trainable_parameters()
def fine_tune(self, train_data, val_data):
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir=self.config.get("output_dir", "./output"),
num_train_epochs=self.config.get("num_epochs", 3),
per_device_train_batch_size=self.config.get("batch_size", 4),
per_device_eval_batch_size=self.config.get("batch_size", 4),
gradient_accumulation_steps=self.config.get("gradient_accumulation_steps", 4),
learning_rate=self.config.get("learning_rate", 2e-4),
fp16=True,
logging_steps=10,
save_steps=100,
evaluation_strategy="steps",
eval_steps=100
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data
)
trainer.train()
return self.modelQLoRA优化
python
class OptimizedQLoRATuner(QLoRATuner):
def __init__(self, model_name: str, config):
super().__init__(model_name, config)
self._setup_paged_optimizers()
def _setup_paged_optimizers(self):
from transformers import Trainer
self.trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.train_data,
eval_dataset=self.val_data,
optim="paged_adamw_32bit"
)
def _apply_gradient_checkpointing(self):
self.model.gradient_checkpointing_enable()
self.model.enable_input_require_grads()
def _apply_flash_attention(self):
for name, module in self.model.named_modules():
if "attn" in name:
if hasattr(module, "use_flash_attention"):
module.use_flash_attention = True参数高效微调
PEFT方法对比
python
class PEFTMethod:
def __init__(self, name: str, config: Dict):
self.name = name
self.config = config
def apply(self, model):
if self.name == "lora":
return self._apply_lora(model)
elif self.name == "adalora":
return self._apply_adalora(model)
elif self.name == "prefix_tuning":
return self._apply_prefix_tuning(model)
elif self.name == "prompt_tuning":
return self._apply_prompt_tuning(model)
elif self.name == "p_tuning":
return self._apply_p_tuning(model)
else:
raise ValueError(f"Unknown PEFT method: {self.name}")
def _apply_lora(self, model):
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=self.config.get("rank", 8),
lora_alpha=self.config.get("alpha", 32),
target_modules=self.config.get("target_modules", ["q_proj", "v_proj"]),
lora_dropout=self.config.get("dropout", 0.05),
bias="none"
)
return get_peft_model(model, config)
def _apply_adalora(self, model):
from peft import AdaLoraConfig, get_peft_model
config = AdaLoraConfig(
r=self.config.get("rank", 8),
lora_alpha=self.config.get("alpha", 32),
target_modules=self.config.get("target_modules", ["q_proj", "v_proj"]),
lora_dropout=self.config.get("dropout", 0.05),
bias="none"
)
return get_peft_model(model, config)
def _apply_prefix_tuning(self, model):
from peft import PrefixTuningConfig, get_peft_model
config = PrefixTuningConfig(
num_virtual_tokens=self.config.get("num_virtual_tokens", 20),
encoder_hidden_size=self.config.get("encoder_hidden_size", 768)
)
return get_peft_model(model, config)
def _apply_prompt_tuning(self, model):
from peft import PromptTuningConfig, get_peft_model
config = PromptTuningConfig(
num_virtual_tokens=self.config.get("num_virtual_tokens", 20),
prompt_tuning_init=self.config.get("prompt_tuning_init", "TEXT"),
prompt_tuning_init_text=self.config.get("prompt_tuning_init_text", "Classify if the tweet is a complaint or not:")
)
return get_peft_model(model, config)
def _apply_p_tuning(self, model):
from peft import PromptEncoderConfig, get_peft_model
config = PromptEncoderConfig(
num_virtual_tokens=self.config.get("num_virtual_tokens", 20),
encoder_hidden_size=self.config.get("encoder_hidden_size", 768)
)
return get_peft_model(model, config)PEFT方法对比
| 方法 | 参数量 | 显存需求 | 训练速度 | 性能 |
|---|---|---|---|---|
| LoRA | <1% | 低 | 快 | 高 |
| AdaLoRA | <1% | 低 | 快 | 高 |
| Prefix Tuning | <0.1% | 极低 | 快 | 中 |
| Prompt Tuning | <0.1% | 极低 | 快 | 中 |
| P-Tuning | <0.1% | 极低 | 快 | 中 |
多LoRA管理
python
class MultiLoRAManager:
def __init__(self, base_model):
self.base_model = base_model
self.lora_adapters = {}
def add_adapter(self, name: str, adapter):
self.lora_adapters[name] = adapter
def activate_adapter(self, name: str):
if name not in self.lora_adapters:
raise ValueError(f"Adapter {name} not found")
adapter = self.lora_adapters[name]
self._apply_adapter(adapter)
def _apply_adapter(self, adapter):
for name, param in self.base_model.named_parameters():
if "lora" in name:
param.data = adapter[name].data
def merge_adapters(self, adapter_names: List[str],
weights: List[float]):
merged_adapter = {}
for name, param in self.base_model.named_parameters():
if "lora" in name:
merged_param = torch.zeros_like(param)
for adapter_name, weight in zip(adapter_names, weights):
adapter = self.lora_adapters[adapter_name]
merged_param += adapter[name] * weight
merged_adapter[name] = merged_param
self._apply_adapter(merged_adapter)
def save_adapter(self, name: str, path: str):
if name not in self.lora_adapters:
raise ValueError(f"Adapter {name} not found")
adapter = self.lora_adapters[name]
torch.save(adapter, path)
def load_adapter(self, name: str, path: str):
adapter = torch.load(path)
self.add_adapter(name, adapter)实际应用
领域适配
python
class DomainAdapter:
def __init__(self, base_model, config):
self.base_model = base_model
self.config = config
self.domain_adapters = {}
def create_domain_adapter(self, domain: str):
tuner = QLoRATuner(
self.base_model.config.name_or_path,
self.config
)
self.domain_adapters[domain] = tuner
return tuner
def fine_tune_domain(self, domain: str, train_data, val_data):
if domain not in self.domain_adapters:
self.create_domain_adapter(domain)
tuner = self.domain_adapters[domain]
return tuner.fine_tune(train_data, val_data)
def switch_domain(self, domain: str):
if domain not in self.domain_adapters:
raise ValueError(f"Domain adapter {domain} not found")
adapter = self.domain_adapters[domain]
self._apply_adapter(adapter)
def _apply_adapter(self, adapter):
pass任务适配
python
class TaskAdapter:
def __init__(self, base_model, config):
self.base_model = base_model
self.config = config
self.task_adapters = {}
def create_task_adapter(self, task: str):
tuner = QLoRATuner(
self.base_model.config.name_or_path,
self.config
)
self.task_adapters[task] = tuner
return tuner
def fine_tune_task(self, task: str, train_data, val_data):
if task not in self.task_adapters:
self.create_task_adapter(task)
tuner = self.task_adapters[task]
return tuner.fine_tune(train_data, val_data)
def switch_task(self, task: str):
if task not in self.task_adapters:
raise ValueError(f"Task adapter {task} not found")
adapter = self.task_adapters[task]
self._apply_adapter(adapter)
def _apply_adapter(self, adapter):
pass实践练习
练习1:实现简单的LoRA层
python
class SimpleLoRALayer(nn.Module):
def __init__(self, in_features, out_features, rank=8):
super().__init__()
self.rank = rank
self.lora_A = nn.Parameter(torch.randn(rank, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
nn.init.kaiming_uniform_(self.lora_A)
nn.init.zeros_(self.lora_B)
def forward(self, x):
return x @ self.lora_A.T @ self.lora_B.T练习2:实现LoRA微调
python
class SimpleLoRATuner:
def __init__(self, model, rank=8, alpha=32):
self.model = model
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
def apply_lora(self):
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
self._add_lora_to_module(module)
def _add_lora_to_module(self, module):
in_features = module.in_features
out_features = module.out_features
lora_layer = SimpleLoRALayer(in_features, out_features, self.rank)
original_forward = module.forward
def new_forward(x):
return original_forward(x) + lora_layer(x) * self.scaling
module.forward = new_forward
module.lora_layer = lora_layer总结
本节我们学习了LoRA与QLoRA:
- LoRA原理和实现
- QLoRA技术和优化
- 参数高效微调方法对比
- 多LoRA管理
- 实际应用(领域适配、任务适配)
LoRA和QLoRA是当前最流行的参数高效微调方法,能够在保持性能的同时大幅降低训练成本。
