0%

TensorBoard

前言

简单学习一下 tensorboard 在 pytorch 中的用法


安装

除了 Pytorch 以外还需要额外安装 tensorboard

1
pip install tensorboard

基本原理

tensorboard 会生成一个(或多个)记录文件,通过在记录文件中添加数据,我们可以在网页端看到可视化处理后的数据,非常方便。


基本用法

使用 tensorboard 的第一步是创建一个 SummaryWriter 对象

1
2
3
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

SummaryWriter记录文件 (即 event 文件)的句柄,通过对 SummaryWriter 的方法调用,我们可以不断向记录文件中插入新的数据。

当不需要在进行修改时,使用 writer.close() 关闭文件。

除此之外,想要在网页端查看可视化数据,还需要在控制台窗口运行

1
tensorboard --logdir=runs

其中 runs 是默认的 记录文件的目录, 可以自行修改,随后在 http://localhost:6006/ 就可以查看啦~

小技巧

  • 分组:tag 参数设为类似 acc/train/bs 的以 / 分割的形式会将第一部分相同的放在一组。
  • 多线一图:事实上 tensorboard 是按 tag 绘图,所以一张图片中如果想分开绘制多个曲线的话可以创建多个 SummaryWriter 句柄,并使用相同的标签进行绘图,最后可视化结果就会显示在同一张图片上。
  • 远程服务器:
    • 使用 ssh 端口转发,在本地执行 ssh -L 6006:127.0.0.1:6006 username@server_ip 登录后即可使用 http://localhost:6006/ 查看
    • 使用 Vscode 的 SSH 连接的话可以直接在 Vscode 中打开可视化界面,也可以自动转到浏览器界面使用 http://localhost:6006/ 查看(其实也是 Vscode 后台执行了端口转发,在端口控制台可以看到转发的端口)
    • 使用 MobaXterm 的 Tunneling 可以配置端口转发,选择 local port forwarding 后填写转发表即可
  • 拼图:可以通过 torchvision.utils.make_grid() 将 batch 中多张图片拼接为一张图片从而使用 add_image 添加而不用 add_images
  • 上下文管理器:可以通过 with SummaryWriter() as writer 的方法进行调用,省去了 writer.close()
  • writer.flush() 立刻刷新缓存写入磁盘
  • writer.close() 关闭读写连接

SummaryWriter

打开文件句柄,如果文件不存在则自动创建

1
SummaryWriter(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix='')

参数:

  • log_dir: 记录文件的路径,建议合理进行分类方便管理,默认是 runs/CURRENT_DATETIME_HOSTNAME, 例如 runs/Jul11_21-39-43_LabServer \(\rightarrow\) 这个是文件夹不是文件
  • comment: 为 log_dir 默认值添加的后缀,如果指定了 log_dir 那么该参数无效
  • purge_step: 大意是如果程序崩溃了,重启后丢失的部分及其后面的部分不会显示,大概用不上
  • max_queue: 缓存队列的大小,队列满了就会调用 flush 将缓存写入磁盘。
  • flush_secs: 刷新缓存的频率间隔,定时将刷新缓存,单位是秒
  • filename_suffix: 所有 event 文件添加后缀

scalar

插入标量有两个方法,分别用于插入一个标量点数据和多个标量点数据

1
add_scalar(tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False)

参数:

  • tag(string): 数据标识符, 可以使用小技巧中提到的方法进行分组
  • scalar_value (float or string/blobname): 需要添加的数据点的值
  • global_step (int): 数据点的全局步值(即横坐标值)
  • walltime (float): 记录数据点的时间(单位: 秒),默认是 time.time()
  • new_style (boolean): 是否使用新样式(测试了下没看出来有什么区别...)
1
add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None)

与上一个的区别是可以绘制多条曲线

参数:

  • main_tag(string): 数据标识符
  • tag_scalar_dict (dict): 键值对,分别是数据的后缀即数据值大小
  • global_step, walltime: 见 scalar

histogram

直方图,不太懂

1
add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)

参数:

Image

顾名思义,记录图片

1
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

  • img_tensor (torch.Tensor, numpy.array, or string/blobname): 图片张量, 默认是 \((3, H, W)\) 的,配合 dataformats 食用
  • dataformats (string): 类似 CHW, HWC, HW, WH 这样的格式,用于指定图片张量的各个维度的含义
  • 其他参数参考 scalar
1
add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW')

参数与 add_image 类似,但张量形状为 \((N, 3, H, W)\) 或者根据 dataformats 进行定义,类似 NCHW, NHWC, CHW, HWC, HW, WH

figure

这里的 figure 特指 matplotlib 中绘制的图像。

1
add_figure(tag, figure, global_step=None, close=True, walltime=None)

参数:

  • figure (matplotlib.pyplot.figure): Figure 或者是包含一些 figures 的列表
  • close (bool): 是否自动关闭图像
  • 其他参数参考 scalar

video

视频

1
add_video(tag, vid_tensor, global_step=None, fps=4, walltime=None)

参数:

  • vid_tensor (torch.Tensor): 维度为 \((N, T, C, H, W)\), 数据类型为 \([0, 1]\)float\([0,255]\)uint8.
  • fps (float or int): 帧频,即 frames per second.
  • 其他参数参考 scalar

audio

1
add_audio(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None)

snd_tensor 的维度为 \((1, L)\) 数值范围为 \([-1, 1]\), sample_rate 是整数(int)表示的采样率。

text

1
add_text(tag, text_string, global_step=None, walltime=None)

scalar 类似,至于 text_string 应该不用介绍了吧~

graph

绘制可视化的模型结构图

1
add_graph(model, input_to_model=None, verbose=False, use_strict_trace=True)

参数:

  • model (torch.nn.Module): 目标模型
  • input_to_model (torch.Tensor or list of torch.Tensor): 输入到模型的张量
  • verbose (bool): 是否绘制到控制台
  • use_strict_trace: 是否严格传递关键字参数给 torch.jit.trace. (不太懂有什么用)

embedding

将高维特征投影到低维进行可视化

1
add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None)

参数:

  • mat (torch.Tensor or numpy.array): \((N, D)\), 指 \(N\)\(D\) 维的特征,即每行都是一个 \(D\) 维的特征
  • metadata (list[str]): 标签列表
  • label_img (torch.Tensor): 标签图片,形状是 \((N, C, H, W)\), 即每个点对应的图片表示

注意,需要在 PROJECTOR 界面查看。

注意:需要重启 tensorboard 才能查看

其他

用到再补充...

1
add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None)
1
add_mesh(tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None)
1
add_hparams(hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None)

参考资料

--- ♥ end ♥ ---

欢迎关注我呀~