SnapViewer: 更快的PyTorch显存分配可视化

背景

在使用 PyTorch 训练模型时,内存不足(OOM)错误是很常见的,因此需要对 GPU 内存进行优化。当简单的方法(如减少batch size)不再有效时,就需要分析模型本身的内存占用情况。

此时,你可能会看到这份文档,它教你如何记录内存快照并在网站上进行可视化。

但是这里存在一个严重的问题:这个网站性能比较差。

  • 如果你的模型较小,快照只有几 MB,性能还可以接受。
  • 但是如果你的模型很大,快照达到几十甚至上百 MB,网站就会变得极慢,帧率可能低至每分钟 2-3 帧(非笔误)。

我研究了网站的 JavaScript 代码,其主要功能是:

  1. 手动加载 Python 的 pickle 文件;
  2. 每次视口发生变化时重新解析原始数据为图形表示,然后将其渲染到屏幕上。

这些解析逻辑用 JavaScript 编写,你可以想象一下每帧执行这些操作,处理上百 MB 数据要多久.

灵感

我当前的工作包括优化一个非LLM的深度学习模型。在处理数十亿参数的模型所导出的显存snapshot时,我遇到了这个问题。

为什么不用现有的 LLM 基础设施而选择手动优化?简单地说,这个模型是研究人员自定义设计的,其中包含许多与标准 LLM 完全不同的模块。现在似乎每个人都认为深度学习就只是关于 LLM——以至于一些技术负责人也认为 LLM 的基础设施可以很容易地适配到其他模型上……不过我有点跑题了。

最初,我写了一个简单的脚本来解析快照内容,希望能发现模型中的内存分配问题。但是在一个月的工作中,我发现我还是需要一个带有GUI的可视化器, 于是我开发了 SnapViewer.

简而言之:内存快照的图形数据被解析并呈现为一个巨大的三角形mesh,利用现有的渲染库来高效处理网格渲染。

下面是一个 100 MB 以上的快照在我的集成显卡上流畅运行的截图:

snapviewer

实现

参考实现

快照格式在 record_memory_history 函数的docstring中有部分记录。但这份文档并不完整, 可能是后续commit的人懒得更新docstring了.

实际将快照解析为字典的过程发生在这里

  1. 该脚本将分配器跟踪转换为内存时间线,然后传递给网页查看器的 JS 代码。
  2. JS 代码进一步将其转换为多边形(表示分配),用于可视化。每个多边形对应一个分配,存储大小和调用栈等细节。

实现:快照 (反)序列化

初始实现

我用 Python 实现这一部分,因为我需要处理 Python 原生数据结构。我只是简单地将字典转换为 JSON 文件。

优化

  1. 原始 JSON 文件太大 → 在写入前进行内存压缩(Python zipfile)。
  2. 在可视化过程中,从磁盘读取 ZIP 文件(Rust zip crate),并在内存中解压缩。
权衡
  • 这种方式在 JSON 解析过程中会导致短暂的内存峰值,但避免了持续的高内存使用。
  • 利用 Rust 的 serde-json(因为 Rust 的 serde-pickle 功能不全,不能处理递归结构)。

实现:渲染与交互

这部分用 Rust 实现。

渲染

  • 由于分配数据在可视化过程中是静态的,所有分配被合并成一个大的三角形mesh,并一次性发送到 GPU。

  • 使用的库:three-d

    • 提供良好的网格抽象。
    • 支持一次性上传到 GPU(无需每帧进行 CPU→GPU 传输)。
    • 处理鼠标/键盘事件。

窗口到世界坐标转换

  1. 步骤 1:将窗口坐标转换为世界坐标(缩放 + 窗口中心偏移)。
  2. 步骤 2:将世界坐标转换为内存位置(预定义的缩放)。

UI & 交互功能

内存刻度标记
  • 根据屏幕可见性动态调整标记的数量和精度。
  • 保持标记在屏幕上的固定位置,即使移动或缩放。
移动 & 缩放
  1. 跟踪原始缩放比例(1/zoom)。
  2. 更新到新的缩放级别,并计算新旧比例之间的比值。
  3. 根据鼠标不变的世界位置,调整屏幕中心位置。

实现:查询

在工作中使用这个工具一周后,我发现自己经常需要搜索内存快照,尤其是:

  • 找到特定时间戳内所有存活的分配
  • 找到调用栈中包含特定子字符串的所有分配
  • 最好按照分配大小降序排列分配

我最初的想法是构建一个简单的 REPL 和一个简单的命令解析器,将每个命令映射到特定的查询函数。

然而,在列出所有需要的功能后,我发现这其实是数据库查询的子集,尤其是 SQL。

因此我决定不再造轮子:我只是连接到一个内存中的 SQLite 数据库。用户交互非常简单:读取用户输入,让 SQLite 执行,并将输出格式化为人可读的格式。


如果你在使用 PyTorch 内存快照时遇到过困难,来看看吧!欢迎贡献和反馈。 ⭐