본문 바로가기

boostcamp AI tech/boostcamp AI

PyTorch Lightning Preview

728x90

pytorch를 한단계 더 추상화한 framework이다. 

1. Data Preparation

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()                                 ## <-- prepare_data

    configure_callbacks()
    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")                                       ## <-- setup
    configure_optimizers()

    # the sanity check runs here
    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")
def prepare_data(self):
    # download
    ...

process에서 한번만 호출되어 실행된다. 보통은 데이터를 다운받는 작업을 수행한다.

def setup(self, stage: Optional[str] = None):
    # Assign train/val datasets for use in dataloaders

모든 process마다(gpu마다) 호출되어 실행된다. 보통 multi gpu training환경에서 사용된다.

class MNISTDataModule(L.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
	
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
      
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

dataloader도 기존의 것을 잘 감싸서 추상화 시켜 놓았다.

 

2. Model Implementation

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

이런식으로 알아서 step들로써 자동으로 불러지도록 정리할 수 있다.

 

이 밖에도 logging, early stopping, checkpoint 생성, 시각화, TorchMetrics등 당연히 구현할 수는 있지만 상당히 귀찮은 작업들을 단순하게 만들어 놓았다.

multi-gpu환경으로 언제든 쉽게 옮겨갈 수 있도록 지원이 잘 되어있다. (심지어 gpu, tpu등 다른 device들도 지원)

torchmetrics를 활용하면 gpu based 연산이 가능하다. 대용량 데이터를 더 빠르게 test 가능하다는 장점이 있다.

 

728x90
반응형