资源经验分享目标检测之数据集增强(旋转)

目标检测之数据集增强(旋转)

2019-11-06 | |  118 |   0

原标题:目标检测之数据集增强(旋转)

原文来自:CSDN      原文链接:https://blog.csdn.net/weixin_42486139/article/details/102896267


项目背景:

最近在做国外身份证的项目,拍摄风格多样。需求:检测出身份证的位置并判断身份证的旋转角度,根据角度和位置校正身份证。

由此可以将一张标注样本旋转三个角度(90,180,270),做数据增广。相应地,xml文件的坐标和label要与旋转后的图像样本对应

1.png

xml文件如下,角度为0的label为TYPE:B_2D_ANGLE0, 相应地,若旋转90度,label为TYPE:B_2D_ANGLE90,180度与270度类似。

2.png

 

# -*- coding: utf-8 -*-
"""
Created on Fri Sep 27 13:53:47 2019
@author: mandy
"""
import os
import cv2
import time
import numpy as np
import xml.dom.minidom as xmldom
 
def parse_xml(fn):
    xml_file = xmldom.parse(fn)
    eles = xml_file.documentElement
    #print(eles.tagName)
    label = eles.getElementsByTagName("name")[0].firstChild.data
    xmin = eles.getElementsByTagName("xmin")[0].firstChild.data
    xmax = eles.getElementsByTagName("xmax")[0].firstChild.data
    ymin = eles.getElementsByTagName("ymin")[0].firstChild.data
    ymax = eles.getElementsByTagName("ymax")[0].firstChild.data
    #print(xmin, xmax, ymin, ymax)
    return label,xmin, ymin, xmax, ymax
 
def rotate_img(image,angle):   #自定义旋转函数,使用opencv自带的旋转函数旋转后会有黑边
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0) #获得旋转矩阵
   # print(M)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))
    M[0, 2] += (nW / 2) - center[0]
    M[1, 2] += (nH / 2) - center[1]
    rotated=cv2.warpAffine(image, M, (nW, nH))    
    return rotated
 
def convert_coo(j,loc_re):   #修改旋转后的坐标及label
    angle_list=[0,90,180,270,0,90,180,270]
    if int(initial_angle) in angle_list:
        idx=angle_list.index(int(initial_angle))
        #print(idx)
        add_idx=int(j/90)
        idx+=add_idx
    w_list=[new_label+str(angle_list[idx])]
   # print(w_list)
    if j==90:
        for l in loc_re:
            x_r,y_r=l[1],w_src-l[0]
            w_list.append(x_r)
            w_list.append(y_r)
    elif j==180:
        for l in loc_re:
            x_r,y_r=w_src-l[0],h_src-l[1]
            w_list.append(x_r)
            w_list.append(y_r)
    elif j==270:
        for l in loc_re:
            x_r,y_r=h_src-l[1],l[0]
            w_list.append(x_r)
            w_list.append(y_r) 
    return w_list
p='../all_img'
files=[os.path.splitext(i)[0] for i in os.listdir(p) if '.xml' in i]
angle_list=[90,180,270]
for i in range(len(files)):
    print('-'*50)
    print(files[i]+'.jpg')
    print(files[i]+'.xml')
    img=cv2.imread(os.path.join(p,files[i]+'.jpg'))
    w_src=img.shape[1]
    h_src=img.shape[0]
    label,xmin, ymin, xmax, ymax=parse_xml(os.path.join(p,files[i]+'.xml'))
    new_label=label[:15] 
    initial_angle=label[15:]
    print(new_label,initial_angle)   
        
    loc_re=np.array([float(xmin), float(ymin), float(xmax), float(ymax)]).reshape((2,2))
    for j in angle_list:
        rotated = rotate_img(img, j)
        w_list=convert_coo(j,loc_re)  #得到的只是旋转前的左上右下对应的旋转后的坐标,并不是旋转后图像的左上右下坐标
        #坐上、右下坐标
        res=[None]*5
        res[0]=w_list[0]
        res[1]=((w_list[1]+w_list[3])/2)-(abs(w_list[3]-w_list[1])/2)
        res[2]=((w_list[2]+w_list[4])/2)-(abs(w_list[4]-w_list[2])/2)
        res[3]=((w_list[1]+w_list[3])/2)+(abs(w_list[3]-w_list[1])/2)
        res[4]=((w_list[2]+w_list[4])/2)+(abs(w_list[4]-w_list[2])/2)
        res=[str(k) for k in res]
        
        cv2.imwrite(os.path.join('rotate_img',files[i]+'_'+str(j)+'.jpg'),rotated)    
        f=open(os.path.join('rotate_txt',files[i]+'_'+str(j)+'.txt'),'w')
        f.write(' '.join(res))
   # print(xmin)

旋转后的label和坐标信息我保存在txt文件,然后再将txt转为xml文件,当然也可以同时转换

txt转xml代码:

from xml.dom import minidom
import os
import cv2
jpg_list=os.listdir('./addRotate/rotate_img')
#txt_list=os.listdir(r'label_class/')
for filename0 in jpg_list :
    print(filename0) 
    xml_filename=os.path.splitext(filename0)[0]
    jpg_dirtory=os.path.join(r'./addRotate/rotate_img' ,filename0)   #jpg文件路径
    txt_dirtory=os.path.join(r'./addRotate/rotate_txt',xml_filename+'.txt')  #txt文件路径
    img_name=jpg_dirtory.split('/')[-1]
    floder=jpg_dirtory#.split('/') 
    im = cv2.imread(jpg_dirtory)
    w = im.shape[1]  
    h = im.shape[0]
    d = im.shape[2]
    doc = minidom.Document()   #创建DOM树对象
    annotation = doc.createElement('annotation')   #创建子节点
    doc.appendChild(annotation)                    #annotation作为doc树的子节点 
    folder = doc.createElement('folder')            
    folder.appendChild(doc.createTextNode(floder))  #文本节点作为floder的子节点
    annotation.appendChild(folder)                 #folder作为annotation的子节点
    filename = doc.createElement('filename')
    filename.appendChild(doc.createTextNode(img_name))
    annotation.appendChild(filename)
    source = doc.createElement('source')
    database = doc.createElement('database')
    database.appendChild(doc.createTextNode("Unknown"))
    source.appendChild(database)
    annotation.appendChild(source)    
    size = doc.createElement('size')
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode("%d" % w))
    size.appendChild(width)
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode("%d" % h))
    size.appendChild(height)
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode("%d" % d))
    annotation.appendChild(size)
    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode("0"))
    annotation.appendChild(segmented)
    txtLabel = open(txt_dirtory, 'r')
    boxes = txtLabel.read().splitlines()  #splitlines代替readlines去掉换行符
    for box in boxes:
        box = box.split(' ')
        object = doc.createElement('object')
        nm = doc.createElement('name')
        nm.appendChild(doc.createTextNode(box[0]))
        object.appendChild(nm)
        pose = doc.createElement('pose')
        pose.appendChild(doc.createTextNode("Unspecified"))
        object.appendChild(pose)
        truncated = doc.createElement('truncated') 
        truncated.appendChild(doc.createTextNode("0")) 
        object.appendChild(truncated) 
        difficult = doc.createElement('difficult')
        difficult.appendChild(doc.createTextNode("0"))
        object.appendChild(difficult) 
        bndbox = doc.createElement('bndbox')   
        xmin = doc.createElement('xmin')    
        xmin.appendChild(doc.createTextNode(box[1]))
        bndbox.appendChild(xmin)    
        ymin = doc.createElement('ymin')        
        ymin.appendChild(doc.createTextNode(box[2]))       
        bndbox.appendChild(ymin)    
        xmax = doc.createElement('xmax')       
        xmax.appendChild(doc.createTextNode(box[3]))       
        bndbox.appendChild(xmax)    
        ymax = doc.createElement('ymax')    
        ymax.appendChild(doc.createTextNode(box[4]))       
        bndbox.appendChild(ymax) 
        object.appendChild(bndbox) 
        annotation.appendChild(object)
       
        p=r'Annotations/'+xml_filename+'.xml'   #xml文件保存路径    
        savefile = open(p, 'w')    
        savefile.write(doc.toprettyxml())   
        savefile.close()

旋转90度后的样本:

3.png

txt文件:

xml文件:

4.png

 

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:python 程序员进阶之路:从新手到高手的100个模块

下一篇:LintCode 题目:斐波纳契数列简单

用户评价
全部评价

热门资源

  • Python 爬虫(二)...

    所谓爬虫就是模拟客户端发送网络请求,获取网络响...

  • TensorFlow从1到2...

    原文第四篇中,我们介绍了官方的入门案例MNIST,功...

  • TensorFlow从1到2...

    “回归”这个词,既是Regression算法的名称,也代表...

  • 机器学习中的熵、...

    熵 (entropy) 这一词最初来源于热力学。1948年,克...

  • TensorFlow2.0(10...

    前面的博客中我们说过,在加载数据和预处理数据时...