wordpress 本地 域名绑定,兰州网站seo收费,电子商务营销手段有哪些,wordpress快速编辑添加多个标签1.背景介绍
在训练的模型的时候#xff0c;需要评价模型的好坏#xff0c;就涉及到混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算。
2、混淆矩阵的概念
所谓的混淆矩阵如下表所示#xff1a; TP:真正类#xff0c;真的正例被预测为正例
FN:假负类#xf…1.背景介绍
在训练的模型的时候需要评价模型的好坏就涉及到混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算。
2、混淆矩阵的概念
所谓的混淆矩阵如下表所示 TP:真正类真的正例被预测为正例
FN:假负类样本为正例被预测为负类
FP:假正类 原本实际为负但是被预测为正例
TN:真负类真的负样本被预测为负类。
从混淆矩阵当中可以得到更高级的分类指标Accuracy准确率Precision查准率Recall查全率Specificity特异性Sensitivity灵敏度。
3. 常用的分类指标
3.1 Accuracy准确率
不管是哪个类别只要预测正确其数量都放在分子上而分母是全部数据量。常用于表示模型的精度当数据类别不平衡时不能用于模型的评价。 3.2 Precision查准率
即所有预测为正的样本中预测正确的样本的所占的比重。 3.3 Recall查全率
真实的为正的样本被正确检测出来的比重。 3.4 Specificity特异性
特异性指标也称 负正类率False Positive Rate, FPR计算的是模型错识别为正类的负类样本占所有负类样本的比例一般越低越好。 3.5 DSCDice coefficient
Dice系数是一种相似性度量度量二进制图像分割的准确性。
如图所示红色的框的区域时Groudtruth而蓝色的框为预测值Prediction。 3.6 IoU交并比 3.7 Sensitivity灵敏度
反应的时预测正确的区域在Groundtruth中所占的比重。 4. 计算程序
ConfusionMatrix 这个类可以直接计算出混淆矩阵
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import errno
import osclass SmoothedValue(object):Track a series of values and provide access to smoothed values over awindow or the global series average.def __init__(self, window_size20, fmtNone):if fmt is None:fmt {value:.4f} ({global_avg:.4f})self.deque deque(maxlenwindow_size)self.total 0.0self.count 0self.fmt fmtdef update(self, value, n1):self.deque.append(value)self.count nself.total value * ndef synchronize_between_processes(self):Warning: does not synchronize the deque!if not is_dist_avail_and_initialized():returnt torch.tensor([self.count, self.total], dtypetorch.float64, devicecuda)dist.barrier()dist.all_reduce(t)t t.tolist()self.count int(t[0])self.total t[1]propertydef median(self):d torch.tensor(list(self.deque))return d.median().item()propertydef avg(self):d torch.tensor(list(self.deque), dtypetorch.float32)return d.mean().item()propertydef global_avg(self):return self.total / self.countpropertydef max(self):return max(self.deque)propertydef value(self):return self.deque[-1]def __str__(self):return self.fmt.format(medianself.median,avgself.avg,global_avgself.global_avg,maxself.max,valueself.value)class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes num_classesself.mat Nonedef update(self, a, b):n self.num_classesif self.mat is None:# 创建混淆矩阵self.mat torch.zeros((n, n), dtypetorch.int64, devicea.device)with torch.no_grad():# 寻找GT中为目标的像素索引k (a 0) (a n)# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)inds n * a[k].to(torch.int64) b[k]self.mat torch.bincount(inds, minlengthn**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h self.mat.float()# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)acc_global torch.diag(h).sum() / h.sum()# 计算每个类别的准确率acc torch.diag(h) / h.sum(1)# 计算每个类别预测与真实目标的iouiu torch.diag(h) / (h.sum(1) h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu self.compute()return (global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}).format(acc_global.item() * 100,[{:.1f}.format(i) for i in (acc * 100).tolist()],[{:.1f}.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)class DiceCoefficient(object):def __init__(self, num_classes: int 2, ignore_index: int -100):self.cumulative_dice Noneself.num_classes num_classesself.ignore_index ignore_indexself.count Nonedef update(self, pred, target):if self.cumulative_dice is None:self.cumulative_dice torch.zeros(1, dtypepred.dtype, devicepred.device)if self.count is None:self.count torch.zeros(1, dtypepred.dtype, devicepred.device)# compute the Dice score, ignoring backgroundpred F.one_hot(pred.argmax(dim1), self.num_classes).permute(0, 3, 1, 2).float()dice_target build_target(target, self.num_classes, self.ignore_index)self.cumulative_dice multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_indexself.ignore_index)self.count 1propertydef value(self):if self.count 0:return 0else:return self.cumulative_dice / self.countdef reset(self):if self.cumulative_dice is not None:self.cumulative_dice.zero_()if self.count is not None:self.count.zeros_()def reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.cumulative_dice)torch.distributed.all_reduce(self.count)分类指标的计算
import torch# SR : Segmentation Result
# GT : Ground Truthdef get_accuracy(SR,GT,threshold0.5):SR SR thresholdGT GT torch.max(GT)corr torch.sum(SRGT)tensor_size SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)acc float(corr)/float(tensor_size)return accdef get_sensitivity(SR,GT,threshold0.5):# Sensitivity RecallSR SR thresholdGT GT torch.max(GT)# TP : True Positive# FN : False NegativeTP ((SR1)(GT1))2FN ((SR0)(GT1))2SE float(torch.sum(TP))/(float(torch.sum(TPFN)) 1e-6) return SEdef get_specificity(SR,GT,threshold0.5):SR SR thresholdGT GT torch.max(GT)# TN : True Negative# FP : False PositiveTN ((SR0)(GT0))2FP ((SR1)(GT0))2SP float(torch.sum(TN))/(float(torch.sum(TNFP)) 1e-6)return SPdef get_precision(SR,GT,threshold0.5):SR SR thresholdGT GT torch.max(GT)# TP : True Positive# FP : False PositiveTP ((SR1)(GT1))2FP ((SR1)(GT0))2PC float(torch.sum(TP))/(float(torch.sum(TPFP)) 1e-6)return PCdef get_F1(SR,GT,threshold0.5):# Sensitivity RecallSE get_sensitivity(SR,GT,thresholdthreshold)PC get_precision(SR,GT,thresholdthreshold)F1 2*SE*PC/(SEPC 1e-6)return F1def get_JS(SR,GT,threshold0.5):# JS : Jaccard similaritySR SR thresholdGT GT torch.max(GT)Inter torch.sum((SRGT)2)Union torch.sum((SRGT)1)JS float(Inter)/(float(Union) 1e-6)return JSdef get_DC(SR,GT,threshold0.5):# DC : Dice CoefficientSR SR thresholdGT GT torch.max(GT)Inter torch.sum((SRGT)2)DC float(2*Inter)/(float(torch.sum(SR)torch.sum(GT)) 1e-6)return DC参考文献
混淆矩阵的概念-CSDN博客