본문 바로가기

인공지능/Python

LightningDataModule을 for loop에서 이용하는 방법

Problem

LightningDataModule을 만들어서 Trainer에 직접 넣는 것은 해봤지만, predict할 때도 이용을 해보고 싶어서 이거저거 찾던 와중에 대부분의 코드는 다음처럼 batch를 불러오는 것을 확인했다.

dl = AutoEncoderDataLoader(root = data_path, batchsize=128)
dl.setup()
next(iter(dl))

근데 난 for loop으로 돌리고 싶어졌다. 단순히 그 이유였다.

 

Solution

for batch in dl.train_dataloader():
    #do something

간단하다.