濮阳杆衣贸易有限公司

主頁 > 知識庫 > pytorch從csv加載自定義數(shù)據(jù)模板的操作

pytorch從csv加載自定義數(shù)據(jù)模板的操作

熱門標簽:云南地圖標注 賓館能在百度地圖標注嗎 電銷機器人 金倫通信 汕頭電商外呼系統(tǒng)供應(yīng)商 crm電銷機器人 南京crm外呼系統(tǒng)排名 北京外呼電銷機器人招商 400電話 申請 條件 鄭州智能外呼系統(tǒng)中心

整理了一套模板,全注釋了,這個難點終于克服了

from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
#放文件的路徑
dir_path= './97/train/'
csv_path='./97/train.csv'
class Mydataset(Dataset):
 #傳遞數(shù)據(jù)路徑,csv路徑 ,數(shù)據(jù)增強方法
 def __init__(self, dir_path,csv, transform=None, target_transform=None):
  super(Mydataset, self).__init__()
  #一個個往列表里面加絕對路徑
  self.path = []
  #讀取csv
  self.data = pd.read_csv(csv)
  #對標簽進行硬編碼,例如0 1 2 3 4,把字母變成這個
  colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))}
  self.data['label'] = self.data['label'].map(colorMap)
  #創(chuàng)造空的label準備存放標簽
  self.num = int(self.data.shape[0]) # 一共多少照片
  self.label = np.zeros(self.num, dtype=np.int32)
  #迭代得到數(shù)據(jù)路徑和標簽一一對應(yīng)
  for index, row in self.data.iterrows():
   self.path.append(os.path.join(dir_path,row['filename']))
   self.label[index] = row['label'] # 將數(shù)據(jù)全部讀取出來
  #訓(xùn)練數(shù)據(jù)增強
  self.transform = transform
  #驗證數(shù)據(jù)增強在這里沒用
  self.target_transform = target_transform
 #最關(guān)鍵的部分,在這里使用前面的方法
 def __getitem__(self, index):
  img =Image.open(self.path[index]).convert('RGB')
  labels = self.label[index]
  #在這里做數(shù)據(jù)增強
  if self.transform is not None:
   img = self.transform(img) # 轉(zhuǎn)化tensor類型
  return img, labels
 def __len__(self):
  return len(self.data)
#數(shù)據(jù)增強的具體內(nèi)容
transform = transforms.Compose(
 [transforms.ToTensor(),
  transforms.Resize(150),
  transforms.CenterCrop(150),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
#加載數(shù)據(jù)
train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0)
#迭代訓(xùn)練
for i_batch,batch_data in enumerate(trainloader):
 image,label=batch_data

補充:pytorch—定義自己的數(shù)據(jù)集及加載訓(xùn)練

筆記:pytorch Conv2d 的寬高公式理解,pytorch 使用自己的數(shù)據(jù)集并且加載訓(xùn)練

一、pypi 鏡像使用幫助

pypi 鏡像每 5 分鐘同步一次。

臨時使用

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

注意,simple 不能少, 是 https 而不是 http

設(shè)為默認

修改 ~/.config/pip/pip.conf (Linux), %APPDATA%\pip\pip.ini (Windows 10)$HOME/Library/Application Support/pip/pip.conf (macOS) (沒有就創(chuàng)建一個), 修改 index-urltuna,例如

[global]
index-url = https://pypi.tuna.tsinghua.edu.cn/simple

pip 和 pip3 并存時,只需修改 ~/.pip/pip.conf。

二、pytorch Conv2d 的寬高公式理解

三、pytorch 使用自己的數(shù)據(jù)集并且加載訓(xùn)練

import os
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import time
import random
import csv
from PIL import Image
def createImgIndex(dataPath, ratio):
 '''
 讀取目錄下面的圖片制作包含圖片信息、圖片label的train.txt和val.txt
 dataPath: 圖片目錄路徑
 ratio: val占比
 return:label列表
 '''
 fileList = os.listdir(dataPath)
 random.shuffle(fileList)
 classList = [] # label列表
 # val 數(shù)據(jù)集制作
 with open('data/val_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList)*ratio)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+標準與否
    row.append(os.path.join(dataPath, fileList[i])) # 圖片路徑
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 # train 數(shù)據(jù)集制作
 with open('data/train_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList) * ratio)+1, len(fileList)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+標準與否
    row.append(os.path.join(dataPath, fileList[i])) # 圖片路徑
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 print(classList, len(classList))
 return classList
def default_loader(path):
 '''定義讀取文件的格式'''
 return Image.open(path).resize((128, 128),Image.ANTIALIAS).convert('RGB')
class MyDataset(Dataset):
 '''Dataset類是讀入數(shù)據(jù)集數(shù)據(jù)并且對讀入的數(shù)據(jù)進行索引'''
 def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
  super(MyDataset, self).__init__() #對繼承自父類的屬性進行初始化
  fh = open(txt, 'r') #按照傳入的路徑和txt文本參數(shù),以只讀的方式打開這個文本
  reader = csv.reader(fh)
  imgs = []
  for row in reader:
   imgs.append((row[0], int(row[1]))) # (圖片信息,lable)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform
  self.loader = loader
 
 def __getitem__(self, index):
  '''用于按照索引讀取每個元素的具體內(nèi)容'''
  # fn是圖片path #fn和label分別獲得imgs[index]也即是剛才每行中row[0]和row[1]的信息
  fn, label = self.imgs[index]
  img = self.loader(fn)
  if self.transform is not None:
   img = self.transform(img) #數(shù)據(jù)標簽轉(zhuǎn)換為Tensor
  return img, label
 
 def __len__(self):
  '''返回數(shù)據(jù)集的長度'''
  return len(self.imgs)
class Model(nn.Module):
 def __init__(self, classNum=31):
  super(Model, self).__init__()
  # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  # torch.nn.MaxPool2d(kernel_size, stride, padding)
  # input 維度 [3, 128, 128]
  self.cnn = nn.Sequential(
   nn.Conv2d(3, 64, 3, 1, 1), # [64, 128, 128]
   nn.BatchNorm2d(64),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [64, 64, 64]
   nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
   nn.BatchNorm2d(128),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [128, 32, 32]
   nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
   nn.BatchNorm2d(256),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [256, 16, 16]
   nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 8, 8]
   nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 4, 4]
  )
  self.fc = nn.Sequential(
   nn.Linear(512 * 4 * 4, 1024),
   nn.ReLU(),
   nn.Linear(1024, 512),
   nn.ReLU(),
   nn.Linear(512, classNum)
  )
 def forward(self, x):
  out = self.cnn(x)
  out = out.view(out.size()[0], -1)
  return self.fc(out)
def train(train_set, train_loader, val_set, val_loader):
 model = Model()
 loss = nn.CrossEntropyLoss() # 因為是分類任務(wù),所以loss function使用 CrossEntropyLoss
 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimizer 使用 Adam
 num_epoch = 10
 # 開始訓(xùn)練
 for epoch in range(num_epoch):
  epoch_start_time = time.time()
  train_acc = 0.0
  train_loss = 0.0
  val_acc = 0.0
  val_loss = 0.0
  model.train() # train model會開放Dropout和BN
  for i, data in enumerate(train_loader):
   optimizer.zero_grad() # 用 optimizer 將 model 參數(shù)的 gradient 歸零
   train_pred = model(data[0]) # 利用 model 的 forward 函數(shù)返回預(yù)測結(jié)果
   batch_loss = loss(train_pred, data[1]) # 計算 loss
   batch_loss.backward() # tensor(item, grad_fn=NllLossBackward>)
   optimizer.step() # 以 optimizer 用 gradient 更新參數(shù)
   train_acc += np.sum(np.argmax(train_pred.data.numpy(), axis=1) == data[1].numpy())
   train_loss += batch_loss.item()
  model.eval()
  with torch.no_grad(): # 不跟蹤梯度
   for i, data in enumerate(val_loader):
    # data = [imgData, labelList]
    val_pred = model(data[0])
    batch_loss = loss(val_pred, data[1])
    val_acc += np.sum(np.argmax(val_pred.data.numpy(), axis=1) == data[1].numpy())
    val_loss += batch_loss.item()
   # 打印結(jié)果
   print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \

     (epoch + 1, num_epoch, time.time() - epoch_start_time, \

     train_acc / train_set.__len__(), train_loss / train_set.__len__(), val_acc / val_set.__len__(),
     val_loss / val_set.__len__()))
if __name__ == '__main__':
 dirPath = '/data/Matt/QC_images/test0916' # 圖片文件目錄
 createImgIndex(dirPath, 0.2)    # 創(chuàng)建train.txt, val.txt
 root = os.getcwd() + '/data/'
 train_data = MyDataset(txt=root+'train_section1015.csv', transform=transforms.ToTensor())
 val_data = MyDataset(txt=root+'val_section1015.csv', transform=transforms.ToTensor())
 train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers = 4)
 val_loader = DataLoader(dataset=val_data, batch_size=6, shuffle=False, num_workers = 4)
 # 開始訓(xùn)練模型
 train(train_data, train_loader, val_data, val_loader)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。

您可能感興趣的文章:
  • pytorch 數(shù)據(jù)加載性能對比分析
  • pytorch加載語音類自定義數(shù)據(jù)集的方法教程
  • pytorch加載自己的圖像數(shù)據(jù)集實例
  • PyTorch加載自己的數(shù)據(jù)集實例詳解
  • Pytorch自己加載單通道圖片用作數(shù)據(jù)集訓(xùn)練的實例
  • Pytorch 數(shù)據(jù)加載與數(shù)據(jù)預(yù)處理方式
  • pytorch 自定義數(shù)據(jù)集加載方法

標簽:石家莊 懷化 浙江 西寧 梅州 錫林郭勒盟 昆明 文山

巨人網(wǎng)絡(luò)通訊聲明:本文標題《pytorch從csv加載自定義數(shù)據(jù)模板的操作》,本文關(guān)鍵詞  pytorch,從,csv,加載,自定義,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《pytorch從csv加載自定義數(shù)據(jù)模板的操作》相關(guān)的同類信息!
  • 本頁收集關(guān)于pytorch從csv加載自定義數(shù)據(jù)模板的操作的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    梅河口市| 自贡市| 榆树市| 奇台县| 富顺县| 内乡县| 芮城县| 大同市| 四子王旗| 西城区| 博客| 施秉县| 龙胜| 浦江县| 长沙县| 蓬莱市| 乐昌市| 浠水县| 调兵山市| 彩票| 临潭县| 昌都县| 林周县| 时尚| 曲周县| 确山县| 拉萨市| 红桥区| 南平市| 沅江市| 鹰潭市| 灵寿县| 苏州市| 鹿泉市| 鄄城县| 绥中县| 望谟县| 汾西县| 石狮市| 乐安县| 长子县|