相亲网站认识的可以做朋友,wordpress建站服务器,专业的东莞网站排名,济南软件网站建设博主一直一来做的都是基于Transformer的目标检测领域#xff0c;相较于基于卷积的目标检测方法#xff0c;如YOLO等#xff0c;其检测速度一直为人诟病。 终于#xff0c;RT-DETR横空出世#xff0c;在取得高精度的同时#xff0c;检测速度也大幅提升。
那么RT-DETR是如…博主一直一来做的都是基于Transformer的目标检测领域相较于基于卷积的目标检测方法如YOLO等其检测速度一直为人诟病。 终于RT-DETR横空出世在取得高精度的同时检测速度也大幅提升。
那么RT-DETR是如何做到的呢
在研究RT-DETR的改进前我们先来了解下DETR类目标检测方法的发展历程吧
首先是DETR该方法作为Transformer在目标检测领域的开山之作一经推出便引发了极大的轰动该方法巧妙的利用Transformer进行特征提取与解码同时通过匈牙利匹配方法完成预测框与真实框的匹配避免了NMS等后处理过程。随后DAB-DETR引入了动态锚框作为查询向量从而对DETR中的100个查询向量进行了解释。Deformable-DETR针对Transformer中自注意力计算复杂度高的问题提出可变形注意力计算即通过可学习的选取少量向量进行注意力计算大幅的降低了计算量。DN-DETR认为匈牙利匹配的二义性是导致DETR训练收敛慢的原因因此提出查询降噪机制即利用先前DAB-DETR中将查询向量解释为锚框的原理给查询向量添加一些噪声来辅助模型收敛最终大幅提升了模型的训练速度。DINO则是在DAB-DETR与DN-DETR的基础上进行进一步的融合与改进。H-DETR为使模型获取更多的正样本特征从而提升检测精度因此提出混合匹配方法在训练阶段包含原始的匈牙利匹配分支与一个一对多的辅助匹配分支而在推理阶段则只有一个匈牙利匹配分支。
然而上述方法尽管已经大幅提升了检测精度降低了计算复杂度但其受Transformer本身高计算复杂度的制约DETR类目标检测方法的实时性始终令人难以满意尤其是相较于YOLO等单阶段目标检测方法其检测速度的确差别巨大。
为了解决这个问题百度提出了RT-DETR该方法依旧是在DETR的基础上改进生成的从论文中给出的实验结果来看该方法无论在检测速度还是检测精度方法都已经超过了YOLOv8实现了真正的实时性。 创新点1高效混合编码器RT-DETR使用了一种高效的混合编码器通过解耦尺度内交互和跨尺度融合来处理多尺度特征。这种独特的基于视觉Transformer的设计降低了计算成本并允许实时物体检测。创新点2IoU感知查询选择RT-DETR通过利用IoU感知的查询选择改进了目标查询初始化。这使得模型能够聚焦于场景中最相关的目标从而提高了检测精度。创新点3自适应推理速度RT-DETR支持通过使用不同的解码器层来灵活调整推理速度而无需重新训练。这种适应性便于在各种实时目标检测场景中的实际应用。
RT-DETR的代码有两个一个是官方提供的代码但该代码功能比较单一只有训练与验证另一个则是集成在YOLOv8中该代码的设计就比较全面了
环境部署
conda create -n rtdetr python3.8
conda activate rtdetr
conda install pytorch2.0.1 torchvision0.15.2 torchaudio2.0.2 pytorch-cuda11.7 -c pytorch -c nvidia
cd RT-DETR-main/rtdetr_pytorch //这个路径根据你自己的改
pip install -r requirement.txt该算法的环境为pytorch2.0.1注意尽量要用pytorch2以上的版本否则可能会报错
AttributeError: module torchvision has no attribute disable_beta_transforms_warning官方模型训练
参数配置
该算法的配置封装较好我们只需要修改配置即可train.py指定要使用的骨干网络。
parser.add_argument(--config, -c, default/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml,typestr, )修改数据集配置文件RT-DETR-main\rtdetr_pytorch\configs\dataset\coco_detection.yml 修改训练集与测试集路径同时修改类别数。 随后便可以开启训练该文件中指定 epochs
RT-DETR-main\rtdetr_pytorch\configs\rtdetr\include\optimizer.yml首次训练需要下载骨干网络的预训练模型 在这里博主使用ResNet18作为骨干特征提取网络
训练结果
开始运行查看GPU使用情况此时的batch-size8可以看到显存占用4.5G左右相较于博主先前提出的方法或者DINO其显存占用少了许多DINO的batch-size2时的显存占用将近16G. 训练了24轮的结果。 训练的结果会保存在output文件夹内 官方模型推理
在进行模型推理前需要先导出模型在官方代码的tools文件夹下有个export_onnx.py文件只需要指定配置文件与训练好的模型文件
parser.add_argument(--config, -c, default/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml,typestr, )
parser.add_argument(--resume, -r, defaultrtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth,typestr, )导出的文件是onnx格式 ONNXOpen Neural Network Exchange是一种开放式的文件格式用于存储和交换训练好的机器学习模型。它使得不同的人工智能框架如PyTorch、TensorFlow可以共享模型促进了模型在不同平台之间的迁移和复用。ONNX文件采用Protobuf序列化技术进行存储具有高效、紧凑的特点。 随后开始推理代码如下
import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor
if __name__ __main__:##################classes [car,truck,bus]################### print(onnx.helper.printable_graph(mm.graph))#############img_path 1.jpg#############im Image.open(img_path).convert(RGB)im im.resize((640, 640))im_data ToTensor()(im)[None]print(im_data.shape)size torch.tensor([[640, 640]])sess ort.InferenceSession(model.onnx)import timestart time.time()output sess.run(# output_names[labels, boxes, scores],output_namesNone,input_feed{images: im_data.data.numpy(), orig_target_sizes: size.data.numpy()})end time.time()fps 1.0 / (end - start)print(fps)labels, boxes, scores outputdraw ImageDraw.Draw(im)thrh 0.6for i in range(im_data.shape[0]):scr scores[i]lab labels[i][scr thrh]box boxes[i][scr thrh]print(i, sum(scr thrh))#print(lab)print(fbox:{box})for l, b in zip(lab, box):draw.rectangle(list(b), outlinered,)print(l.item())draw.text((b[0], b[1] - 10), textstr(classes[l.item()]), fillblue, )#############im.save(2.jpg)#############YOLOv8集成RT-DETR训练
在YOLOv8中给出了YOLO先前的诸多版本此外还包含RT-DETR 其运行环境与官方的相同这里就不再赘述了另外如果想要了解YOLO及其集成算法的更多功能可以查看
https://docs.ultralytics.com/ultralytics集成了多种算法已有将YOLO目标检测算法大一统的趋势涵盖语义分割、目标检测、姿势估计、分类、跟踪等多个任务。
数据集配置
YOLO版本的RT-DETR的数据集支持的数据集格式有多种这里博主选用的是YOLO格式的
cocoimagestrain2017val2017lablestrain2017val2017开始训练
随后在根目录下新建一个run.py文件文件中写入如下代码
from ultralytics.models import RTDETR
if __name__ __main__:model RTDETR(modelultralytics/cfg/models/rt-detr/rtdetr-l.yaml)#model.load(rtdetr-l.pt) # 不使用预训练权重可注释掉此行model.train(pretrainedTrue, dataultralytics\cfg\datasets\cocomine.yaml, epochs200, batch16, device0, imgsz320, workers2,cacheFalse,)运行报错
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.解决方法这是由于Anconda的torch中的某个文件与环境中的某个文件冲突导致的找到环境中的文件 环境路径
D:\softwares\Anconda\envs\detr\Library\bin将下面的文件给重命名即可。 随后便开始训练了如下 至此RT-DETR的训练过程便完成了。博主设置训练200个epoch但考虑到接下来的任务因此训练到一半就停止了生成的文件存放在run文件中如下 YOLOv8集成RT-DETR推理
在YOLOv8集成的RT-DETR中其设计就非常完备了我们只需要新建一个predict.py里面的内容如下 这里的images即为一个文件夹里面可以放入多张图像save代表保存
modelRTDETR(runs\detect/train\weights/best.pt)
model.predict(sourceimages,saveTrue)推理结果、保存路径与推理速度都会显示在下面 当然我们还可以指定conf参数即置信度可以帮我们筛选一下结果设置置信度为0.6此时原本的汽车就不再框选了。 视频推理
视频推理也很简单只需要将原来的图像换为视频即可
modelRTDETR(runs\detect/train\weights/best.pt)
model.predict(sourceimages/1.mp4,saveTrue,conf0.6)目标跟踪
在先前的目标跟踪中都是通过先检测后跟踪的方式如采用YOLOv5DeepSort的方式进行目标跟踪而在YOLOv8中他将该功能集成到里面我们可以直接采用执行跟踪任务的方式完成目标跟踪。
from ultralytics.models import RTDETR
modelRTDETR(runs\detect/train\weights/best.pt)
results model.track(sourceimages/1.mp4, conf0.3, iou0.5,saveTrue)RT-DETR目标跟踪视频 轨迹绘制
from collections import defaultdictimport cv2
import numpy as np
from ultralytics import RTDETR# Load the YOLOv8 model
modelRTDETR(D:\graduate\programs\yolo8/ultralytics-main/runs\detect/train\weights/best.pt)# Open the video file
video_path images/1.mp4
cap cv2.VideoCapture(video_path)# Store the track history
track_history defaultdict(lambda: [])# Loop through the video frames
while cap.isOpened():# Read a frame from the videosuccess, frame cap.read()if success:# Run YOLOv8 tracking on the frame, persisting tracks between framesresults model.track(frame, persistTrue)# Get the boxes and track IDsboxes results[0].boxes.xywh.cpu()track_ids results[0].boxes.id.int().cpu().tolist()# Visualize the results on the frameannotated_frame results[0].plot()# Plot the tracksfor box, track_id in zip(boxes, track_ids):x, y, w, h boxtrack track_history[track_id]track.append((float(x), float(y))) # x, y center pointif len(track) 30: # retain 90 tracks for 90 framestrack.pop(0)# Draw the tracking linespoints np.hstack(track).astype(np.int32).reshape((-1, 1, 2))cv2.polylines(annotated_frame, [points], isClosedFalse, color(230, 230, 230), thickness10)# Display the annotated framecv2.imshow(YOLOv8 Tracking, annotated_frame)# Break the loop if q is pressedif cv2.waitKey(1) 0xFF ord(q):breakelse:# Break the loop if the end of the video is reachedbreak# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()