본문 바로가기

boostcamp AI tech/boostcamp AI

[PyTorch] Modify Gradient while Backward Propagation Using Hook

728x90

Hooks in PyTorch allow you to execute custom functions at specific points during the forward or backward passes of your neural network.

 

1. Understanding the Gradient Flow In PyTorch

When you call loss.backward(), the gradients are computed and stored in the grad attribute of each parameter.

Forward Pass: You pass your input through the network to get the output.

  • input tensor X
  • model parameter tensor W
  • output tensor O

Loss Calculation: You calculate the loss based on the output.

  • loss value

Backward Pass: When you call loss.backward(), PyTorch computes the gradients of the loss with respect to each parameter and stores them in parameter.grad.

  • X.grad = d loss / d X
  • W.grad = d loss / d W
  • O.grad = d loss / d O

If you need to modify these gradients before the optimizer updates the weights, you can do so by accessing and modifying the grad attributes of the parameters directly.

loss.backward() 

for p in model.parameters(): 
	p.grad = C  # you are modifying the gradient of W

optimizer.step()

2. Gradient Modification Using Hook

2.1 Tensor Hooks

When you register a hook on a tensor using tensor.register_hook(), the hook function takes the gradient of the tensor as input. 

The returned value means d Loss/d tensor.

def gradient_hook(grad):
    return grad * 2  # Modify the gradient

tensor.register_hook(gradient_hook)

 

2.2 Module Hooks

When using module hooks, you can access and modify the gradients of the module's parameters.

def backward_hook(module, grad_input, grad_output):
    # Modify the gradients here
    return tuple(grad * 2 for grad in grad_input)

module.register_full_backward_hook(backward_hook)

 

2.2.1 Backward Hook Function Signature

  • grad_input: This is the gradient of the loss with respect to the input of the module. In other words, it is the gradient that is being backpropagated from the output of the module to its input.
    • d Loss / d inputOfModule
  • grad_output: This is the gradient of the loss with respect to the output of the module. It represents the gradient that is being passed from the next layer (or the final loss) to the current module's output.
    • d Loss / d outputOfModule
input -> Linear1 -> ReLU -> Linear2 -> Output

If you register a backward hook on Linear1,

grad_input would be the gradient of the loss with respect to the input of Linear1, which is essentially the input to the network.

grad_output would be the gradient of the loss with respect to the output of Linear1

 

Module might have several inputs so the hook returns tuple of thoes inputs.

if it returns a new value then this new value will be stored in grad attribute or just use original grad value.

3. Practice - Modify Input gradient (Module Hook)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        x = self.linear(x)
        return x

Let's say we have a model here.

model = MyModel()
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,", ", param.data)
        print(name," grad, ", param.grad)
print("\n")

input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("input grad, ", input.retain_grad())
print("input shape, ", input.shape)
print("\n")

output = model(input)
print("output, ", output)
print("output grad, ", output.retain_grad())
print("output shape, ", output.shape)
print("\n")

loss = output.sum()
loss.backward()
print("loss, ", loss)
print("\n")

print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)

 

Let's multiply input grad by 2 using backward hook

def modify_gradients(module, grad_input, grad_output):
    print("--------------inside hook--------------")
    print("grad_output, ", grad_output)
    print("grad_input, ", grad_input)
    print("---------------------------------------")
    # Modify the gradients, multiply by 2.
    return tuple(grad * 2 for grad in grad_input)

# Remove the hook if needed
# hook_handle.remove()
model2 = MyModel()

# Register the hook on the linear layer
hook_handle = model2.linear.register_full_backward_hook(modify_gradients)

for name, param in model2.named_parameters():
    if param.requires_grad:
        print(name,", ", param.data)
        print(name," grad, ", param.grad)
print("\n")

input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("input grad, ", input.retain_grad())
print("input shape, ", input.shape)
print("\n")

output = model2(input)
print("output, ", output)
print("output grad, ", output.retain_grad())
print("output shape, ", output.shape)
print("\n")

loss = output.sum()
loss.backward()
print("loss, ", loss)
print("\n")

print("input grad after backprop(d loss/d input = w, but we *2 here)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model2.named_parameters():
    if param.requires_grad:
        print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)

 

4. Practice - Modify Model parameters' gradient (Tensor Hook)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(2, 1)

    def set_hook(self):
        def _hook(grad):
            print("--------------inside hook--------------")
            print("original grad, ", grad)
            print("---------------------------------------")
            return grad * 0 + 1.2345  # Modify the gradient
        
        for name, param in self.linear.named_parameters():
            if param.requires_grad:
                param.register_hook(_hook)

    def forward(self, x):
        x = self.linear(x)
        return x

Let's say we have a model here. We are going to set hook on the self.linear layer's parameters(tensors) to make them all into 1.2345.

 

Before, 

model = MyModel()

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,", ", param.data)
        print(name," grad, ", param.grad)
print("\n")

input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("\n")

output = model(input)
output.retain_grad()
loss = output.sum()
loss.backward()

print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)

 

After,

model = MyModel()
model.set_hook()

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,", ", param.data)
        print(name," grad, ", param.grad)
print("\n")

input = torch.tensor([[1.,3.]], requires_grad=True)
print("input, ", input)
print("\n")

output = model(input)
loss = output.sum()
loss.backward()

print("input grad after backprop(d loss/d input = w)", input.grad)
print("output grad after backprop(d loss/d output = 1)", output.grad)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name,"grad(d loss/d w = x, d loss/d b = 1)", param.grad)

 

5. Further Study

 

PyTorch 101: Understanding Hooks | DigitalOcean

Working on improving health and education, reducing inequality, and spurring economic growth? We'd like to help.

www.digitalocean.com

 

728x90
반응형