728x90
pytorch를 한단계 더 추상화한 framework이다.
1. Data Preparation
def fit(self):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data() ## <-- prepare_data
configure_callbacks()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# called PER DEVICE
on_fit_start()
setup("fit") ## <-- setup
configure_optimizers()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()
on_fit_end()
teardown("fit")
def prepare_data(self):
# download
...
process에서 한번만 호출되어 실행된다. 보통은 데이터를 다운받는 작업을 수행한다.
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
모든 process마다(gpu마다) 호출되어 실행된다. 보통 multi gpu training환경에서 사용된다.
class MNISTDataModule(L.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
dataloader도 기존의 것을 잘 감싸서 추상화 시켜 놓았다.
2. Model Implementation
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
이런식으로 알아서 step들로써 자동으로 불러지도록 정리할 수 있다.
이 밖에도 logging, early stopping, checkpoint 생성, 시각화, TorchMetrics등 당연히 구현할 수는 있지만 상당히 귀찮은 작업들을 단순하게 만들어 놓았다.
multi-gpu환경으로 언제든 쉽게 옮겨갈 수 있도록 지원이 잘 되어있다. (심지어 gpu, tpu등 다른 device들도 지원)
torchmetrics를 활용하면 gpu based 연산이 가능하다. 대용량 데이터를 더 빠르게 test 가능하다는 장점이 있다.
728x90
반응형
'boostcamp AI tech > boostcamp AI' 카테고리의 다른 글
Representation Learning & Self Supervised Learning (0) | 2024.01.02 |
---|---|
GPT-3 and Latest Trend (0) | 2024.01.01 |
Introduction to NLP task (0) | 2023.12.12 |
Self-Supervised Pre-training Model (GPT-2, GPT-3, ALBERT, ELECTRA) (2) | 2023.12.08 |
Self-Supervised Pre-training Model (GPT-1, BERT) (0) | 2023.12.07 |