Skip to content

第59天:推理加速

学习目标

  • 掌握批处理优化方法
  • 学习KV Cache技术
  • 理解Flash Attention
  • 掌握TensorRT优化
  • 了解ONNX Runtime

批处理优化

基本批处理

python
import torch
from typing import List

class BatchProcessor:
    def __init__(self, model, tokenizer, 
                 max_batch_size: int = 8):
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
    
    def process_batch(self, prompts: List[str]) -> List[str]:
        results = []
        
        for i in range(0, len(prompts), self.max_batch_size):
            batch = prompts[i:i + self.max_batch_size]
            batch_results = self._process_single_batch(batch)
            results.extend(batch_results)
        
        return results
    
    def _process_single_batch(self, batch: List[str]) -> List[str]:
        inputs = self.tokenizer(
            batch,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        
        with torch.no_grad():
            outputs = self.model.generate(**inputs)
        
        results = self.tokenizer.batch_decode(
            outputs,
            skip_special_tokens=True
        )
        
        return results

动态批处理

python
class DynamicBatchProcessor:
    def __init__(self, model, tokenizer, 
                 max_batch_size: int = 8):
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
    
    def process_dynamic_batch(self, prompts: List[str]) -> List[str]:
        sorted_prompts = self._sort_by_length(prompts)
        batches = self._create_dynamic_batches(sorted_prompts)
        
        results = []
        for batch in batches:
            batch_results = self._process_batch(batch)
            results.extend(batch_results)
        
        return self._restore_order(results, prompts)
    
    def _sort_by_length(self, prompts: List[str]) -> List[tuple]:
        prompt_lengths = [
            (i, len(self.tokenizer.encode(prompt)))
            for i, prompt in enumerate(prompts)
        ]
        
        sorted_lengths = sorted(prompt_lengths, key=lambda x: x[1])
        
        return sorted_lengths
    
    def _create_dynamic_batches(self, 
                               sorted_prompts: List[tuple]) -> List[List[str]]:
        batches = []
        current_batch = []
        current_length = 0
        
        for idx, length in sorted_prompts:
            if len(current_batch) >= self.max_batch_size:
                batches.append(current_batch)
                current_batch = []
                current_length = 0
            
            current_batch.append(idx)
            current_length += length
        
        if current_batch:
            batches.append(current_batch)
        
        return batches
    
    def _process_batch(self, batch_indices: List[int]) -> List[str]:
        prompts = [sorted_prompts[i][0] for i in batch_indices]
        
        inputs = self.tokenizer(
            prompts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        
        with torch.no_grad():
            outputs = self.model.generate(**inputs)
        
        results = self.tokenizer.batch_decode(
            outputs,
            skip_special_tokens=True
        )
        
        return results
    
    def _restore_order(self, results: List[str], 
                       original_prompts: List[str]) -> List[str]:
        ordered_results = [None] * len(original_prompts)
        
        for i, result in enumerate(results):
            ordered_results[i] = result
        
        return ordered_results

KV Cache

基本KV Cache

python
class KVCache:
    def __init__(self, num_layers: int, 
                 num_heads: int, 
                 head_dim: int, 
                 max_seq_len: int):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        
        self.key_cache = torch.zeros(
            num_layers,
            num_heads,
            max_seq_len,
            head_dim
        )
        self.value_cache = torch.zeros(
            num_layers,
            num_heads,
            max_seq_len,
            head_dim
        )
        self.current_length = 0
    
    def update(self, layer_idx: int, 
               key: torch.Tensor, 
               value: torch.Tensor):
        seq_len = key.size(1)
        
        if self.current_length + seq_len > self.max_seq_len:
            raise ValueError("KV cache overflow")
        
        self.key_cache[layer_idx, :, 
                       self.current_length:self.current_length+seq_len, :] = key
        self.value_cache[layer_idx, :, 
                         self.current_length:self.current_length+seq_len, :] = value
        
        self.current_length += seq_len
    
    def get(self, layer_idx: int) -> tuple:
        return (
            self.key_cache[layer_idx, :, :self.current_length, :],
            self.value_cache[layer_idx, :, :self.current_length, :]
        )
    
    def reset(self):
        self.current_length = 0
        self.key_cache.zero_()
        self.value_cache.zero_()

KV Cache推理

python
class KVCacheInference:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.kv_cache = None
    
    def generate_with_cache(self, prompt: str, 
                           max_new_tokens: int = 100) -> str:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        
        if self.kv_cache is None:
            self.kv_cache = self._create_kv_cache()
        
        generated_ids = input_ids
        
        for _ in range(max_new_tokens):
            with torch.no_grad():
                outputs = self.model(
                    generated_ids,
                    past_key_values=self.kv_cache.get(0) if self.kv_cache else None,
                    use_cache=True
                )
            
            new_token = outputs.logits[:, -1:].argmax(dim=-1)
            generated_ids = torch.cat([generated_ids, new_token], dim=-1)
            
            if self.kv_cache:
                self.kv_cache.update(0, outputs.past_key_values[0], outputs.past_key_values[1])
            
            if new_token.item() == self.tokenizer.eos_token_id:
                break
        
        generated_text = self.tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )
        
        return generated_text
    
    def _create_kv_cache(self):
        config = self.model.config
        
        return KVCache(
            num_layers=config.num_hidden_layers,
            num_heads=config.num_attention_heads,
            head_dim=config.hidden_size // config.num_attention_heads,
            max_seq_len=config.max_position_embeddings
        )

Flash Attention

Flash Attention实现

python
class FlashAttention:
    def __init__(self, model):
        self.model = model
        self._enable_flash_attention()
    
    def _enable_flash_attention(self):
        for name, module in self.model.named_modules():
            if "attn" in name.lower():
                if hasattr(module, "use_flash_attention"):
                    module.use_flash_attention = True
                elif hasattr(module, "attention_type"):
                    module.attention_type = "flash_attention_2"
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

Flash Attention优化

python
class FlashAttentionOptimizer:
    def __init__(self, model):
        self.model = model
        self.flash_attn = FlashAttention(model)
    
    def optimize_inference(self, input_ids: torch.Tensor):
        with torch.no_grad():
            outputs = self.flash_attn(input_ids)
        
        return outputs
    
    def benchmark(self, input_ids: torch.Tensor, 
                  n_runs: int = 100):
        import time
        
        start_time = time.time()
        
        for _ in range(n_runs):
            self.optimize_inference(input_ids)
        
        avg_time = (time.time() - start_time) / n_runs
        
        return {
            "avg_time": avg_time,
            "throughput": 1.0 / avg_time
        }

TensorRT优化

TensorRT转换

python
class TensorRTConverter:
    def __init__(self, model, input_shape: tuple):
        self.model = model
        self.input_shape = input_shape
    
    def convert_to_tensorrt(self, output_path: str):
        try:
            import tensorrt as trt
            from torch2trt import torch2trt
        except ImportError:
            raise ImportError("Install tensorrt: pip install tensorrt")
        
        dummy_input = torch.randn(self.input_shape)
        
        model_trt = torch2trt(
            self.model,
            [dummy_input],
            fp16_mode=True,
            max_workspace_size=1 << 30
        )
        
        torch.save(model_trt.state_dict(), output_path)
        
        return model_trt
    
    def optimize_for_inference(self, model_trt):
        model_trt.eval()
        
        for param in model_trt.parameters():
            param.requires_grad = False
        
        return model_trt

TensorRT推理

python
class TensorRTInference:
    def __init__(self, engine_path: str):
        self.engine = self._load_engine(engine_path)
    
    def _load_engine(self, engine_path: str):
        try:
            import tensorrt as trt
        except ImportError:
            raise ImportError("Install tensorrt: pip install tensorrt")
        
        with open(engine_path, "rb") as f:
            runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
            engine = runtime.deserialize_cuda_engine(f.read())
        
        return engine
    
    def infer(self, input_data: torch.Tensor):
        import tensorrt as trt
        
        context = self.engine.create_execution_context()
        
        input_ptr = input_data.data_ptr()
        output_ptr = torch.empty(self._get_output_shape()).data_ptr()
        
        context.set_binding_shape(0, input_data.shape)
        context.execute_v2([input_ptr, output_ptr])
        
        output = torch.from_ptr(output_ptr, self._get_output_shape())
        
        return output
    
    def _get_output_shape(self):
        return self.engine.get_binding_shape(1)

ONNX Runtime

ONNX转换

python
class ONNXConverter:
    def __init__(self, model, input_shape: tuple):
        self.model = model
        self.input_shape = input_shape
    
    def convert_to_onnx(self, output_path: str):
        try:
            import torch.onnx
        except ImportError:
            raise ImportError("Install torch: pip install torch")
        
        dummy_input = torch.randn(self.input_shape)
        
        torch.onnx.export(
            self.model,
            dummy_input,
            output_path,
            export_params=True,
            opset_version=14,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )
        
        return output_path

ONNX推理

python
class ONNXInference:
    def __init__(self, onnx_path: str):
        self.session = self._load_onnx_model(onnx_path)
    
    def _load_onnx_model(self, onnx_path: str):
        try:
            import onnxruntime as ort
        except ImportError:
            raise ImportError("Install onnxruntime: pip install onnxruntime")
        
        session = ort.InferenceSession(onnx_path)
        
        return session
    
    def infer(self, input_data: np.ndarray):
        outputs = self.session.run(
            None,
            {'input': input_data}
        )
        
        return outputs[0]
    
    def benchmark(self, input_data: np.ndarray, 
                  n_runs: int = 100):
        import time
        
        start_time = time.time()
        
        for _ in range(n_runs):
            self.infer(input_data)
        
        avg_time = (time.time() - start_time) / n_runs
        
        return {
            "avg_time": avg_time,
            "throughput": 1.0 / avg_time
        }

推理优化对比

性能对比

python
class InferenceOptimizerComparator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def compare_optimizers(self, test_prompts: List[str]):
        results = {}
        
        baseline_time = self._measure_baseline_inference(test_prompts)
        results["baseline"] = baseline_time
        
        batch_time = self._measure_batch_inference(test_prompts)
        results["batch"] = batch_time
        
        kv_cache_time = self._measure_kv_cache_inference(test_prompts)
        results["kv_cache"] = kv_cache_time
        
        return results
    
    def _measure_baseline_inference(self, prompts: List[str]) -> float:
        import time
        
        total_time = 0
        
        for prompt in prompts:
            start_time = time.time()
            
            inputs = self.tokenizer(prompt, return_tensors="pt")
            with torch.no_grad():
                _ = self.model.generate(**inputs)
            
            total_time += time.time() - start_time
        
        return total_time / len(prompts)
    
    def _measure_batch_inference(self, prompts: List[str]) -> float:
        import time
        
        processor = BatchProcessor(self.model, self.tokenizer)
        
        start_time = time.time()
        _ = processor.process_batch(prompts)
        total_time = time.time() - start_time
        
        return total_time / len(prompts)
    
    def _measure_kv_cache_inference(self, prompts: List[str]) -> float:
        import time
        
        inference = KVCacheInference(self.model, self.tokenizer)
        
        total_time = 0
        
        for prompt in prompts:
            start_time = time.time()
            _ = inference.generate_with_cache(prompt)
            total_time += time.time() - start_time
        
        return total_time / len(prompts)

实践练习

练习1:实现批处理推理

python
def batch_inference(model, tokenizer, prompts: List[str], 
                   batch_size: int = 8):
    results = []
    
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i + batch_size]
        
        inputs = tokenizer(batch, padding=True, 
                       truncation=True, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model.generate(**inputs)
        
        batch_results = tokenizer.batch_decode(outputs, 
                                        skip_special_tokens=True)
        results.extend(batch_results)
    
    return results

练习2:实现KV Cache推理

python
def kv_cache_inference(model, tokenizer, prompt: str, 
                      max_new_tokens: int = 100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    past_key_values = None
    generated_ids = input_ids
    
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(
                generated_ids,
                past_key_values=past_key_values,
                use_cache=True
            )
        
        new_token = outputs.logits[:, -1:].argmax(dim=-1)
        generated_ids = torch.cat([generated_ids, new_token], dim=-1)
        
        past_key_values = outputs.past_key_values
        
        if new_token.item() == tokenizer.eos_token_id:
            break
    
    generated_text = tokenizer.decode(generated_ids[0], 
                                   skip_special_tokens=True)
    
    return generated_text

总结

本节我们学习了推理加速:

  1. 批处理优化(基本批处理、动态批处理)
  2. KV Cache技术
  3. Flash Attention优化
  4. TensorRT优化
  5. ONNX Runtime

推理加速是提高AI应用性能的关键技术。

参考资源