Appearance
第59天:推理加速
学习目标
- 掌握批处理优化方法
- 学习KV Cache技术
- 理解Flash Attention
- 掌握TensorRT优化
- 了解ONNX Runtime
批处理优化
基本批处理
python
import torch
from typing import List
class BatchProcessor:
def __init__(self, model, tokenizer,
max_batch_size: int = 8):
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
def process_batch(self, prompts: List[str]) -> List[str]:
results = []
for i in range(0, len(prompts), self.max_batch_size):
batch = prompts[i:i + self.max_batch_size]
batch_results = self._process_single_batch(batch)
results.extend(batch_results)
return results
def _process_single_batch(self, batch: List[str]) -> List[str]:
inputs = self.tokenizer(
batch,
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
outputs = self.model.generate(**inputs)
results = self.tokenizer.batch_decode(
outputs,
skip_special_tokens=True
)
return results动态批处理
python
class DynamicBatchProcessor:
def __init__(self, model, tokenizer,
max_batch_size: int = 8):
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
def process_dynamic_batch(self, prompts: List[str]) -> List[str]:
sorted_prompts = self._sort_by_length(prompts)
batches = self._create_dynamic_batches(sorted_prompts)
results = []
for batch in batches:
batch_results = self._process_batch(batch)
results.extend(batch_results)
return self._restore_order(results, prompts)
def _sort_by_length(self, prompts: List[str]) -> List[tuple]:
prompt_lengths = [
(i, len(self.tokenizer.encode(prompt)))
for i, prompt in enumerate(prompts)
]
sorted_lengths = sorted(prompt_lengths, key=lambda x: x[1])
return sorted_lengths
def _create_dynamic_batches(self,
sorted_prompts: List[tuple]) -> List[List[str]]:
batches = []
current_batch = []
current_length = 0
for idx, length in sorted_prompts:
if len(current_batch) >= self.max_batch_size:
batches.append(current_batch)
current_batch = []
current_length = 0
current_batch.append(idx)
current_length += length
if current_batch:
batches.append(current_batch)
return batches
def _process_batch(self, batch_indices: List[int]) -> List[str]:
prompts = [sorted_prompts[i][0] for i in batch_indices]
inputs = self.tokenizer(
prompts,
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
outputs = self.model.generate(**inputs)
results = self.tokenizer.batch_decode(
outputs,
skip_special_tokens=True
)
return results
def _restore_order(self, results: List[str],
original_prompts: List[str]) -> List[str]:
ordered_results = [None] * len(original_prompts)
for i, result in enumerate(results):
ordered_results[i] = result
return ordered_resultsKV Cache
基本KV Cache
python
class KVCache:
def __init__(self, num_layers: int,
num_heads: int,
head_dim: int,
max_seq_len: int):
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.key_cache = torch.zeros(
num_layers,
num_heads,
max_seq_len,
head_dim
)
self.value_cache = torch.zeros(
num_layers,
num_heads,
max_seq_len,
head_dim
)
self.current_length = 0
def update(self, layer_idx: int,
key: torch.Tensor,
value: torch.Tensor):
seq_len = key.size(1)
if self.current_length + seq_len > self.max_seq_len:
raise ValueError("KV cache overflow")
self.key_cache[layer_idx, :,
self.current_length:self.current_length+seq_len, :] = key
self.value_cache[layer_idx, :,
self.current_length:self.current_length+seq_len, :] = value
self.current_length += seq_len
def get(self, layer_idx: int) -> tuple:
return (
self.key_cache[layer_idx, :, :self.current_length, :],
self.value_cache[layer_idx, :, :self.current_length, :]
)
def reset(self):
self.current_length = 0
self.key_cache.zero_()
self.value_cache.zero_()KV Cache推理
python
class KVCacheInference:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.kv_cache = None
def generate_with_cache(self, prompt: str,
max_new_tokens: int = 100) -> str:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
if self.kv_cache is None:
self.kv_cache = self._create_kv_cache()
generated_ids = input_ids
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = self.model(
generated_ids,
past_key_values=self.kv_cache.get(0) if self.kv_cache else None,
use_cache=True
)
new_token = outputs.logits[:, -1:].argmax(dim=-1)
generated_ids = torch.cat([generated_ids, new_token], dim=-1)
if self.kv_cache:
self.kv_cache.update(0, outputs.past_key_values[0], outputs.past_key_values[1])
if new_token.item() == self.tokenizer.eos_token_id:
break
generated_text = self.tokenizer.decode(
generated_ids[0],
skip_special_tokens=True
)
return generated_text
def _create_kv_cache(self):
config = self.model.config
return KVCache(
num_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_dim=config.hidden_size // config.num_attention_heads,
max_seq_len=config.max_position_embeddings
)Flash Attention
Flash Attention实现
python
class FlashAttention:
def __init__(self, model):
self.model = model
self._enable_flash_attention()
def _enable_flash_attention(self):
for name, module in self.model.named_modules():
if "attn" in name.lower():
if hasattr(module, "use_flash_attention"):
module.use_flash_attention = True
elif hasattr(module, "attention_type"):
module.attention_type = "flash_attention_2"
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)Flash Attention优化
python
class FlashAttentionOptimizer:
def __init__(self, model):
self.model = model
self.flash_attn = FlashAttention(model)
def optimize_inference(self, input_ids: torch.Tensor):
with torch.no_grad():
outputs = self.flash_attn(input_ids)
return outputs
def benchmark(self, input_ids: torch.Tensor,
n_runs: int = 100):
import time
start_time = time.time()
for _ in range(n_runs):
self.optimize_inference(input_ids)
avg_time = (time.time() - start_time) / n_runs
return {
"avg_time": avg_time,
"throughput": 1.0 / avg_time
}TensorRT优化
TensorRT转换
python
class TensorRTConverter:
def __init__(self, model, input_shape: tuple):
self.model = model
self.input_shape = input_shape
def convert_to_tensorrt(self, output_path: str):
try:
import tensorrt as trt
from torch2trt import torch2trt
except ImportError:
raise ImportError("Install tensorrt: pip install tensorrt")
dummy_input = torch.randn(self.input_shape)
model_trt = torch2trt(
self.model,
[dummy_input],
fp16_mode=True,
max_workspace_size=1 << 30
)
torch.save(model_trt.state_dict(), output_path)
return model_trt
def optimize_for_inference(self, model_trt):
model_trt.eval()
for param in model_trt.parameters():
param.requires_grad = False
return model_trtTensorRT推理
python
class TensorRTInference:
def __init__(self, engine_path: str):
self.engine = self._load_engine(engine_path)
def _load_engine(self, engine_path: str):
try:
import tensorrt as trt
except ImportError:
raise ImportError("Install tensorrt: pip install tensorrt")
with open(engine_path, "rb") as f:
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
engine = runtime.deserialize_cuda_engine(f.read())
return engine
def infer(self, input_data: torch.Tensor):
import tensorrt as trt
context = self.engine.create_execution_context()
input_ptr = input_data.data_ptr()
output_ptr = torch.empty(self._get_output_shape()).data_ptr()
context.set_binding_shape(0, input_data.shape)
context.execute_v2([input_ptr, output_ptr])
output = torch.from_ptr(output_ptr, self._get_output_shape())
return output
def _get_output_shape(self):
return self.engine.get_binding_shape(1)ONNX Runtime
ONNX转换
python
class ONNXConverter:
def __init__(self, model, input_shape: tuple):
self.model = model
self.input_shape = input_shape
def convert_to_onnx(self, output_path: str):
try:
import torch.onnx
except ImportError:
raise ImportError("Install torch: pip install torch")
dummy_input = torch.randn(self.input_shape)
torch.onnx.export(
self.model,
dummy_input,
output_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
return output_pathONNX推理
python
class ONNXInference:
def __init__(self, onnx_path: str):
self.session = self._load_onnx_model(onnx_path)
def _load_onnx_model(self, onnx_path: str):
try:
import onnxruntime as ort
except ImportError:
raise ImportError("Install onnxruntime: pip install onnxruntime")
session = ort.InferenceSession(onnx_path)
return session
def infer(self, input_data: np.ndarray):
outputs = self.session.run(
None,
{'input': input_data}
)
return outputs[0]
def benchmark(self, input_data: np.ndarray,
n_runs: int = 100):
import time
start_time = time.time()
for _ in range(n_runs):
self.infer(input_data)
avg_time = (time.time() - start_time) / n_runs
return {
"avg_time": avg_time,
"throughput": 1.0 / avg_time
}推理优化对比
性能对比
python
class InferenceOptimizerComparator:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def compare_optimizers(self, test_prompts: List[str]):
results = {}
baseline_time = self._measure_baseline_inference(test_prompts)
results["baseline"] = baseline_time
batch_time = self._measure_batch_inference(test_prompts)
results["batch"] = batch_time
kv_cache_time = self._measure_kv_cache_inference(test_prompts)
results["kv_cache"] = kv_cache_time
return results
def _measure_baseline_inference(self, prompts: List[str]) -> float:
import time
total_time = 0
for prompt in prompts:
start_time = time.time()
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
_ = self.model.generate(**inputs)
total_time += time.time() - start_time
return total_time / len(prompts)
def _measure_batch_inference(self, prompts: List[str]) -> float:
import time
processor = BatchProcessor(self.model, self.tokenizer)
start_time = time.time()
_ = processor.process_batch(prompts)
total_time = time.time() - start_time
return total_time / len(prompts)
def _measure_kv_cache_inference(self, prompts: List[str]) -> float:
import time
inference = KVCacheInference(self.model, self.tokenizer)
total_time = 0
for prompt in prompts:
start_time = time.time()
_ = inference.generate_with_cache(prompt)
total_time += time.time() - start_time
return total_time / len(prompts)实践练习
练习1:实现批处理推理
python
def batch_inference(model, tokenizer, prompts: List[str],
batch_size: int = 8):
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
inputs = tokenizer(batch, padding=True,
truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs)
batch_results = tokenizer.batch_decode(outputs,
skip_special_tokens=True)
results.extend(batch_results)
return results练习2:实现KV Cache推理
python
def kv_cache_inference(model, tokenizer, prompt: str,
max_new_tokens: int = 100):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
past_key_values = None
generated_ids = input_ids
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(
generated_ids,
past_key_values=past_key_values,
use_cache=True
)
new_token = outputs.logits[:, -1:].argmax(dim=-1)
generated_ids = torch.cat([generated_ids, new_token], dim=-1)
past_key_values = outputs.past_key_values
if new_token.item() == tokenizer.eos_token_id:
break
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)
return generated_text总结
本节我们学习了推理加速:
- 批处理优化(基本批处理、动态批处理)
- KV Cache技术
- Flash Attention优化
- TensorRT优化
- ONNX Runtime
推理加速是提高AI应用性能的关键技术。
