轻松学Pytorch-使用ResNet50实现图像分类

发布时间:2025-05-23 10:26:23 作者:益华网络 来源:undefined 浏览量(1) 点赞(1)
摘要:磐创AI分享 来源 | OpenCV学堂 作者 | gloomyfish Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模

  磐创AI分享

来源 | OpenCV学堂

作者 | gloomyfish

Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。

模型

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:

AlexNet

VGG

ResNet

SqueezeNet

DenseNet

Inception v3

GoogLeNet

ShuffleNet v2

MobileNet v2

ResNeXt

Wide ResNet

MNASNet

这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:

model = torchvision.models.resnet50(pretrained=True).eval().cuda()

tf = transforms.Compose([

            transforms.Resize(256),

            transforms.CenterCrop(224),

            transforms.ToTensor(),

            transforms.Normalize(

            mean=[0.485, 0.456, 0.406],

            std=[0.229, 0.224, 0.225]

        )])

使用模型实现图像分类

这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:

1with open(imagenet_classes.txtas

 f:

2    labels = [line.strip() for line in

 f.readlines()]

3 4src = cv.imread("D:/images/space_shuttle.jpg"# aeroplane.jpg 5image = cv.resize(src, (224224

))

6image = np.float32(image) / 255.0 7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406

))

8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225

))

9image = image.transpose((201

))

10input_x = torch.from_numpy(image).unsqueeze(0

)

11

print(input_x.size())

12

pred = model(input_x.cuda())

13pred_index = torch.argmax(pred, 1

).cpu().detach().numpy()

14

print(pred_index)

15print("current predict class name : %s"%labels[pred_index[0

]])

16cv.putText(src, labels[pred_index[0]], (5050), cv.FONT_HERSHEY_SIMPLEX, 1.0, (00255), 2

)

17cv.imshow("input"

, src)

18cv.waitKey(0

)

19cv.destroyAllWindows()

运行结果如下:

转ONNX支持

在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:

1with open(imagenet_classes.txtas

 f:

2    labels = [line.strip() for line in

 f.readlines()]

3net = cv.dnn.readNetFromONNX("resnet.onnx"

)

4src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg 5image = cv.resize(src, (224224

))

6image = np.float32(image) / 255.0 7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406

))

8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225

))

9blob = cv.dnn.blobFromImage(image, 1.0, (224224), (000), False

)

10

net.setInput(blob)

11

probs = net.forward()

12

index = np.argmax(probs)

13cv.putText(src, labels[index], (5050), cv.FONT_HERSHEY_SIMPLEX, 1.0, (00255), 2

)

14cv.imshow("input"

, src)

15cv.waitKey(0

)

16cv.destroyAllWindows()

 运行结果见上图,这里就不再贴了。

二维码

扫一扫,关注我们

声明:本文由【益华网络】编辑上传发布,转载此文章须经作者同意,并请附上出处【益华网络】及本页链接。如内容、图片有任何版权问题,请联系我们进行处理。

感兴趣吗?

欢迎联系我们,我们愿意为您解答任何有关网站疑难问题!

您身边的【网站建设专家】

搜索千万次不如咨询1次

主营项目:网站建设,手机网站,响应式网站,SEO优化,小程序开发,公众号系统,软件开发等

立即咨询 15368564009
在线客服
嘿,我来帮您!