在Pytorch中加载数据
pytorch具有广泛的神经网络构建模块和一个简单、直观、稳定的API。Pytorch包括为您的模型准备和加载通用数据集的包。
介绍
Pytorch加载数据的核心是torch.utils.data.DataLoader类。它表示一个在数据集上的一个Python可迭代对象。Pytorch库为我们提供了内置的高质量数据集,去在torch.utils.data.Dataset中使用。数据集可从tochvision、torchaudio、torchtext中获得。
我们使用来自torchaudio.datasets的Yesno数据集。我们将演示如何有效地将数据从PyTorch数据集加载到PyTorch DataLoader中。
配置
pip install torchaudio
步骤、
1. 导入必须的库,来加载我们的数据
2. 访问数据集中的数据
3. 加载数据
4. 对数据进行迭代
5. 可视化数据(可选择)
1. Import necessary libraries for loading our data
import torch
import torchaudio
2. Access the data in the dataset
torchaudio.datasets.YESNO(
root,
url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
folder_in_archive='waves_yesno',
download=False,
transform=None,
target_transform=None)
# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
# * ``transform``: Using transforms on your data allows you to take it from its source state and transform it into data that’s joined together, de-normalized, and ready for training. Each library in PyTorch supports a growing list of transformations.
# * ``target_transform``: A function/transform that takes in the target and transforms it.
#
# Let’s access our Yesno data:
#
# A data point in Yesno is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data_trainset = torchaudio.datasets.YESNO('./', download=True)
# Pick data point number 3 to see an example of the the yesno_data:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))
3. Loading the data
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True)
4. Iterate over the data
for data in data_loader:
print("Data: ", data)
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
break
5. [Optional] Visualize the data
import matplotlib.pyplot as plt
print(data[0][0].numpy())
plt.figure()
plt.plot(waveform.t().numpy())