Skip to content

LLM Visualization 核心功能详解

1. Token生成过程可视化

1.1 输入处理可视化

展示文本如何被转换为模型可理解的token:

输入文本: "Hello world"

Token化: [15496, 995]

嵌入层:  [[0.23, -0.45, 0.67, ...], [0.12, 0.89, -0.34, ...]]

位置编码: [[0.23, -0.45, 0.67, ...] + pos_enc(0),
          [0.12, 0.89, -0.34, ...] + pos_enc(1)]

可视化效果:

  • 每个token以彩色球体表示
  • 颜色编码表示token的语义相似性
  • 3D空间中展示嵌入向量的分布
  • 位置编码以波纹效果叠加

1.2 逐步生成动画

实时展示模型如何逐个生成token:

typescript
interface GenerationStep {
  stepNumber: number;
  inputTokens: number[];
  attentionWeights: number[][];
  hiddenStates: Float32Array;
  logits: Float32Array;
  probabilities: Float32Array;
  selectedToken: number;
  outputToken: string;
}

class TokenGenerationVisualizer {
  async visualizeGeneration(inputText: string) {
    const tokens = this.tokenize(inputText);
    
    for (let i = 0; i < maxTokens; i++) {
      const step = await this.runGenerationStep(tokens);
      
      // 1. 展示注意力计算
      this.visualizeAttention(step.attentionWeights);
      
      // 2. 展示概率分布
      this.visualizeProbabilityDistribution(step.probabilities);
      
      // 3. 高亮选中的token
      this.highlightSelectedToken(step.selectedToken);
      
      // 4. 更新输出
      this.appendOutputToken(step.outputToken);
      
      tokens.push(step.selectedToken);
    }
  }
}

动画流程:

  1. 输入展示 - 当前输入序列的可视化
  2. 注意力聚焦 - 展示模型关注哪些token
  3. 计算过程 - 展示矩阵运算的流动
  4. 概率分布 - 展示候选token的概率柱状图
  5. 采样过程 - 展示temperature和top-p采样效果
  6. 输出生成 - 新token添加到输出序列

1.3 概率分布可视化

展示模型对每个位置预测的概率分布:

typescript
class ProbabilityVisualizer {
  createProbabilityChart(probabilities: number[], topK: number = 10): void {
    // 获取概率最高的topK个token
    const topTokens = this.getTopKTokens(probabilities, topK);
    
    // 创建柱状图
    const chart = {
      type: 'bar',
      data: {
        labels: topTokens.map(t => t.token),
        datasets: [{
          label: 'Probability',
          data: topTokens.map(t => t.probability),
          backgroundColor: topTokens.map((t, i) => 
            i === 0 ? '#4ecdc4' : '#95a5a6'
          )
        }]
      }
    };
    
    // 添加温度调节滑块
    this.addTemperatureControl();
    
    // 添加top-p调节滑块
    this.addTopPControl();
  }
}

交互功能:

  • 调整temperature观察概率分布变化
  • 调整top-p观察 nucleus sampling 效果
  • 悬停查看具体概率值
  • 点击token强制选择

2. 注意力机制可视化

2.1 单头注意力可视化

展示单个注意力头的权重分布:

typescript
interface AttentionVisualization {
  tokens: string[];
  weights: number[][];  // [target, source]
  query: Float32Array;
  key: Float32Array;
  value: Float32Array;
}

class SingleHeadVisualizer {
  visualizeAttention(data: AttentionVisualization) {
    const size = data.tokens.length;
    
    // 创建热力图
    const heatmap = this.createHeatmap(data.weights);
    
    // 添加token标签
    this.addTokenLabels(data.tokens);
    
    // 添加连接线
    for (let i = 0; i < size; i++) {
      for (let j = 0; j < size; j++) {
        if (data.weights[i][j] > 0.1) {
          this.addConnection(i, j, data.weights[i][j]);
        }
      }
    }
  }
  
  private createHeatmap(weights: number[][]): HTMLCanvasElement {
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d')!;
    const cellSize = 30;
    
    canvas.width = weights.length * cellSize;
    canvas.height = weights.length * cellSize;
    
    for (let i = 0; i < weights.length; i++) {
      for (let j = 0; j < weights.length; j++) {
        const intensity = weights[i][j];
        const color = this.getHeatmapColor(intensity);
        
        ctx.fillStyle = color;
        ctx.fillRect(j * cellSize, i * cellSize, cellSize, cellSize);
        
        // 添加数值标签
        if (intensity > 0.3) {
          ctx.fillStyle = intensity > 0.6 ? 'white' : 'black';
          ctx.font = '10px Arial';
          ctx.textAlign = 'center';
          ctx.fillText(
            intensity.toFixed(2),
            j * cellSize + cellSize / 2,
            i * cellSize + cellSize / 2 + 4
          );
        }
      }
    }
    
    return canvas;
  }
}

2.2 多头注意力对比

同时展示多个注意力头的模式差异:

typescript
class MultiHeadVisualizer {
  visualizeAllHeads(
    tokens: string[],
    allHeadWeights: number[][][],  // [layer, head, target, source]
    numLayers: number,
    numHeads: number
  ) {
    const container = document.createElement('div');
    container.className = 'multi-head-grid';
    
    for (let layer = 0; layer < numLayers; layer++) {
      for (let head = 0; head < numHeads; head++) {
        const headData = {
          tokens,
          weights: allHeadWeights[layer][head]
        };
        
        const headViz = this.createHeadVisualization(headData);
        headViz.className = `head-viz layer-${layer} head-${head}`;
        
        // 添加标题
        const title = document.createElement('div');
        title.textContent = `Layer ${layer + 1}, Head ${head + 1}`;
        headViz.appendChild(title);
        
        container.appendChild(headViz);
      }
    }
    
    return container;
  }
}

可视化效果:

  • 网格布局展示所有注意力头
  • 不同颜色区分不同模式
  • 支持按层筛选
  • 支持按注意力模式聚类

2.3 注意力模式分析

识别和标注不同的注意力模式:

typescript
enum AttentionPattern {
  DIAGONAL = 'diagonal',      // 对角线模式(关注邻近token)
  VERTICAL = 'vertical',      // 垂直模式(关注特定位置)
  BLOCK = 'block',           // 块模式(关注token范围)
  SPARSE = 'sparse',         // 稀疏模式(分散关注)
  GLOBAL = 'global'          // 全局模式(关注所有token)
}

class AttentionPatternAnalyzer {
  analyzePattern(weights: number[][]): AttentionPattern {
    const size = weights.length;
    
    // 计算对角线集中度
    const diagonalScore = this.calculateDiagonalScore(weights);
    
    // 计算垂直集中度
    const verticalScore = this.calculateVerticalScore(weights);
    
    // 计算稀疏度
    const sparsity = this.calculateSparsity(weights);
    
    // 根据分数判断模式
    if (diagonalScore > 0.7) return AttentionPattern.DIAGONAL;
    if (verticalScore > 0.6) return AttentionPattern.VERTICAL;
    if (sparsity > 0.8) return AttentionPattern.SPARSE;
    if (sparsity < 0.3) return AttentionPattern.GLOBAL;
    
    return AttentionPattern.BLOCK;
  }
  
  private calculateDiagonalScore(weights: number[][]): number {
    let diagonalSum = 0;
    let totalSum = 0;
    
    for (let i = 0; i < weights.length; i++) {
      for (let j = 0; j < weights.length; j++) {
        totalSum += weights[i][j];
        if (Math.abs(i - j) <= 2) {
          diagonalSum += weights[i][j];
        }
      }
    }
    
    return diagonalSum / totalSum;
  }
}

常见注意力模式:

模式描述典型用途
对角线主要关注邻近token局部语法结构
垂直关注特定位置(如[CLS])句子级表示
关注token范围短语或实体
稀疏分散关注多个位置长距离依赖
全局均匀关注所有token全局上下文

3. 网络层状态可视化

3.1 残差连接可视化

展示残差连接如何帮助梯度流动:

typescript
class ResidualConnectionVisualizer {
  visualizeResidualFlow(
    input: Float32Array,
    sublayerOutput: Float32Array,
    finalOutput: Float32Array
  ) {
    // 创建三个节点
    const inputNode = this.createNode('Input', input, 0xff6b6b);
    const sublayerNode = this.createNode('Sublayer', sublayerOutput, 4ecdc4);
    const outputNode = this.createNode('Output', finalOutput, 45b7d1);
    
    // 残差连接(跳跃连接)
    const residualLine = this.createAnimatedLine(
      inputNode.position,
      outputNode.position,
      0xff6b6b,
      'dash'
    );
    
    // 子层输出连接
    const sublayerLine = this.createAnimatedLine(
      sublayerNode.position,
      outputNode.position,
      0x4ecdc4,
      'solid'
    );
    
    // 添加加法操作的可视化
    this.showAdditionAnimation(input, sublayerOutput, finalOutput);
  }
}

3.2 Layer Normalization可视化

展示Layer Norm如何稳定训练:

typescript
class LayerNormVisualizer {
  visualizeNormalization(
    input: Float32Array,
    normalized: Float32Array,
    gamma: Float32Array,
    beta: Float32Array
  ) {
    // 展示输入分布
    this.showDistribution(input, 'Input Distribution', 'before');
    
    // 计算并展示统计量
    const mean = this.calculateMean(input);
    const variance = this.calculateVariance(input);
    
    this.showStatistics(mean, variance);
    
    // 展示归一化后的分布
    this.showDistribution(normalized, 'Normalized Distribution', 'after');
    
    // 展示gamma和beta的作用
    this.showScaleShift(gamma, beta);
  }
}

3.3 前馈网络可视化

展示FFN的变换过程:

typescript
class FFNVisualizer {
  visualizeFFN(
    input: Float32Array,
    hidden: Float32Array,
    output: Float32Array,
    weights1: Float32Array,
    weights2: Float32Array
  ) {
    // 第一层变换
    this.showMatrixMultiplication(input, weights1, hidden);
    this.showActivation(hidden, 'gelu');
    
    // 第二层变换
    this.showMatrixMultiplication(hidden, weights2, output);
    
    // 展示维度变化
    this.showDimensionChange(
      input.length,
      hidden.length,
      output.length
    );
  }
}

4. 嵌入空间可视化

4.1 Token嵌入可视化

使用降维技术展示高维嵌入空间:

typescript
class EmbeddingVisualizer {
  async visualizeEmbeddings(
    embeddings: Float32Array[],
    tokens: string[],
    method: 'pca' | 'tsne' | 'umap' = 'tsne'
  ) {
    // 降维到3D
    const reduced = await this.reduceDimension(embeddings, 3, method);
    
    // 创建3D散点图
    const scatter3D = this.create3DScatter(reduced, tokens);
    
    // 添加语义聚类
    const clusters = this.clusterEmbeddings(reduced);
    this.colorByCluster(scatter3D, clusters);
    
    // 添加交互
    this.addHoverTooltip(scatter3D, tokens);
    this.addClickSelection(scatter3D);
  }
  
  private async reduceDimension(
    data: Float32Array[],
    targetDim: number,
    method: string
  ): Promise<Float32Array[]> {
    switch (method) {
      case 'pca':
        return this.pca(data, targetDim);
      case 'tsne':
        return this.tsne(data, targetDim);
      case 'umap':
        return this.umap(data, targetDim);
      default:
        return this.pca(data, targetDim);
    }
  }
}

4.2 位置编码可视化

展示位置编码如何注入位置信息:

typescript
class PositionalEncodingVisualizer {
  visualizePositionalEncoding(maxLen: number, dModel: number) {
    // 计算位置编码矩阵
    const posEnc = this.calculatePositionalEncoding(maxLen, dModel);
    
    // 热力图展示
    const heatmap = this.createHeatmap(posEnc);
    
    // 展示正弦/余弦波形
    for (let dim = 0; dim < Math.min(dModel, 8); dim++) {
      const wave = this.extractWaveform(posEnc, dim);
      this.plotWaveform(wave, `Dimension ${dim}`);
    }
    
    // 展示相对位置关系
    this.showRelativePositionSimilarity(posEnc);
  }
  
  private calculatePositionalEncoding(maxLen: number, dModel: number): number[][] {
    const posEnc: number[][] = [];
    
    for (let pos = 0; pos < maxLen; pos++) {
      const row: number[] = [];
      for (let i = 0; i < dModel; i++) {
        const angle = pos / Math.pow(10000, (2 * (i % 2)) / dModel);
        row.push(i % 2 === 0 ? Math.sin(angle) : Math.cos(angle));
      }
      posEnc.push(row);
    }
    
    return posEnc;
  }
}

5. 交互功能

5.1 视图控制

typescript
class ViewController {
  private camera: THREE.PerspectiveCamera;
  private controls: OrbitControls;
  
  setupControls() {
    // 旋转
    this.controls.enableRotate = true;
    this.controls.rotateSpeed = 0.5;
    
    // 缩放
    this.controls.enableZoom = true;
    this.controls.zoomSpeed = 1.0;
    
    // 平移
    this.controls.enablePan = true;
    this.controls.panSpeed = 0.8;
    
    // 阻尼效果
    this.controls.enableDamping = true;
    this.controls.dampingFactor = 0.05;
  }
  
  // 预设视角
  setPresetView(view: 'front' | 'side' | 'top' | 'isometric') {
    const positions = {
      front: [0, 0, 50],
      side: [50, 0, 0],
      top: [0, 50, 0],
      isometric: [30, 30, 30]
    };
    
    this.animateCameraTo(positions[view]);
  }
}

5.2 图层筛选

typescript
class LayerFilter {
  private visibleLayers: Set<number> = new Set();
  private visibleHeads: Set<number> = new Set();
  
  toggleLayer(layerIndex: number) {
    if (this.visibleLayers.has(layerIndex)) {
      this.visibleLayers.delete(layerIndex);
    } else {
      this.visibleLayers.add(layerIndex);
    }
    this.updateVisibility();
  }
  
  toggleHead(headIndex: number) {
    if (this.visibleHeads.has(headIndex)) {
      this.visibleHeads.delete(headIndex);
    } else {
      this.visibleHeads.add(headIndex);
    }
    this.updateVisibility();
  }
  
  private updateVisibility() {
    // 更新3D场景中各层的可见性
    this.scene.traverse((object) => {
      if (object.userData.layerIndex !== undefined) {
        object.visible = this.visibleLayers.has(object.userData.layerIndex);
      }
      if (object.userData.headIndex !== undefined) {
        object.visible = this.visibleHeads.has(object.userData.headIndex);
      }
    });
  }
}

5.3 动画控制

typescript
class AnimationController {
  private animationSpeed: number = 1.0;
  private isPlaying: boolean = false;
  private currentFrame: number = 0;
  
  play() {
    this.isPlaying = true;
    this.animate();
  }
  
  pause() {
    this.isPlaying = false;
  }
  
  setSpeed(speed: number) {
    this.animationSpeed = speed;
  }
  
  seek(frame: number) {
    this.currentFrame = frame;
    this.renderFrame(frame);
  }
  
  private animate() {
    if (!this.isPlaying) return;
    
    this.currentFrame += this.animationSpeed;
    this.renderFrame(Math.floor(this.currentFrame));
    
    requestAnimationFrame(() => this.animate());
  }
}

6. 导出与分享

6.1 截图导出

typescript
class ScreenshotExporter {
  exportScreenshot(format: 'png' | 'jpg' = 'png', quality: number = 1.0) {
    const canvas = this.renderer.domElement;
    
    // 渲染一帧
    this.renderer.render(this.scene, this.camera);
    
    // 导出为图片
    const dataURL = canvas.toDataURL(`image/${format}`, quality);
    
    // 创建下载链接
    const link = document.createElement('a');
    link.download = `llm-viz-${Date.now()}.${format}`;
    link.href = dataURL;
    link.click();
  }
}

6.2 状态分享

typescript
class StateSharer {
  shareCurrentState() {
    const state = {
      inputText: this.getInputText(),
      selectedLayer: this.getSelectedLayer(),
      selectedHead: this.getSelectedHead(),
      viewPosition: this.getCameraPosition(),
      filters: this.getActiveFilters()
    };
    
    // 序列化为URL参数
    const params = btoa(JSON.stringify(state));
    const shareURL = `${window.location.origin}?state=${params}`;
    
    // 复制到剪贴板
    navigator.clipboard.writeText(shareURL);
    
    return shareURL;
  }
  
  loadSharedState(params: string) {
    const state = JSON.parse(atob(params));
    
    this.setInputText(state.inputText);
    this.setSelectedLayer(state.selectedLayer);
    this.setSelectedHead(state.selectedHead);
    this.setCameraPosition(state.viewPosition);
    this.setActiveFilters(state.filters);
  }
}

下一步

了解了核心功能后,让我们学习如何在实际项目中应用LLM Visualization