Dataset与DataLoader的关系
- Dataset: 构建一个数据集,其中含有所有的数据样本
- DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。
import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): """创建自己的数据集""" def __init__(self): """初始化构建数据集所需要的参数""" pass def __getitem__(self, index): """来获取数据集中样本的索引""" pass def __len__(self): """获取数据集中的样本个数""" pass # 实例化自定义的数据集 dataset = MyDataset() # 将自定义的数据集加载到可训练的迭代容器 train_loader = DataLoader(dataset=dataset, # 自定义的数据集 batch_size=32, # 数据集中小批量的大小 shuffle=True, # 是否要打乱数据集中样本的次序 num_workers=2) # 是否要并行
实战1:CSV数据集(结构化数据集)
import torch import numpy as np from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): """创建自己的数据集""" def __init__(self, filepath): """初始化构建数据集所需要的参数""" xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) self.len = xy.shape[0] # 查看数据集中样本的个数 self.x_data = torch.from_numpy(xy[:, :-1]) self.y_data = torch.from_numpy(xy[:, [-1]]) print("数据已准备好......") def __getitem__(self, index): """为了支持下标操作, 即索引dataset[index]:来获取数据集中样本的索引""" return self.x_data[index], self.y_data[index] def __len__(self): """为了使用len(dataset):获取数据集中的样本个数""" return self.len file = "D:\BaiduNetdiskDownload\Dataset_Dataload\diabetes1.csv" """ 1.使用 MyDataset类 构建自己的dataset """ mydataset = MyDataset(file) """ 2.使用 DataLoader 构建train_loader """ train_loader = DataLoader(dataset=mydataset, batch_size=32, shuffle=True, num_workers=0) class MyModel(torch.nn.Module): """定义自己的模型""" def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(8, 6) self.linear2 = torch.nn.Linear(6, 4) self.linear3 = torch.nn.Linear(4, 1) self.sigmooid = torch.nn.Sigmoid() def forward(self, x): x = self.sigmooid(self.linear1(x)) x = self.sigmooid(self.linear2(x)) x = self.sigmooid(self.linear3(x)) return x # 实例化模型 model = MyModel() # 定义损失函数 criterion = torch.nn.BCELoss(size_average=True) # 定义优化器 optimizer = torch.optim.SGD(model.parameters(), lr=0.1) if __name__ == "__main__": for epoch in range(10): for i, data in enumerate(train_loader, 0): # 1. 准备数据 inputs, labels = data # 2. 前向传播 y_pred= model(inputs) loss = criterion(y_pred, labels) print(epoch, i, loss.item()) # 3. 反向传播 optimizer.zero_grad() loss.backward() # 4. 梯度更新 optimizer.step()
实战2:图片数据集
├── flower_data
—├── flower_photos(解压的数据集文件夹,3670个样本)
—├── train(生成的训练集,3306个样本)
—└── val(生成的验证集,364个样本)
主函数文件main.py
import os import torch from torchvision import transforms from my_dataset import MyDataSet from utils import read_split_data, plot_data_loader_image root = "../data/flower_data/flower_photos" # 数据集所在根目录 def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} train_data_set = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) batch_size = 8 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers'.format(nw)) train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=nw, collate_fn=train_data_set.collate_fn) # plot_data_loader_image(train_loader) for epoch in range(100): for step, data in enumerate(train_loader): images, labels = data # 然后在进行相应的训练操作即可 if __name__ == '__main__': main()
自定义数据集文件my_dataset.py
from PIL import Image import torch from torch.utils.data import Dataset class MyDataSet(Dataset): """自定义数据集""" def __init__(self, images_path: list, images_class: list, transform=None): self.images_path = images_path self.images_class = images_class self.transform = transform def __len__(self): return len(self.images_path) def __getitem__(self, item): img = Image.open(self.images_path[item]) # RGB为彩色图片,L为灰度图片 if img.mode != 'RGB': raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) label = self.images_class[item] if self.transform is not None: img = self.transform(img) return img, label @staticmethod def collate_fn(batch): # 官方实现的default_collate可以参考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels
功能文件utils.py(训练集、验证集的划分与可视化)
import os import json import pickle import random import matplotlib.pyplot as plt def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 保证随机结果可复现 assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 判断路径是否存在 # 遍历文件夹,一个文件夹对应一个类别 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保证顺序一致 flower_class.sort() # 生成类别名称以及对应的数字索引: 字典{’花名‘:0,’花名‘:1,···} class_indices = dict((k, v) for v, k in enumerate(flower_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) # 将花名与对应的序号分行保存 with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存储训练集的所有图片路径 train_images_label = [] # 存储训练集图片对应索引信息 val_images_path = [] # 存储验证集的所有图片路径 val_images_label = [] # 存储验证集图片对应索引信息 every_class_num = [] # 存储每个类别的样本总数 supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 # 遍历每个文件夹下的文件 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍历获取supported支持的所有文件路径 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 获取该类别对应的索引 image_class = class_indices[cla] # 记录该类别的样本数量 every_class_num.append(len(images)) # 按比例随机采样验证样本 val_path = random.sample(images, k=int(len(images) * val_rate)) for img_path in images: if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否则存入训练集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) plot_image = True if plot_image: # 绘制每种类别个数柱状图 plt.bar(range(len(flower_class)), every_class_num, align='center') # 将横坐标0,1,2,3,4替换为相应的类别名称 plt.xticks(range(len(flower_class)), flower_class) # 在柱状图上添加数值标签 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 设置x坐标 plt.xlabel('image class') # 设置y坐标 plt.ylabel('number of images') # 设置柱状图的标题 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_label def plot_data_loader_image(data_loader): batch_size = data_loader.batch_size plot_num = min(batch_size, 4) json_path = './class_indices.json' assert os.path.exists(json_path), json_path + " does not exist." json_file = open(json_path, 'r') class_indices = json.load(json_file) for data in data_loader: images, labels = data for i in range(plot_num): # [C, H, W] -> [H, W, C] img = images[i].numpy().transpose(1, 2, 0) # 反Normalize操作 img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 label = labels[i].item() plt.subplot(1, plot_num, i+1) plt.xlabel(class_indices[str(label)]) plt.xticks([]) # 去掉x轴的刻度 plt.yticks([]) # 去掉y轴的刻度 plt.imshow(img.astype('uint8')) plt.show() def write_pickle(list_info: list, file_name: str): with open(file_name, 'wb') as f: pickle.dump(list_info, f) def read_pickle(file_name: str) -> list: with open(file_name, 'rb') as f: info_list = pickle.load(f) return info_list