Appearance
LLM Visualization 架构与原理
整体架构
LLM Visualization采用分层架构设计,将模型推理、数据处理和可视化渲染分离,确保系统的可维护性和扩展性。
┌─────────────────────────────────────────────────────────────┐
│ 用户界面层 (UI Layer) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 3D渲染视图 │ │ 控制面板 │ │ 信息展示 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────┐
│ 可视化引擎层 (Visualization) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 场景管理器 │ │ 动画控制器 │ │ 交互处理器 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────┐
│ 模型推理层 (Model Inference) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ ONNX Runtime│ │ 注意力计算 │ │ 状态捕获 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────┐
│ 数据管理层 (Data Management) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 模型加载器 │ │ 状态缓存 │ │ 配置管理 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘核心组件详解
1. 模型推理引擎
ONNX Runtime Web
LLM Visualization使用ONNX Runtime Web在浏览器中运行模型推理:
typescript
// 模型加载和推理示例
class ModelInference {
private session: ort.InferenceSession;
async loadModel(modelPath: string) {
this.session = await ort.InferenceSession.create(modelPath, {
executionProviders: ['wasm'],
graphOptimizationLevel: 'all'
});
}
async runInference(inputTokens: number[]) {
const inputTensor = new ort.Tensor('int64', inputTokens, [1, inputTokens.length]);
const outputs = await this.session.run({ input: inputTensor });
return outputs;
}
}注意力计算捕获
为了可视化注意力权重,系统需要捕获中间计算结果:
typescript
interface AttentionState {
query: Float32Array; // Q矩阵
key: Float32Array; // K矩阵
value: Float32Array; // V矩阵
attentionWeights: Float32Array; // 注意力权重
output: Float32Array; // 输出
}
class AttentionCapture {
captureLayer(layerIndex: number): AttentionState {
// 捕获指定层的注意力状态
return {
query: this.extractQuery(layerIndex),
key: this.extractKey(layerIndex),
value: this.extractValue(layerIndex),
attentionWeights: this.computeAttentionWeights(layerIndex),
output: this.extractOutput(layerIndex)
};
}
}2. 可视化引擎
3D场景管理
使用Three.js构建交互式3D场景:
typescript
class SceneManager {
private scene: THREE.Scene;
private camera: THREE.PerspectiveCamera;
private renderer: THREE.WebGLRenderer;
private controls: OrbitControls;
constructor(container: HTMLElement) {
this.scene = new THREE.Scene();
this.camera = new THREE.PerspectiveCamera(75, aspect, 0.1, 1000);
this.renderer = new THREE.WebGLRenderer({ antialias: true });
this.controls = new OrbitControls(this.camera, this.renderer.domElement);
this.setupLighting();
this.setupCamera();
}
private setupLighting() {
const ambientLight = new THREE.AmbientLight(0x404040, 0.5);
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
directionalLight.position.set(10, 10, 5);
this.scene.add(ambientLight);
this.scene.add(directionalLight);
}
}节点可视化
将神经网络层可视化为3D节点:
typescript
class NodeVisualization {
createTokenNode(token: string, position: Vector3): THREE.Mesh {
const geometry = new THREE.SphereGeometry(0.5, 32, 32);
const material = new THREE.MeshPhongMaterial({
color: this.getTokenColor(token),
transparent: true,
opacity: 0.8
});
const node = new THREE.Mesh(geometry, material);
node.position.copy(position);
// 添加标签
const label = this.createLabel(token);
node.add(label);
return node;
}
createAttentionConnection(
from: Vector3,
to: Vector3,
weight: number
): THREE.Line {
const geometry = new THREE.BufferGeometry().setFromPoints([from, to]);
const material = new THREE.LineBasicMaterial({
color: 0xff6b6b,
transparent: true,
opacity: weight * 0.8
});
return new THREE.Line(geometry, material);
}
}3. 数据流管理
状态管理架构
使用Redux管理应用状态:
typescript
// 状态定义
interface AppState {
model: {
currentModel: string;
loading: boolean;
error: string | null;
};
inference: {
inputText: string;
tokens: Token[];
currentStep: number;
isRunning: boolean;
};
visualization: {
selectedLayer: number;
selectedHead: number;
viewMode: '3d' | '2d';
animationSpeed: number;
};
}
// Action定义
const actions = {
setInputText: (text: string) => ({
type: 'SET_INPUT_TEXT',
payload: text
}),
runInference: () => ({
type: 'RUN_INFERENCE'
}),
updateVisualization: (config: VisualizationConfig) => ({
type: 'UPDATE_VISUALIZATION',
payload: config
})
};数据转换管道
将模型输出转换为可视化数据:
typescript
class DataTransformer {
transformAttentionData(
rawAttention: Float32Array,
tokens: string[]
): AttentionVisualizationData {
const size = tokens.length;
const weights: number[][] = [];
for (let i = 0; i < size; i++) {
const row: number[] = [];
for (let j = 0; j < size; j++) {
row.push(rawAttention[i * size + j]);
}
weights.push(row);
}
return {
tokens,
weights,
maxWeight: Math.max(...weights.flat()),
minWeight: Math.min(...weights.flat())
};
}
transformEmbeddingData(
embeddings: Float32Array,
dimension: number
): EmbeddingVisualizationData {
// 使用PCA或t-SNE降维到3D
const reduced = this.reduceDimension(embeddings, dimension, 3);
return {
points: reduced,
dimension: 3,
originalDimension: dimension
};
}
}可视化原理
1. 注意力权重可视化
热力图渲染
将注意力矩阵渲染为彩色热力图:
typescript
class AttentionHeatmap {
createHeatmap(weights: number[][]): HTMLCanvasElement {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d')!;
const size = weights.length;
canvas.width = size * 20;
canvas.height = size * 20;
for (let i = 0; i < size; i++) {
for (let j = 0; j < size; j++) {
const weight = weights[i][j];
const color = this.weightToColor(weight);
ctx.fillStyle = color;
ctx.fillRect(j * 20, i * 20, 20, 20);
}
}
return canvas;
}
private weightToColor(weight: number): string {
// 使用蓝色到红色的渐变
const r = Math.floor(weight * 255);
const b = Math.floor((1 - weight) * 255);
return `rgb(${r}, 50, ${b})`;
}
}2. Token生成过程可视化
逐步动画
展示token生成的逐步过程:
typescript
class TokenGenerationAnimation {
private currentStep: number = 0;
private steps: GenerationStep[];
async animateGeneration() {
for (const step of this.steps) {
await this.animateStep(step);
this.currentStep++;
}
}
private async animateStep(step: GenerationStep) {
// 1. 高亮当前输入token
await this.highlightInputTokens(step.inputTokens);
// 2. 展示注意力计算
await this.showAttentionComputation(step.attentionWeights);
// 3. 展示前馈网络计算
await this.showFeedForward(step.hiddenStates);
// 4. 展示输出概率分布
await this.showProbabilityDistribution(step.probabilities);
// 5. 生成新token
await this.generateToken(step.outputToken);
}
}3. 网络层可视化
分层展示
将Transformer的各层以3D堆叠方式展示:
typescript
class LayerVisualization {
createLayerStack(numLayers: number): THREE.Group {
const group = new THREE.Group();
for (let i = 0; i < numLayers; i++) {
const layer = this.createLayer(i);
layer.position.y = i * 3; // 每层间隔3个单位
group.add(layer);
// 添加层间连接线
if (i > 0) {
const connections = this.createLayerConnections(i - 1, i);
group.add(connections);
}
}
return group;
}
private createLayer(layerIndex: number): THREE.Group {
const layer = new THREE.Group();
// 自注意力子层
const attention = this.createAttentionBlock(layerIndex);
layer.add(attention);
// 前馈网络子层
const ffn = this.createFFNBlock(layerIndex);
ffn.position.y = 1.5;
layer.add(ffn);
return layer;
}
}性能优化
1. 渲染优化
typescript
class RenderOptimizer {
// 使用InstancedMesh批量渲染相同几何体
createInstancedNodes(count: number): THREE.InstancedMesh {
const geometry = new THREE.SphereGeometry(0.5, 16, 16);
const material = new THREE.MeshPhongMaterial({ color: 0x4ecdc4 });
return new THREE.InstancedMesh(geometry, material, count);
}
// LOD(细节层次)系统
setupLOD(object: THREE.Object3D, distances: number[]) {
const lod = new THREE.LOD();
// 添加不同细节层次的模型
lod.addLevel(highDetailModel, 0);
lod.addLevel(mediumDetailModel, distances[0]);
lod.addLevel(lowDetailModel, distances[1]);
return lod;
}
}2. 计算优化
typescript
class ComputationOptimizer {
// Web Worker并行计算
private worker: Worker;
constructor() {
this.worker = new Worker('inference-worker.js');
}
async computeInWorker(data: Float32Array): Promise<Float32Array> {
return new Promise((resolve) => {
this.worker.onmessage = (e) => resolve(e.data);
this.worker.postMessage(data);
});
}
// 缓存计算结果
private cache: Map<string, Float32Array> = new Map();
getCachedResult(key: string): Float32Array | undefined {
return this.cache.get(key);
}
setCachedResult(key: string, value: Float32Array) {
this.cache.set(key, value);
}
}扩展架构
插件系统
支持自定义可视化插件:
typescript
interface VisualizationPlugin {
name: string;
version: string;
// 初始化插件
initialize(context: PluginContext): void;
// 注册可视化组件
registerVisualizations(registry: VisualizationRegistry): void;
// 处理模型输出
processOutput(output: ModelOutput): VisualizationData;
}
class PluginManager {
private plugins: Map<string, VisualizationPlugin> = new Map();
registerPlugin(plugin: VisualizationPlugin) {
this.plugins.set(plugin.name, plugin);
plugin.initialize(this.context);
}
getPlugin(name: string): VisualizationPlugin | undefined {
return this.plugins.get(name);
}
}下一步
了解了架构和原理后,让我们继续探索LLM Visualization的核心功能详解。
