基于飞桨复现语义分割网络HRNet,实现瓷砖缺陷检测
- 2020-10-30 10:28:00
- 刘大牛 转自文章
- 231
内容简介
PaddleSeg介绍 HRNet网络分析 基于PaddleSeg使用HRNet网络进行瓷砖缺陷检测
PaddleSeg介绍
HRNet网络分析
始终保持高分辨率表征
残差单元
类似全连接的阶段性特征融合
简单明了的解码过程
#获取各阶段的通道数{18,36,72,144} channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS #获取各阶段残差单元的循环次数{1,4,3} num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES #步长为2的跨步卷积 f=3*3 x = conv_bn_layer(input=input,filter_size=3,num_filters=64,stride=2,if_act=True,name=layer1_1) #步长为2的跨步卷积 f=3*3 x = conv_bn_layer(input=x,filter_size=3,num_filters=64,stride=2,if_act=True,name=layer1_2) #执行1个残差单元 la1 = layer1(x, name=layer2) #根据输入中最低分辨率特征图生成 低分辨率特征图,并规范特征图的通道数 tr1 = transition_layer([la1], [256], channels_2, name=tr1) #执行4次残差卷积,并在每次残差单元结束时进行特征融合 st2 = stage(tr1, num_modules_2, channels_2, name=st2) #根据输入中最低分辨率特征图生成 低分辨率特征图,并规范特征图的通道数 tr2 = transition_layer(st2, channels_2, channels_3, name=tr2) #执行3次残差卷积,并在每次残差单元结束时进行特征融合 st3 = stage(tr2, num_modules_3, channels_3, name=st3) #根据输入中最低分辨率特征图生成 低分辨率特征图,并规范特征图的通道数 tr3 = transition_layer(st3, channels_3, channels_4, name=tr3) #执行1次残差卷积,并在每次残差单元结束时进行特征融合 st4 = stage(tr3, num_modules_4, channels_4, name=st4) shape = st4[0].shape ##获取st4[0]宽高,并进行双线性插值 height, width = shape[-2], shape[-1] st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width]) st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width]) st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width]) #特征通道合并 out = fluid.layers.concat(st4, axis=1) #求总通道数 last_channels = sum(channels_4) #使用1*1卷积进行跨通道的特征融合 out = conv_bn_layer(input=out,filter_size=1,num_filters=last_channels,stride=1,if_act=True,name=conv-2) #使用1*1卷积进行最后的像素分类 out = fluid.layers.conv2d(input=out,num_filters=num_classes,filter_size=1,stride=1,padding=0,act=None, param_attr=ParamAttr(initializer=MSRA(), name=conv-1_weights),bias_attr=False) #恢复至网络输入的大小 out = fluid.layers.resize_bilinear(out, input.shape[2:])
基于PaddleSeg使用
HRNet 进行瓷砖缺陷检测
1. 数据准备
表面缺陷检测是筛选不合格产品的核心过程,但该过程很少能自动完成。 据记载,在世界上最大的瓷砖生产基地浙江省的瓷砖厂,有近3/4的工人在检查产品质量。 为了减轻人类的劳动强度,已经提出了许多图像处理 技术来尝试这样的检查任务。 瓷砖的自动损伤检测存在纹理复杂、缺陷形状多样、瓷砖光照条件随机性等几个瓶颈问题。 目标缺陷如气孔、裂纹、断裂、磨损如图所示。
2. 环境搭建
pip install -U paddlepaddle-gpu
git clone https://github.com/PaddlePaddle/PaddleSeg
cd PaddleSeg pip install -r requirements.txt
3. 标签数据
PaddleSeg采用单通道的标注图片,每一种像素值代表一种类别,像素标注类别需要从0开始递增,例如0,1,2,3表示有4种类别。3. 标签数据
标注图像请使用PNG无损压缩格式的图片,标注类别最多为256类。 PaddleSeg支持灰度标注同时也支持伪彩色标注。 PaddleSeg支持灰度标注转换为伪彩色标注,如需转换成伪彩色标注图,可使用PaddleSeg自带的的转换工具
4. 模型选择 参数 配置
模型选择 :根据自己的需求选择合适的模型进行训练。本文选择HRNet-W18作为训练模型。 预训练模型: pretrained_model/download_model.py中提供了相应的预训练模型下载地址,可以根据自己的需求在其中寻找相应的预训练模型,如不存在,可以按照同样的格式添加对应的模型名称与下载地址。 参数 配置:参数 由config.py和hrnet_Magnetic.yaml共同决定,.yaml文件的优先级高于config.py 。
DATASET:关于数据集的相关配置,如类别数、训练数据列表、测试数据列表 MODEL:模型配置: MODEL_NAME: "hrnet" 模型名称 HRNET:配置各个stage中不同分辨率特征图的通道数
MULTI_LOSS_WEIGHT:模型输出权重 配置 TRAIN_CROP_SIZE:训练时输入数据大小 EVAL_CROP_SIZE:测试时输入数据大小 BATCH_SIZE:输入网络中的BATCH_SIZE,需要适配显存 SNAPSHOT_EPOCH: 阶段性保存EPOCH NUM_EPOCHS:总的训练轮数 LOSS:损失函数 类别 LR:学习率
5. 参数 校验
python pdseg/check.py --cfg ./configs/hrnet_Magnetic.yaml
PaddleSeg/saved_model/unet_optic/best_model
python pdseg/train.py --use_gpu --cfg ./configs/hrnet_Magnetic.yaml --do_eval
7. 模型评估
python pdseg/train.py --use_gpu --cfg ./configs/hrnet_Magnetic.yaml --do_eval [EVAL]#image=81 acc=0.9853 IoU=0.8434 [EVAL]Category IoU: [0.9842 0.7891 0.8468 0.7010 0.9258 0.8136] [EVAL]Category Acc: [0.9927 0.8871 0.9407 0.9106 0.9597 0.8829] [EVAL]Kappa:0.9037
8. 结果可视化
python pdseg/vis.py --use_gpu --cfg ./configs/hrnet_Magnetic.yaml
得到可视化结果之后,可以使用如下代码展示可视化结果:
import matplotlib.pyplot as plt import os import cv2 # 定义显示函数 def display(img_name): image_dir = os.path.join("./dataset/Magnetic/images", img_name.split(".")[0]+".jpg") label_dir = os.path.join("./dataset/Magnetic/color",img_name) mask_dir = os.path.join("./visual", img_name) img_dir = [image_dir, label_dir, mask_dir] plt.figure(figsize=(15, 15)) title = [Image, label, Predict] for i in range(len(title)): plt.subplot(1, len(title), i+1) plt.title(title[i]) if i==0: img_rgb = cv2.imread(img_dir[i]) else: img = cv2.imread(img_dir[i]) b,g,r = cv2.split(img) img_rgb = cv2.merge([r,g,b]) plt.imshow(img_rgb) plt.axis(off) plt.show() # 注:第一次运行可能无法显示,再运行一次即可。 img_list=os.listdir("./visual") for img_name in img_list: display(img_name)
心得体会
官网地址:https://www.paddlepaddle.org.cn
发表评论
文章分类
联系我们
联系人: | 透明七彩巨人 |
---|---|
Email: | weok168@gmail.com |