Hooks in PyTorch allow you to execute custom functions at specific points during the forward or backward passes of your neural network.
1. Understanding the Gradient Flow In PyTorch
When you call loss.backward(), the gradients are computed and stored in the grad attribute of each parameter.
Forward Pass: You pass your input through the network to get the output.
- input tensor X
- model parameter tensor W
- output tensor O
Loss Calculation: You calculate the loss based on the output.
- loss value
Backward Pass: When you call loss.backward(), PyTorch computes the gradients of the loss with respect to each parameter and stores them in parameter.grad.
- X.grad = d loss / d X
- W.grad = d loss / d W
- O.grad = d loss / d O
If you need to modify these gradients before the optimizer updates the weights, you can do so by accessing and modifying the grad attributes of the parameters directly.
loss.backward()
for p in model.parameters():
p.grad = C # you are modifying the gradient of W
optimizer.step()
2. Gradient Modification Using Hook
2.1 Tensor Hooks
When you register a hook on a tensor using tensor.register_hook(), the hook function takes the gradient of the tensor as input.
The returned value means d Loss/d tensor.
def gradient_hook(grad):
return grad * 2 # Modify the gradient
tensor.register_hook(gradient_hook)
2.2 Module Hooks
When using module hooks, you can access and modify the gradients of the module's parameters.
def backward_hook(module, grad_input, grad_output):
# Modify the gradients here
return tuple(grad * 2 for grad in grad_input)
module.register_full_backward_hook(backward_hook)
2.2.1 Backward Hook Function Signature
- grad_input: This is the gradient of the loss with respect to the input of the module. In other words, it is the gradient that is being backpropagated from the output of the module to its input.
- d Loss / d inputOfModule
- grad_output: This is the gradient of the loss with respect to the output of the module. It represents the gradient that is being passed from the next layer (or the final loss) to the current module's output.
- d Loss / d outputOfModule
input -> Linear1 -> ReLU -> Linear2 -> Output
If you register a backward hook on Linear1,
grad_input would be the gradient of the loss with respect to the input of Linear1, which is essentially the input to the network.
grad_output would be the gradient of the loss with respect to the output of Linear1
Module might have several inputs so the hook returns tuple of thoes inputs.
if it returns a new value then this new value will be stored in grad attribute or just use original grad value.
3. Practice - Modify Input gradient (Module Hook)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
x = self.linear(x)
return x
Let's say we have a model here.
model = MyModel()
for name, param in model.named_parameters():
if param.requires_grad:
print(name,", ", param.data)
print(name," grad, ", param.grad)
print("\n")
input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("input grad, ", input.retain_grad())
print("input shape, ", input.shape)
print("\n")
output = model(input)
print("output, ", output)
print("output grad, ", output.retain_grad())
print("output shape, ", output.shape)
print("\n")
loss = output.sum()
loss.backward()
print("loss, ", loss)
print("\n")
print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
if param.requires_grad:
print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)
Let's multiply input grad by 2 using backward hook
def modify_gradients(module, grad_input, grad_output):
print("--------------inside hook--------------")
print("grad_output, ", grad_output)
print("grad_input, ", grad_input)
print("---------------------------------------")
# Modify the gradients, multiply by 2.
return tuple(grad * 2 for grad in grad_input)
# Remove the hook if needed
# hook_handle.remove()
model2 = MyModel()
# Register the hook on the linear layer
hook_handle = model2.linear.register_full_backward_hook(modify_gradients)
for name, param in model2.named_parameters():
if param.requires_grad:
print(name,", ", param.data)
print(name," grad, ", param.grad)
print("\n")
input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("input grad, ", input.retain_grad())
print("input shape, ", input.shape)
print("\n")
output = model2(input)
print("output, ", output)
print("output grad, ", output.retain_grad())
print("output shape, ", output.shape)
print("\n")
loss = output.sum()
loss.backward()
print("loss, ", loss)
print("\n")
print("input grad after backprop(d loss/d input = w, but we *2 here)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model2.named_parameters():
if param.requires_grad:
print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)
4. Practice - Modify Model parameters' gradient (Tensor Hook)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(2, 1)
def set_hook(self):
def _hook(grad):
print("--------------inside hook--------------")
print("original grad, ", grad)
print("---------------------------------------")
return grad * 0 + 1.2345 # Modify the gradient
for name, param in self.linear.named_parameters():
if param.requires_grad:
param.register_hook(_hook)
def forward(self, x):
x = self.linear(x)
return x
Let's say we have a model here. We are going to set hook on the self.linear layer's parameters(tensors) to make them all into 1.2345.
Before,
model = MyModel()
for name, param in model.named_parameters():
if param.requires_grad:
print(name,", ", param.data)
print(name," grad, ", param.grad)
print("\n")
input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("\n")
output = model(input)
output.retain_grad()
loss = output.sum()
loss.backward()
print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
if param.requires_grad:
print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)
After,
model = MyModel()
model.set_hook()
for name, param in model.named_parameters():
if param.requires_grad:
print(name,", ", param.data)
print(name," grad, ", param.grad)
print("\n")
input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("\n")
output = model(input)
loss = output.sum()
loss.backward()
print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
if param.requires_grad:
print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)
5. Further Study
'boostcamp AI tech > boostcamp AI' 카테고리의 다른 글
[Math] Groups, Vector Spaces and Vector Subspaces (0) | 2024.10.02 |
---|---|
[Math] Finding Solutions of a System of Linear Equations (0) | 2024.09.28 |
[Math] Viewing Deep Learning From Maximum Likelihood Estimation Perspective (0) | 2024.09.25 |
Docker (1) | 2024.03.06 |
Airflow (0) | 2024.02.26 |