Skip to content

LLM Visualization 架构与原理

整体架构

LLM Visualization采用分层架构设计,将模型推理、数据处理和可视化渲染分离,确保系统的可维护性和扩展性。

┌─────────────────────────────────────────────────────────────┐
│                    用户界面层 (UI Layer)                      │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐        │
│  │  3D渲染视图   │  │  控制面板    │  │  信息展示    │        │
│  └──────────────┘  └──────────────┘  └──────────────┘        │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│                   可视化引擎层 (Visualization)                │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐        │
│  │  场景管理器   │  │  动画控制器  │  │  交互处理器  │        │
│  └──────────────┘  └──────────────┘  └──────────────┘        │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│                    模型推理层 (Model Inference)               │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐        │
│  │  ONNX Runtime│  │  注意力计算  │  │  状态捕获    │        │
│  └──────────────┘  └──────────────┘  └──────────────┘        │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│                    数据管理层 (Data Management)               │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐        │
│  │  模型加载器   │  │  状态缓存    │  │  配置管理    │        │
│  └──────────────┘  └──────────────┘  └──────────────┘        │
└─────────────────────────────────────────────────────────────┘

核心组件详解

1. 模型推理引擎

ONNX Runtime Web

LLM Visualization使用ONNX Runtime Web在浏览器中运行模型推理:

typescript
// 模型加载和推理示例
class ModelInference {
  private session: ort.InferenceSession;
  
  async loadModel(modelPath: string) {
    this.session = await ort.InferenceSession.create(modelPath, {
      executionProviders: ['wasm'],
      graphOptimizationLevel: 'all'
    });
  }
  
  async runInference(inputTokens: number[]) {
    const inputTensor = new ort.Tensor('int64', inputTokens, [1, inputTokens.length]);
    const outputs = await this.session.run({ input: inputTensor });
    return outputs;
  }
}

注意力计算捕获

为了可视化注意力权重,系统需要捕获中间计算结果:

typescript
interface AttentionState {
  query: Float32Array;      // Q矩阵
  key: Float32Array;        // K矩阵
  value: Float32Array;      // V矩阵
  attentionWeights: Float32Array;  // 注意力权重
  output: Float32Array;     // 输出
}

class AttentionCapture {
  captureLayer(layerIndex: number): AttentionState {
    // 捕获指定层的注意力状态
    return {
      query: this.extractQuery(layerIndex),
      key: this.extractKey(layerIndex),
      value: this.extractValue(layerIndex),
      attentionWeights: this.computeAttentionWeights(layerIndex),
      output: this.extractOutput(layerIndex)
    };
  }
}

2. 可视化引擎

3D场景管理

使用Three.js构建交互式3D场景:

typescript
class SceneManager {
  private scene: THREE.Scene;
  private camera: THREE.PerspectiveCamera;
  private renderer: THREE.WebGLRenderer;
  private controls: OrbitControls;
  
  constructor(container: HTMLElement) {
    this.scene = new THREE.Scene();
    this.camera = new THREE.PerspectiveCamera(75, aspect, 0.1, 1000);
    this.renderer = new THREE.WebGLRenderer({ antialias: true });
    this.controls = new OrbitControls(this.camera, this.renderer.domElement);
    
    this.setupLighting();
    this.setupCamera();
  }
  
  private setupLighting() {
    const ambientLight = new THREE.AmbientLight(0x404040, 0.5);
    const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
    directionalLight.position.set(10, 10, 5);
    
    this.scene.add(ambientLight);
    this.scene.add(directionalLight);
  }
}

节点可视化

将神经网络层可视化为3D节点:

typescript
class NodeVisualization {
  createTokenNode(token: string, position: Vector3): THREE.Mesh {
    const geometry = new THREE.SphereGeometry(0.5, 32, 32);
    const material = new THREE.MeshPhongMaterial({
      color: this.getTokenColor(token),
      transparent: true,
      opacity: 0.8
    });
    
    const node = new THREE.Mesh(geometry, material);
    node.position.copy(position);
    
    // 添加标签
    const label = this.createLabel(token);
    node.add(label);
    
    return node;
  }
  
  createAttentionConnection(
    from: Vector3,
    to: Vector3,
    weight: number
  ): THREE.Line {
    const geometry = new THREE.BufferGeometry().setFromPoints([from, to]);
    const material = new THREE.LineBasicMaterial({
      color: 0xff6b6b,
      transparent: true,
      opacity: weight * 0.8
    });
    
    return new THREE.Line(geometry, material);
  }
}

3. 数据流管理

状态管理架构

使用Redux管理应用状态:

typescript
// 状态定义
interface AppState {
  model: {
    currentModel: string;
    loading: boolean;
    error: string | null;
  };
  inference: {
    inputText: string;
    tokens: Token[];
    currentStep: number;
    isRunning: boolean;
  };
  visualization: {
    selectedLayer: number;
    selectedHead: number;
    viewMode: '3d' | '2d';
    animationSpeed: number;
  };
}

// Action定义
const actions = {
  setInputText: (text: string) => ({
    type: 'SET_INPUT_TEXT',
    payload: text
  }),
  
  runInference: () => ({
    type: 'RUN_INFERENCE'
  }),
  
  updateVisualization: (config: VisualizationConfig) => ({
    type: 'UPDATE_VISUALIZATION',
    payload: config
  })
};

数据转换管道

将模型输出转换为可视化数据:

typescript
class DataTransformer {
  transformAttentionData(
    rawAttention: Float32Array,
    tokens: string[]
  ): AttentionVisualizationData {
    const size = tokens.length;
    const weights: number[][] = [];
    
    for (let i = 0; i < size; i++) {
      const row: number[] = [];
      for (let j = 0; j < size; j++) {
        row.push(rawAttention[i * size + j]);
      }
      weights.push(row);
    }
    
    return {
      tokens,
      weights,
      maxWeight: Math.max(...weights.flat()),
      minWeight: Math.min(...weights.flat())
    };
  }
  
  transformEmbeddingData(
    embeddings: Float32Array,
    dimension: number
  ): EmbeddingVisualizationData {
    // 使用PCA或t-SNE降维到3D
    const reduced = this.reduceDimension(embeddings, dimension, 3);
    
    return {
      points: reduced,
      dimension: 3,
      originalDimension: dimension
    };
  }
}

可视化原理

1. 注意力权重可视化

热力图渲染

将注意力矩阵渲染为彩色热力图:

typescript
class AttentionHeatmap {
  createHeatmap(weights: number[][]): HTMLCanvasElement {
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d')!;
    const size = weights.length;
    
    canvas.width = size * 20;
    canvas.height = size * 20;
    
    for (let i = 0; i < size; i++) {
      for (let j = 0; j < size; j++) {
        const weight = weights[i][j];
        const color = this.weightToColor(weight);
        
        ctx.fillStyle = color;
        ctx.fillRect(j * 20, i * 20, 20, 20);
      }
    }
    
    return canvas;
  }
  
  private weightToColor(weight: number): string {
    // 使用蓝色到红色的渐变
    const r = Math.floor(weight * 255);
    const b = Math.floor((1 - weight) * 255);
    return `rgb(${r}, 50, ${b})`;
  }
}

2. Token生成过程可视化

逐步动画

展示token生成的逐步过程:

typescript
class TokenGenerationAnimation {
  private currentStep: number = 0;
  private steps: GenerationStep[];
  
  async animateGeneration() {
    for (const step of this.steps) {
      await this.animateStep(step);
      this.currentStep++;
    }
  }
  
  private async animateStep(step: GenerationStep) {
    // 1. 高亮当前输入token
    await this.highlightInputTokens(step.inputTokens);
    
    // 2. 展示注意力计算
    await this.showAttentionComputation(step.attentionWeights);
    
    // 3. 展示前馈网络计算
    await this.showFeedForward(step.hiddenStates);
    
    // 4. 展示输出概率分布
    await this.showProbabilityDistribution(step.probabilities);
    
    // 5. 生成新token
    await this.generateToken(step.outputToken);
  }
}

3. 网络层可视化

分层展示

将Transformer的各层以3D堆叠方式展示:

typescript
class LayerVisualization {
  createLayerStack(numLayers: number): THREE.Group {
    const group = new THREE.Group();
    
    for (let i = 0; i < numLayers; i++) {
      const layer = this.createLayer(i);
      layer.position.y = i * 3; // 每层间隔3个单位
      group.add(layer);
      
      // 添加层间连接线
      if (i > 0) {
        const connections = this.createLayerConnections(i - 1, i);
        group.add(connections);
      }
    }
    
    return group;
  }
  
  private createLayer(layerIndex: number): THREE.Group {
    const layer = new THREE.Group();
    
    // 自注意力子层
    const attention = this.createAttentionBlock(layerIndex);
    layer.add(attention);
    
    // 前馈网络子层
    const ffn = this.createFFNBlock(layerIndex);
    ffn.position.y = 1.5;
    layer.add(ffn);
    
    return layer;
  }
}

性能优化

1. 渲染优化

typescript
class RenderOptimizer {
  // 使用InstancedMesh批量渲染相同几何体
  createInstancedNodes(count: number): THREE.InstancedMesh {
    const geometry = new THREE.SphereGeometry(0.5, 16, 16);
    const material = new THREE.MeshPhongMaterial({ color: 0x4ecdc4 });
    
    return new THREE.InstancedMesh(geometry, material, count);
  }
  
  // LOD(细节层次)系统
  setupLOD(object: THREE.Object3D, distances: number[]) {
    const lod = new THREE.LOD();
    
    // 添加不同细节层次的模型
    lod.addLevel(highDetailModel, 0);
    lod.addLevel(mediumDetailModel, distances[0]);
    lod.addLevel(lowDetailModel, distances[1]);
    
    return lod;
  }
}

2. 计算优化

typescript
class ComputationOptimizer {
  // Web Worker并行计算
  private worker: Worker;
  
  constructor() {
    this.worker = new Worker('inference-worker.js');
  }
  
  async computeInWorker(data: Float32Array): Promise<Float32Array> {
    return new Promise((resolve) => {
      this.worker.onmessage = (e) => resolve(e.data);
      this.worker.postMessage(data);
    });
  }
  
  // 缓存计算结果
  private cache: Map<string, Float32Array> = new Map();
  
  getCachedResult(key: string): Float32Array | undefined {
    return this.cache.get(key);
  }
  
  setCachedResult(key: string, value: Float32Array) {
    this.cache.set(key, value);
  }
}

扩展架构

插件系统

支持自定义可视化插件:

typescript
interface VisualizationPlugin {
  name: string;
  version: string;
  
  // 初始化插件
  initialize(context: PluginContext): void;
  
  // 注册可视化组件
  registerVisualizations(registry: VisualizationRegistry): void;
  
  // 处理模型输出
  processOutput(output: ModelOutput): VisualizationData;
}

class PluginManager {
  private plugins: Map<string, VisualizationPlugin> = new Map();
  
  registerPlugin(plugin: VisualizationPlugin) {
    this.plugins.set(plugin.name, plugin);
    plugin.initialize(this.context);
  }
  
  getPlugin(name: string): VisualizationPlugin | undefined {
    return this.plugins.get(name);
  }
}

下一步

了解了架构和原理后,让我们继续探索LLM Visualization的核心功能详解