【深度之眼】【Pytorch打卡第3天】:DataLoader、DataSet、Transforms+划分数据集代码、构建Dataset、读取数据

    技术2024-05-29  83

    概括


    DataLoader与DataSet

    torch.utils.data.DataLoader:构建可迭代的数据装载器
    dataset: Dataset类,决定数据从哪读取 及如何读取batchsize : 批大小num_works: 是否多进程读取数据shuffle: 每个epoch是否乱序drop_last:当样本数不能被batchsize整 除时,是否舍弃最后一批数据
    torch.utils.data.Dataset:Dataset抽象类,所有自定义的 Dataset需要继承它,并且复写
    getitem() getitem : 接收一个索引,返回一个样本


    Transforms

    torchvision.transforms : 常用的图像预处理方法torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等
    transforms

    torchvision.transforms : 常用的图像预处理方法 • 数据中心化 • 数据标准化 • 缩放 • 裁剪 • 旋转 • 翻转 • 填充 • 噪声添加 • 灰度变换 • 线性变换 • 仿射变换 • 亮度、饱和度及对比度变换

    transforms.Normalize:加速运算
    功能:逐channel的对图像进行标准化 output = (input - mean) / std • mean:各通道的均值 • std:各通道的标准差 • inplace:是否原地操作 train_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ])

    划分数据集

    # -*- coding: utf-8 -*- """ # @file name : 1_split_dataset.py # @author : xinwenhu # @date : 2019-09-07 10:08:00 # @brief : 将数据集划分为训练集,验证集,测试集 """ import os import random import shutil def makedir(new_dir): if not os.path.exists(new_dir): os.makedirs(new_dir) if __name__ == '__main__': random.seed(1) dataset_dir = os.path.join("..", "..", "data", "RMB_data") split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") test_dir = os.path.join(split_dir, "test") train_pct = 0.8 valid_pct = 0.1 test_pct = 0.1 for root, dirs, files in os.walk(dataset_dir): for sub_dir in dirs: imgs = os.listdir(os.path.join(root, sub_dir)) imgs = list(filter(lambda x: x.endswith('.jpg'), imgs)) random.shuffle(imgs) img_count = len(imgs) train_point = int(img_count * train_pct) valid_point = int(img_count * (train_pct + valid_pct)) for i in range(img_count): if i < train_point: out_dir = os.path.join(train_dir, sub_dir) elif i < valid_point: out_dir = os.path.join(valid_dir, sub_dir) else: out_dir = os.path.join(test_dir, sub_dir) makedir(out_dir) target_path = os.path.join(out_dir, imgs[i]) src_path = os.path.join(dataset_dir, sub_dir, imgs[i]) shutil.copy(src_path, target_path) print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point, img_count-valid_point))

    构建Dataset

    # -*- coding: utf-8 -*- """ # @file name : dataset.py # @author : yts3221@126.com # @date : 2019-08-21 10:08:00 # @brief : 各数据集的Dataset定义 """ import os import random from PIL import Image from torch.utils.data import Dataset random.seed(1) rmb_label = {"1": 0, "100": 1} class RMBDataset(Dataset): def __init__(self, data_dir, transform=None): """ rmb面额分类任务的Dataset :param data_dir: str, 数据集所在路径 :param transform: torch.transform,数据预处理 """ self.label_name = {"1": 0, "100": 1} self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本 self.transform = transform def __getitem__(self, index): path_img, label = self.data_info[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self): return len(self.data_info) @staticmethod def get_img_info(data_dir): data_info = list() for root, dirs, _ in os.walk(data_dir): # 遍历类别 for sub_dir in dirs: img_names = os.listdir(os.path.join(root, sub_dir)) img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) # 遍历图片 for i in range(len(img_names)): img_name = img_names[i] path_img = os.path.join(root, sub_dir, img_name) label = rmb_label[sub_dir] data_info.append((path_img, int(label))) return data_info

    数据读取

    import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from model.lenet import LeNet from tools.my_dataset import RMBDataset def set_seed(seed=1): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) set_seed() # 设置随机种子 rmb_label = {"1": 0, "100": 1} # 参数设置 MAX_EPOCH = 10 BATCH_SIZE = 16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================ step 1/5 数据 ============================ split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
    Processed: 0.010, SQL: 9