前言
简单学习一下 tensorboard 在 pytorch 中的用法
安装
除了 Pytorch 以外还需要额外安装 tensorboard
1 | pip install tensorboard |
基本原理
tensorboard 会生成一个(或多个)记录文件,通过在记录文件中添加数据,我们可以在网页端看到可视化处理后的数据,非常方便。
基本用法
使用 tensorboard 的第一步是创建一个 SummaryWriter
对象
1 | import torch |
SummaryWriter
是 记录文件 (即
event
文件)的句柄,通过对 SummaryWriter
的方法调用,我们可以不断向记录文件中插入新的数据。
当不需要在进行修改时,使用 writer.close()
关闭文件。
除此之外,想要在网页端查看可视化数据,还需要在控制台窗口运行
1 | tensorboard --logdir=runs |
其中 runs
是默认的 记录文件的目录,
可以自行修改,随后在 http://localhost:6006/
就可以查看啦~
小技巧
- 分组:
tag
参数设为类似acc/train/bs
的以/
分割的形式会将第一部分相同的放在一组。 - 多线一图:事实上 tensorboard 是按 tag
绘图,所以一张图片中如果想分开绘制多个曲线的话可以创建多个
SummaryWriter
句柄,并使用相同的标签进行绘图,最后可视化结果就会显示在同一张图片上。 - 远程服务器:
- 使用 ssh 端口转发,在本地执行
ssh -L 6006:127.0.0.1:6006 username@server_ip
登录后即可使用 http://localhost:6006/ 查看 - 使用 Vscode 的 SSH 连接的话可以直接在 Vscode 中打开可视化界面,也可以自动转到浏览器界面使用 http://localhost:6006/ 查看(其实也是 Vscode 后台执行了端口转发,在端口控制台可以看到转发的端口)
- 使用 MobaXterm 的 Tunneling 可以配置端口转发,选择 local port forwarding 后填写转发表即可
- 使用 ssh 端口转发,在本地执行
- 拼图:可以通过
torchvision.utils.make_grid()
将 batch 中多张图片拼接为一张图片从而使用add_image
添加而不用add_images
- 上下文管理器:可以通过
with SummaryWriter() as writer
的方法进行调用,省去了writer.close()
writer.flush()
立刻刷新缓存写入磁盘writer.close()
关闭读写连接
SummaryWriter
打开文件句柄,如果文件不存在则自动创建
1 | SummaryWriter(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix='') |
参数:
log_dir
: 记录文件的路径,建议合理进行分类方便管理,默认是runs/CURRENT_DATETIME_HOSTNAME
, 例如runs/Jul11_21-39-43_LabServer
\(\rightarrow\) 这个是文件夹不是文件comment
: 为log_dir
默认值添加的后缀,如果指定了log_dir
那么该参数无效purge_step
: 大意是如果程序崩溃了,重启后丢失的部分及其后面的部分不会显示,大概用不上max_queue
: 缓存队列的大小,队列满了就会调用flush
将缓存写入磁盘。flush_secs
: 刷新缓存的频率间隔,定时将刷新缓存,单位是秒filename_suffix
: 所有event
文件添加后缀
scalar
插入标量有两个方法,分别用于插入一个标量点数据和多个标量点数据
1 | add_scalar(tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False) |
参数:
tag(string)
: 数据标识符, 可以使用小技巧中提到的方法进行分组scalar_value (float or string/blobname)
: 需要添加的数据点的值global_step (int)
: 数据点的全局步值(即横坐标值)walltime (float)
: 记录数据点的时间(单位: 秒),默认是time.time()
new_style (boolean)
: 是否使用新样式(测试了下没看出来有什么区别...)
1 | add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None) |
与上一个的区别是可以绘制多条曲线
参数:
main_tag(string)
: 数据标识符tag_scalar_dict (dict)
: 键值对,分别是数据的后缀即数据值大小global_step, walltime
: 见 scalar
histogram
直方图,不太懂
1 | add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None) |
参数:
values (torch.Tensor, numpy.array, or string/blobname)
: 直方图的值bins (string)
: 可行项为{‘tensorflow’,’auto’, ‘fd’, …}
, 参考 https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html- 其他参数参考 scalar
Image
顾名思义,记录图片
1 | add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW') |
参数:
img_tensor (torch.Tensor, numpy.array, or string/blobname)
: 图片张量, 默认是 \((3, H, W)\) 的,配合dataformats
食用dataformats (string)
: 类似CHW, HWC, HW, WH
这样的格式,用于指定图片张量的各个维度的含义- 其他参数参考 scalar
1 | add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW') |
参数与 add_image
类似,但张量形状为 \((N, 3, H, W)\) 或者根据
dataformats
进行定义,类似
NCHW, NHWC, CHW, HWC, HW, WH
figure
这里的 figure 特指 matplotlib 中绘制的图像。
1 | add_figure(tag, figure, global_step=None, close=True, walltime=None) |
参数:
figure (matplotlib.pyplot.figure)
: Figure 或者是包含一些 figures 的列表close (bool)
: 是否自动关闭图像- 其他参数参考 scalar
video
视频
1 | add_video(tag, vid_tensor, global_step=None, fps=4, walltime=None) |
参数:
vid_tensor (torch.Tensor)
: 维度为 \((N, T, C, H, W)\), 数据类型为 \([0, 1]\) 的 float 或 \([0,255]\) 的 uint8.fps (float or int)
: 帧频,即 frames per second.- 其他参数参考 scalar
audio
1 | add_audio(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None) |
snd_tensor 的维度为 \((1, L)\)
数值范围为 \([-1, 1]\),
sample_rate
是整数(int)表示的采样率。
text
1 | add_text(tag, text_string, global_step=None, walltime=None) |
与 scalar 类似,至于 text_string
应该不用介绍了吧~
graph
绘制可视化的模型结构图
1 | add_graph(model, input_to_model=None, verbose=False, use_strict_trace=True) |
参数:
model (torch.nn.Module)
: 目标模型input_to_model (torch.Tensor or list of torch.Tensor)
: 输入到模型的张量verbose (bool)
: 是否绘制到控制台use_strict_trace
: 是否严格传递关键字参数给torch.jit.trace
. (不太懂有什么用)
embedding
将高维特征投影到低维进行可视化
1 | add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None) |
参数:
mat (torch.Tensor or numpy.array)
: \((N, D)\), 指 \(N\) 个 \(D\) 维的特征,即每行都是一个 \(D\) 维的特征metadata (list[str])
: 标签列表label_img (torch.Tensor)
: 标签图片,形状是 \((N, C, H, W)\), 即每个点对应的图片表示
注意,需要在 PROJECTOR 界面查看。
注意:需要重启 tensorboard 才能查看
其他
用到再补充...
1 | add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None) |
1 | add_mesh(tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None) |
1 | add_hparams(hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None) |