将cifar10数据转换为图片文件

1.下载cifar10数据集
2.获取文件名队列
3.从文件队列中读取文件
4.启动填充队列线程
5.将文件保存为图片

# copding:utf-8

import cifar10
import cifar10_input
# cifar10.py、cifar10_input.py 为官方文件
# 下载地址:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10
import tensorflow as tf
import os
import scipy.misc


def inputs_origin(data_dir):
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
    #print(filenames)
    #['cifar10_data/cifar-10-batches-bin/data_batch_1.bin',
    # 'cifar10_data/cifar-10-batches-bin/data_batch_2.bin',
    # 'cifar10_data/cifar-10-batches-bin/data_batch_3.bin',
    # 'cifar10_data/cifar-10-batches-bin/data_batch_4.bin',
    # 'cifar10_data/cifar-10-batches-bin/data_batch_5.bin']

    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file:' +f )
    # tf.train.string_input_producer tensorflow创建文件名队列
    filename_queue = tf.train.string_input_producer(filenames)
    # cifar10_input.read_cifar10 事先写好的程序 从队列queue读取文件的函数
    read_input = cifar10_input.read_cifar10(filename_queue)
    print(read_input)
    # 将图片转换为实数的形式
    reshape_image = tf.cast(read_input.uint8image, tf.float32)
    # print(reshape_image)
    # Tensor("Cast_1:0", shape=(32, 32, 3), dtype=float32)
    # 返回的reshape_image是一张图片的tensor,每次使用sess.run(reshape_image)会取出一张图片
    return reshape_image

if __name__ == '__main__':
    # 将数据集的默认地址进行修改
    flag = tf.app.flags.FLAGS
    flag.data_dir = 'cifar10_data'
    # 检测数据集有没有被下载,如果有则跳过,如果没有则下载数据
    cifar10.maybe_download_and_extract()

    if not os.path.exists('cifar10_pic/'):
        os.makedirs('cifar10_pic/')

    with tf.Session() as sess:
        reshape_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
        # tf.train.start_queue_runners 启动填充队列的线程
        threads = tf.train.start_queue_runners(sess=sess)
        # 对变量进行初始化
        sess.run(tf.global_variables_initializer())
        for i in range(30):
            # 每次取出一张图片
            image_array = sess.run(reshape_image)
            scipy.misc.toimage(image_array).save('cifar10_pic/%d.jpg' % i)
0

Leave a Reply

Your email address will not be published.