ImageFolder可以使用的情况是,比如猫狗识别中,train一个文件夹下已经将猫和狗分为两个不同的文件夹了,那我们可以直接使用ImageFolder来包装成数据集。 但有时候遇到的数据集是train一个文件夹下包含了所有的猫狗图片,那我们就无法使用ImageFolder函数,此时我们可以自己构造一个dataset类。 核心是getitem函数,这个函数的主要功能是根据样本的索引,返回索引对应的一张图片的图像数据X与对应的标签Y,也就是返回一个对应的训练样本。 具体而言,getitem的实现思路比较简单,将索引idx转换为图片的路径,然后用PIL的Image包来读取图片数据,然后将数据用torchvision的transforms转换成tensor并且进行Resize来统一大小(给出的图片尺寸不一致)与归一化,这样一来就可以得到图像数据了。因为训练集中图片的文件名上面带有猫狗的标签,所以标签可以通过对图片文件名split后得到然后转成0,1编码。
class MyDataset(Dataset): def __init__(self,data_path:str,train=True,transform=None): self.data_path = data_path self.train_flag = train if transform is None:([ …… ]) else: self.transform = transform self.path_dir = os.listdir(data_path) def __getitem__(self, idx:int): img_path = self.path_dir[idx] if self.train_flag is True: if img_path.split('.')[0] =='dog': label = 1 else: label = 0 else: label = int(img_path.split('.')[0]) label = torch.as_tensor(label,dtype=torch.int64) img_path = os.path.join(self.data_path,img_path) image = Image.open(img_path) image = self.transform(image) return image,label def __len__(self)->int: return len(self.path_dir) ```