PHPFixing
  • Privacy Policy
  • TOS
  • Ask Question
  • Contact Us
  • Home
  • PHP
  • Programming
  • SQL Injection
  • Web3.0

Monday, August 15, 2022

[FIXED] How to get the output gradient w.r.t input

 August 15, 2022     gradient, input, mnist, output, pytorch     No comments   

Issue

I have some problem with getting the output gradient of input. It is simple mnist model.

for num,(sample_img, sample_label) in enumerate(mnist_test):
    if num == 1:
        break

    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    prediction = model(sample_img.unsqueeze(dim=0))
    cost = criterion(prediction, torch.tensor([sample_label]).to(device))
    optimizer.zero_grad()
    cost.backward()
    print(sample_label)
    print(sample_img.shape)

    plt.imshow(sample_img.detach().cpu().squeeze(),cmap='gray')
    plt.show()

print(sample_img.grad)

sample_img.grad is None


Solution

If you need to compute the gradient with respect to the input you can do so by calling sample_img.requires_grad_(), or by setting sample_img.requires_grad = True, as suggested in your comments.

Here is a small example:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


model = nn.Sequential(  # a dummy model
    nn.Conv2d(1, 1, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten()
)

sample_img = torch.rand(1, 5, 5)  # a dummy input
sample_label = 0

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
device = "cpu"

sample_img = sample_img.to(device)
sample_img.requires_grad = True

prediction = model(sample_img.unsqueeze(dim=0))
cost = criterion(prediction, torch.tensor([sample_label]).to(device))
optimizer.zero_grad()
cost.backward()
print(sample_label)
print(sample_img.shape)

plt.imshow(sample_img.detach().cpu().squeeze(), cmap='gray')
plt.show()

print(sample_img.grad.shape)
print(sample_img.grad)

Additionally, if you don't need the gradients of the model, you can set their gradient requirements off:

for param in model.parameters():
    param.requires_grad = False


Answered By - aretor
Answer Checked By - Terry (PHPFixing Volunteer)
  • Share This:  
  •  Facebook
  •  Twitter
  •  Stumble
  •  Digg
Newer Post Older Post Home

0 Comments:

Post a Comment

Note: Only a member of this blog may post a comment.

Total Pageviews

Featured Post

Why Learn PHP Programming

Why Learn PHP Programming A widely-used open source scripting language PHP is one of the most popular programming languages in the world. It...

Subscribe To

Posts
Atom
Posts
Comments
Atom
Comments

Copyright © PHPFixing