Custom dataset 만들기

https://www.youtube.com/watch?v=38hn-FpRaJs

<기본>
class my_dataset(torch.utils.data.Dataset):
   def __init__(self, [내가 필요한 것들]):
      데이터셋을 가져와 전처리 해준다.
   def __len__(self):
      데이터셋 길이를 반환한다.
   def __getitem(self, index):
      데이터셋에서 1개의 데이터를 반환한다.

<적용>

class TrainData(torch.utils.data.Dataset):
 
    def __init__(self):
        self.len = train_data2.shape[0]      # 학습데이터의 갯수 : 56320개
        self.x_data = train_data2.iloc[:, 0]  # df로서 csv 파일 내용의 이미지 번호
        self.y_data = train_data2.iloc[:, 2]  # df로서 csv 파일 내용의 이미지 추출 주소
         
    def __getitem__(self, index):
        ###### 수정 시작 ######## 1월 31일 수정
        try:
            # TrainData 클래스로 dataset이라는 생성자를 만들게 되면 init 메소드가 작동
            # 하면서 동시에 x_data가 만들어지게 되고, x_data의 길이 중에서
            # 인덱스를 1개를 추출하여 기계학습에 적합한 수치행렬을 만듦
            # 즉, no_2_numpy함수로 이미지를 기계학습에 적합한 수치행렬(numpy)로 변환

            a = no_2_numpy(self.x_data[index])
            self.xx_data = torch.tensor(a)                # x numpy.array --> torch.tensor
            self.y_data = torch.tensor(self.y_data)     # 목표값(y)은 수치이므로 그대로 변환
            return self.xx_data, self.y_data[index]      # x, y를 반환
     
        except:
            print("id {}번에서 error 발생하였습니다.".format(train_data2.iloc[index,0]))
        ####### 수정 끝 ##########

    def __len__(self):
        return self.len  # init메소드에서 정한 길이를 반환

class TestData(torch.utils.data.Dataset):
 
    def __init__(self):
        self.len = test_data2.shape[0]
        self.x_data = test_data2.iloc[:, 0]
        self.y_data = test_data2.iloc[:, 2]
         
    def __getitem__(self, index):
        ###### 수정 시작 ######## 1월 31일 수정
        try:
            a = no_2_numpy(self.x_data[index])
            self.xx_data = torch.tensor(a)
            self.y_data = torch.tensor(self.y_data)
            return self.xx_data, self.y_data[index]
        except:
            print("id {}번에서 error 발생하였습니다.".format(test_data2.iloc[index,0]))
        ####### 수정 끝 ##########

    def __len__(self):
        return self.len

(함수)------------------------------------------------------------------------------------
img_path = "../../no_touch/img_old2/"
img_list = os.listdir(img_path)
img_list.sort()  # img_list 폴더 안에 있는 파일명을 올림차순으로 소트

# 경치 이미지를 기계학습에 맞도록 전처리하는 함수
def no_2_numpy(no):
    img = Image.open(img_path + "old{:06d}.jpg".format(no)) # 이미지를 객체로 변환
    img2 = img.resize((256, 256))                           # 이미지 객체를 resize
    img3 = img2.crop((16, 16, 240, 240))                  # 이미지 객체를 crop
    np_img = np.array(img3)                                 # 이미지 객체를 수치행렬로 변환
    np_img2 = np.transpose(np_img, (2, 0, 1))           # 수치행렬을 기계학습에 맞게 전치
 
    # normalize
    np_img3 = np_img2 / 255                               # 기계학습이 잘 되도록 0~1로 변환

    # 기계학습이 잘 되도록 데이터를 색깔 채널별로  표준화
    arr1 = (np_img3[0, :, :] - 0.485) / 0.229             
    arr2 = (np_img3[1, :, :] - 0.456) / 0.224
    arr3 = (np_img3[2, :, :] - 0.406) / 0.225
    np_img4 = np.array((arr1, arr2, arr3), dtype="float")   # 색깔채널을 합침
 
    return np_img4 

댓글