728x90
1. Save and load Model parameters
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
torch.save(model.state_dict(), os.path.join(MODEL_PATH, "model.pt"))
new_model = TheModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt"))
- 모델의 parameter만 model.pt로 저장해두고 불러 올 수 있다. (모델 class는 가지고 있어야한다.)
from torchsummary import summary
summary(model, (3, 224, 224))
- 참고로 torchsummary라이브러리를 사용하면 더 깔끔한 모델의 설명을 볼 수 있다.
2. Save and load Model
torch.save(model. os.path.join(PATH, "model.pt"))
model = torch.load(os.path.join(PATH, "model.pt"))
model.eval()
- 통째로 모델까지 저장하고 불러올 수 있다.
3. Checkpoint
for e in epochs:
......
# forward, loss, backward, step
......
......
torch.save({
'epoch' : e,
'model_state_dict' : model.state_dict(),
'optimizer_state_dict' : optimizer.state_dict(),
'loss' : epoch_loss,
}
f"cp/checkpoint_model_{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}.pt")
- for e in epochs: 안에서 epoch마다 모델을 저장해서 checkpoint를 만들 수 있다.
cp = torch.load(PATH)
model.load_state_dict(cp['model_state_dict'])
optimizer.load_state_dict(cp['optimizer_state_dict'])
epoch = cp['epoch']
loss = cp['loss']
- 저장된 checkpoint를 이렇게 다시 불러올 수 있다. checkpoint를 가지고 학습을 이어서 할 수 있다.
4. Transfer learning
vgg = models.vgg99(pretrained=True).to(device)
- 공개된 모델 불러오기
class MyNewVgg(nn.Module):
def __init__(self):
super(MyNewVgg, self).__init__()
self.vgg99 = models.vgg99(pretrained=True)
self.linear_layers = nn.Linear(1000,1)
def forward(self, x):
x = self.vgg99(x)
return self.linear_layer(x)
my_vgg = MyNewVgg()
my_vgg = my_vgg.to(device)
## freeze vgg part
for param in my_vgg.parameters():
param.requires_grad = False
## train only the linear_layer
for param in my_vgg.linear_layers.parameters():
param.requires_grad = True
- pretrained model을 가져와서 나의 task에 맞게 내가 가지고 있는 데이터로 한번 더 학습 (fine tuning)
- 특정 layer만 frozen 시키면 다시 학습시킬 때 frozen layer를 제외하고 backpropagation이 수행 됨
728x90
반응형
'boostcamp AI tech > boostcamp AI' 카테고리의 다른 글
Multi gpu training, Hyper parameter Search, etc (0) | 2023.11.16 |
---|---|
Monitoring tools (0) | 2023.11.15 |
PyTorch Project (0) | 2023.11.13 |
PyTorch Basics (0) | 2023.11.13 |
Maximum Likelihood Estimation (0) | 2023.11.12 |