Tensorflow tfrecord 数据集制作及读取

导入必要的库

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import os
import random
import tensorflow as tf
from scipy import misc
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt

定义数据类型转换函数

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

生成获取图片地址 + 标签的txt文件(顺序)

def mk_txt():
    data_dir = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/data/train/'  #图片路径
    #data_dir = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/data/test/'

    # 输出txt文件路径及文件名
    output_path = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/list_train.txt'
    #output_path = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/list_test.txt'
    fd = open(output_path, 'w')
    images_list = os.listdir(data_dir)
    for image_name in images_list:
        image_name = image_name.strip()
        labels = image_name.split('.')[0]
        if labels == 'cat':
            fd.write('{}/{} {}\n'.format(data_dir, image_name, '0'))
        elif labels == 'dog':
            fd.write('{}/{} {}\n'.format(data_dir, image_name, '1'))
    fd.close()

把训练集随机分为训练集和验证集(乱序)

def train2val():
    list_path = 'list.txt'  # 全部训练集的txt文本路径
    train_list_path = 'list_train.txt'  # 生成的训练集txt文件名(或路径+文件名)
    val_list_path = 'list_val.txt'  # 生成的验证集txt文件名(或路径+文件名)
    val_per = 0.1  # 验证集占比
    RANDOM_SEED = 0

    fd = open(list_path)
    lines = fd.readlines()
    fd.close()
    num_lines = len(lines)
    NUM_VALIDATION = int(num_lines * val_per)
    random.seed(RANDOM_SEED)
    random.shuffle(lines)
    fd = open(train_list_path, 'w')
    for line in lines[NUM_VALIDATION:]:
        fd.write(line)
    fd.close()
    fd = open(val_list_path, 'w')
    for line in lines[:NUM_VALIDATION]:
        fd.write(line)
    fd.close()

通过读取txt文件,生成tfrecord文件(乱序)

def mk_tfrecord():

    # 读取的txt文件的路径及文件名
    # list_path = '/home/xieqi/traffic_class/dir/retrain/data/list_train.txt'
    list_path = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/list_train.txt'

    # 生成的tfrecord文件路径及文件名
    # record_name = '/home/xieqi/traffic_class/dir/retrain/data/dir_train.tfrecords'
    record_name = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/train.tfrecords'

    fd = open(list_path)
    lines = fd.readlines()
    fd.close()
    random.shuffle(lines)
    writer = tf.python_io.TFRecordWriter(record_name)  # 创建一个tfrecord文件
    for item_path in lines:
        item_path = item_path.strip()  # 去掉首尾的换行符和空格
        img_path, img_label = item_path.split(' ') # 拆分图片绝对路径及标签

        img_label = int(img_label)  # 将标签由字符型转化为整型

        # imagel类型为Jpeg.File, 此处也可用cv2.imread()直接读取返回类型为numpy.ndarray,dtype=uint8
        img = Image.open(img_path)

        img = np.asarray(img, np.uint8) # 将Jpeg.File格式转化为ndarray
        img_height, img_width, img_channel = img.shape  # 图片大小
        img_raw = img.tobytes()  # 将图片转化为字符串格式(二进制文件)

        # example对象对label和image数据进行封装
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": _int64_feature(img_label),  # 标签的数据形式为int64
            'img_raw': _bytes_feature(img_raw),  # 图片的数据形式为Bytes
            "img_height": _int64_feature(img_height),
            "img_width": _int64_feature(img_width)}))
        writer.write(example.SerializeToString())  # 序列化为字符串
    writer.close()

读取解码数据, 返回的是Tensor

def read_and_decode(filename_queue, image_W, image_H, batch_size,min_after_dequeue):
    """
    filename_queue:文件名字符串队列
    image_W, image_H:图片的宽和高
    batch_size:批次大小
    min_after_dequeue:队列中剩余图片数目
    """

    reader = tf.TFRecordReader()  # tfrecord文件阅读器--类

    # 从文件名队列中读数据,返回下一个记录(键,值)对
    _, serialized_example = reader.read(filename_queue)

    # 解析读取的样例。
    features = tf.parse_single_example(serialized_example,
        features={'label': tf.FixedLenFeature([], tf.int64),
                'img_raw': tf.FixedLenFeature([], tf.string),
                'img_height': tf.FixedLenFeature([], tf.int64),
                'img_width': tf.FixedLenFeature([], tf.int64)})

    label = tf.cast(features['label'], tf.int32)
    height = tf.cast(features['img_height'], tf.int32) #此处必须为int32,int64无法显示图片
    width = tf.cast(features['img_width'], tf.int32)

    image = tf.decode_raw(features['img_raw'], tf.uint8)  # 将字符串解析成图像对应的像素数组
    # tf.decode_raw 转换成字符串之前是什么类型的数据,此处就要转换成对应的类型
    channel = 3
    image = tf.reshape(image, [height, width, channel])  # 向量---三维矩阵
    image = tf.cast(image, tf.float32)

    # 统一图片大小--缩放或裁剪,生成batch时图片必须为相同大小
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)

    # 标准化
    #image = tf.image.per_image_standardization(image)

    # 随机选取样本组成batch
    # capacity: 队列大小
    # num_threads: 线程数目
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                    capacity=capacity, num_threads=64, min_after_dequeue=min_after_dequeue)

    #one_hot
    """
    label_batch = tf.one_hot(label_batch, depth= 2)
    label_batch = tf.cast(label_batch, dtype=tf.int32)
    label_batch = tf.reshape(label_batch, [batch_size, 2])
    """

    return image_batch, label_batch

测试

def test_run():

    # tfrecord存放路径及名称
    tfrecord_filename = 'C:/Users/ahxie/PycharmProjects/Cats_vs_Dogs/train.tfrecords'
    filename_queue = tf.train.string_input_producer([tfrecord_filename], num_epochs=1)
    # 该处得到的为tensor,需要sess.run才能得到实际的数据
    image, label = read_and_decode(filename_queue, 208, 208, 16,min_after_dequeue=1000)
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()  # 从队列中取数据需要先建立一个Coordinator()
        # 并建立线程开始从队列中读取数据
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(1):
            example, lab = sess.run([image, label])  # 取出image和label
            print(type(example))
            print(type(lab))
            for j in range(10):
                print(lab[j])
                img = np.uint8(example[j])
                plt.imshow(img)
                plt.show()

        coord.request_stop()
        coord.join(threads)
        sess.close()

主函数

if __name__ == '__main__':
    #mk_txt() # 生成获取图片地址 + 标签的txt文件
    #mk_tfrecord() # 读取txt文件,生成tfrecord文件
    test_run()
Table of Contents