京东网站推广方式,设计网站公司的口号,商融交通建设工程有限公司网站,企业培训课程ppt0、基础提示
1、FLOPS是用来衡量硬件算力的指标#xff0c;FLOPs用来衡量模型复杂度。 2、MAC 一般为 FLOPs的2倍 3、并非FLOPs越小在硬件上就一定运行更快#xff0c;还与模型占用的内存#xff0c;带宽#xff0c;等有关
1、FLOPs计算
神经网络参数量。用于衡量模型大…0、基础提示
1、FLOPS是用来衡量硬件算力的指标FLOPs用来衡量模型复杂度。 2、MAC 一般为 FLOPs的2倍 3、并非FLOPs越小在硬件上就一定运行更快还与模型占用的内存带宽等有关
1、FLOPs计算
神经网络参数量。用于衡量模型大小。一般卷积计算方式为 F L O P s 2 ∗ H W ( K h ∗ K w ∗ C i n 1 ) C o u t FLOPs 2*HW(Kh*Kw*Cin1)Cout FLOPs2∗HW(Kh∗Kw∗Cin1)Cout 其中 HW表示该层卷积的高和宽 KhKw表示卷积核的高和宽 2 表示一次乘操作 一次加操作 1 表示bias操作
2、统计工具-THOP
源代码链接
2.1 安装
pip install thop或
pip install --upgrade githttps://github.com/Lyken17/pytorch-OpCounter.git2.2 基础使用
from torchvision.models import resnet50
from thop import profile
model resnet50()
input torch.randn(1, 3, 224, 224)
macs, params profile(model, inputs(input, ))2.3 定义自己的规则
class YourModule(nn.Module):# your definitiondef count_your_model(model, x, y):# your rule hereinput torch.randn(1, 3, 224, 224)
macs, params profile(model, inputs(input, ), custom_ops{YourModule: count_your_model})2.4 模型包含多个输入
修改input就好
from torchvision.models import resnet50
from thop import profile
model resnet50()
input1 input2 torch.randn(1, 3, 224, 224)
macs, params profile(model, inputs(input1, input2,))3、 统计工具-torchstat
这个是我更中意的因为他统计信息更加丰富包含params,memory, Madd, FLOPs等。缺点在于已经不更新了且不支持多输入好在我们可以修改代码支持。 源代码链接
3.1 安装
pip install torchstat3.2 基础使用
from torchstat import stat
import torchvision.models as models
model models.resnet18()
stat(model, (3, 224, 224))3.3 输入多个Input
将torchstat 库安装目录下的 torchstat/statistics.py 中按如下修改
class ModelStat(object):def __init__(self, model, input_size, query_granularity1):assert isinstance(model, nn.Module)# 删除输入长度为3的限制# assert isinstance(input_size, (tuple, list)) and len(input_size) 3assert isinstance(input_size, (tuple, list))self._model modelself._input_size input_sizeself._query_granularity query_granularity将torchstat 库安装目录下的 torchstat/model_hook.py 中按如下修改
class ModelHook(object):def __init__(self, model, input_size):assert isinstance(model, nn.Module)assert isinstance(input_size, (list, tuple))self._model model# 原始是通过单个输入的尺寸再构建输入tensor我们可以修改为在网络外构建输入tensor后直接送入网络# self._input_size input_sizeself._origin_call dict() # sub module call hookself._hook_model()# x torch.rand(1, *self._input_size) # add module duration timeself._model.eval()# self._model(x)self._model(*self._input_size)使用时候测试代码
from torchstat import stat
import torchvision.models as models
model models.resnet18()
input1, input2 torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224)
stat(model, (input1, input2))大致改动就是这样了还有什么bug可以自己稍微修改一下哈。另外找修改地方可以看报错提示torchstat安装路径修改。
4、fvcore
stat有个很麻烦的问题是他不支持transformer因此包含transformer的网络可以使用fvcore他是Facebook开源的一个轻量级的核心库。
4.1、 安装
pip install fvcore4.2、 基础使用
from fvcore.nn import FlopCountAnalysis, parameter_count_table
# 创建网络
model MobileViTBlock(in_channels32, transformer_dim64, ffn_dim256)# 创建输入网络的tensor
tensor (torch.rand(1, 32, 64, 64),)# 分析FLOPs
flops FlopCountAnalysis(model, tensor)
print(FLOPs: , flops.total())# 分析parameters
print(parameter_count_table(model))参考来自https://zhuanlan.zhihu.com/p/583106030
欢迎交流补充