PyTorch Dataset

Dataset 클래스

  • 데이터 입력 형태를 정의하는 클래스.
  • 데이터를 입력하는 방식의 표준화.
  • Image, Text, Audio 등에 따른 다른 입력 정의.
class CustomDataset(Dataset):
    def __init__(self,text,labels):
        self.data = text
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem(self,idx):
        label = self.labels[idx]
        text = self.data[idx]
        smaple = {"Text":text, "Class":label}
        return sample
  • __init__은 데이터 생성 및 초기화.
  • __len__은 데이터 길이 반환.
  • __getitem__은 인덱스 따라 데이터 반환.
  • transform 있을 시 getitem으로 적용.
  • 초기화 할때 모든 데이터 적용 x (많은 시간 소요)

Dataset 클래스 생성시 유의점

  • 데이터 형태에 따라 각 함수를 다르게 정의함.
  • 모든 것을 데이터 생성 시점에 처리할 필요는 없음.
    • image의 Tensor 변화는 학습에 필요한 시점에 변환.
  • 데이터 셋에 대한 표준화된 처리방법 제공 필요.
    • 후속 연구자 또는 동료에게는 빛과 같은 존재.
  • 최근에는 HuggingFace등 표준화된 라이브러리 사용.

DataLoader 클래스

  • Data의 Batch를 생성해주는 클래스.
  • 학습직전(GPU feed전) 데이터의 변환을 책임.
  • Tensor로 변환 + Batch 처리가 메인 업무.
  • 병렬적인 데이터 전처리 코드의 고민 필요.
  • collate_fn은 튜플 형태의 데이터를 train과 label끼리 묶어주는 역할.

Categories:

Updated: