#!/usr/bin/env python
# coding: utf-8

# # MindRecord格式转换
# 
# [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/dataset/mindspore_record.ipynb)&emsp;
# [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/dataset/mindspore_record.py)&emsp;
# [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_source.svg)](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/dataset/record.ipynb)
# 
# MindSpore可以将用于训练网络模型的数据集转换为特定的数据格式（MindSpore Record），便于数据的保存和加载。其目标是归一化用户数据集，并通过[mindspore.dataset.MindDataset](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/dataset/mindspore.dataset.MindDataset.html)接口实现数据的读取，用于训练过程。
# 
# ![conversion](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/tutorials/source_zh_cn/dataset/images/data_conversion_concept.png)
# 
# 此外，MindSpore还针对部分数据场景进行了性能优化。使用MindSpore Record数据格式可以减少磁盘IO和网络IO开销，从而获得更好的使用体验。
# 
# MindSpore Record数据格式具备以下特征：
# 
# 1. 实现数据统一存储和访问，使训练时数据读取更加简便。
# 2. 支持数据聚合存储和高效读取，便于数据管理和移动。
# 3. 提供高效的数据编解码操作，用户对数据操作无感知。
# 4. 可以灵活控制数据切分的分区大小，实现分布式数据处理。
# 
# ## Record文件结构
# 
# 如下图所示，MindSpore Record文件由数据文件和索引文件组成。
# 
# ![MindSpore Record](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/tutorials/source_zh_cn/dataset/images/mindrecord.png)
# 
# 其中，数据文件包含文件头、标量数据页和块数据页，用于存储用户归一化后的训练数据。具体用途如下：
# 
# - **文件头**：MindSpore Record文件的元信息。主要用于存储文件头大小、标量数据页大小、块数据页大小、Schema信息、索引字段、统计信息、文件分区信息、标量数据与块数据对应关系等。
# - **标量数据页**：主要用于存储整型、字符串、浮点型等标量类型数据，如图像的Label、文件名、长宽等信息。
# - **块数据页**：主要用于存储二进制串、NumPy数组等数据，如二进制图像文件本身、文本转换后的字典等。
# 
# 索引文件则包含基于标量数据（如图像Label、图像文件名等）生成的索引信息，便于检索和统计数据集信息。
# 
# > - 单个MindSpore Record文件建议小于20G。对于大数据集，建议分片存储为多个MindSpore Record文件。
# > - 数据文件和索引文件均暂不支持重命名操作。
# 
# ## 转换成Record格式
# 
# 下面主要介绍如何将CV类数据和NLP类数据转换为MindSpore Record文件格式，并通过`MindDataset`接口读取。
# 
# ### 转换CV类数据
# 
# 本示例以包含100条记录的CV数据集为例，介绍如何将其转换为MindSpore Record格式，并使用`MindDataset`接口读取。
# 
# 具体来说，需要创建一个包含100张图片的数据集并保存。每个样本包含`file_name`（字符串）、`label`（整型）、 `data`（二进制）三个字段，然后使用`MindDataset`接口读取该MindSpore Record文件。
# 
# 1. 生成100张图像，并转换成MindSpore Record文件格式。

# In[1]:


from PIL import Image
from io import BytesIO
from mindspore.mindrecord import FileWriter

file_name = "test_vision.mindrecord"
# 定义包含的字段
cv_schema = {"file_name": {"type": "string"},
             "label": {"type": "int32"},
             "data": {"type": "bytes"}}

# 声明MindSpore Record文件格式
writer = FileWriter(file_name, shard_num=1, overwrite=True)
writer.add_schema(cv_schema, "it is a cv dataset")
writer.add_index(["file_name", "label"])

# 创建数据集
data = []
for i in range(100):
    sample = {}
    white_io = BytesIO()
    Image.new('RGB', ((i+1)*10, (i+1)*10), (255, 255, 255)).save(white_io, 'JPEG')
    image_bytes = white_io.getvalue()
    sample['file_name'] = str(i+1) + ".jpg"
    sample['label'] = i+1
    sample['data'] = white_io.getvalue()

    data.append(sample)
    if i % 10 == 0:
        writer.write_raw_data(data)
        data = []

if data:
    writer.write_raw_data(data)

writer.commit()


# 若上述示例运行无报错，则说明数据集转换成功。
# 
# 2. 通过`MindDataset`接口读取MindSpore Record格式文件。

# In[2]:


from mindspore.dataset import MindDataset
from mindspore.dataset.vision import Decode

# 读取MindSpore Record格式文件
data_set = MindDataset(dataset_files=file_name)
decode_op = Decode()
data_set = data_set.map(operations=decode_op, input_columns=["data"], num_parallel_workers=2)

# 样本计数
print("Got {} samples".format(data_set.get_dataset_size()))


# ### 转换NLP类数据集
# 
# 本示例首先创建一个包含100条记录的文本数据，然后转换为MindSpore Record文件格式。每个样本包含八个字段，均为整型数组。最后，使用`MindDataset`接口读取该MindSpore Record文件。
# 
# > 为便于展示，此处略去了将文本转换成字典序的预处理过程。
# 
# 1. 生成100条文本数据，并转换成MindSpore Record文件格式。

# In[3]:


import numpy as np
from mindspore.mindrecord import FileWriter

# 输出的MindSpore Record文件完整路径
file_name = "test_text.mindrecord"

# 定义样本数据包含的字段
nlp_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
              "source_sos_mask": {"type": "int64", "shape": [-1]},
              "source_eos_ids": {"type": "int64", "shape": [-1]},
              "source_eos_mask": {"type": "int64", "shape": [-1]},
              "target_sos_ids": {"type": "int64", "shape": [-1]},
              "target_sos_mask": {"type": "int64", "shape": [-1]},
              "target_eos_ids": {"type": "int64", "shape": [-1]},
              "target_eos_mask": {"type": "int64", "shape": [-1]}}

# 声明MindSpore Record文件格式
writer = FileWriter(file_name, shard_num=1, overwrite=True)
writer.add_schema(nlp_schema, "Preprocessed nlp dataset.")

# 创建虚拟数据集
data = []
for i in range(100):
    sample = {"source_sos_ids": np.array([i, i + 1, i + 2, i + 3, i + 4], dtype=np.int64),
              "source_sos_mask": np.array([i * 1, i * 2, i * 3, i * 4, i * 5, i * 6, i * 7], dtype=np.int64),
              "source_eos_ids": np.array([i + 5, i + 6, i + 7, i + 8, i + 9, i + 10], dtype=np.int64),
              "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
              "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
              "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
              "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
              "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}
    data.append(sample)

    if i % 10 == 0:
        writer.write_raw_data(data)
        data = []

if data:
    writer.write_raw_data(data)

writer.commit()


# 2. 通过`MindDataset`接口读取MindSpore Record格式文件。

# In[4]:


from mindspore.dataset import MindDataset

# 读取MindSpore Record格式文件
data_set = MindDataset(dataset_files=file_name, shuffle=False)

# 样本计数
print("Got {} samples".format(data_set.get_dataset_size()))

# 打印部分数据
count = 0
for item in data_set.create_dict_iterator(output_numpy=True):
    print("source_sos_ids:", item["source_sos_ids"])
    count += 1
    if count == 10:
        break


# ## Dataset转存MindRecord
# 
# MindSpore提供常用数据集的转换工具类，能够将常用的数据集转换为MindSpore Record文件格式。
# 
# > 更多数据集转换的详细说明参考[API文档](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/mindspore.mindrecord.html)。
# 
# ### 转存CIFAR-10数据集
# 
# 用户可以通过[mindspore.dataset.Dataset.save](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.save.html)方法，将CIFAR-10原始数据转换为MindSpore Record，并使用`MindDataset`接口读取。
# 
# 1. 下载[CIFAR-10数据集](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)，并使用`Cifar10Dataset`加载。

# In[5]:


from download import download
from mindspore.dataset import Cifar10Dataset

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"

path = download(url, "./", kind="tar.gz", replace=True)
dataset = Cifar10Dataset("./cifar-10-batches-bin/")  # 加载数据


# 2. 调用`Dataset.save`接口，将CIFAR-10数据集转存为MindSpore Record文件格式。

# In[7]:


dataset.save("cifar10.mindrecord")


# 3. 通过`MindDataset`接口读取MindSpore Record格式文件。

# In[8]:


import os
from mindspore.dataset import MindDataset

# 读取MindSpore Record文件格式
data_set = MindDataset(dataset_files="cifar10.mindrecord")

# 样本计数
print("Got {} samples".format(data_set.get_dataset_size()))

if os.path.exists("cifar10.mindrecord") and os.path.exists("cifar10.mindrecord.db"):
    os.remove("cifar10.mindrecord")
    os.remove("cifar10.mindrecord.db")

