PyTorch Hooks by Example

Mansoor Aldosari
1 min readApr 19, 2021

--

Photo by Tatiana Rodriguez on Unsplash

Hooks allow us to look at data during the forward and backward pass.
register_forward_hook() can get the input and output data of a module after the forward pass. So if we have a model the looks like this:

>>> model
CustomLayer(
(linear1): Linear(in_features=1000, out_features=100, bias=True)
(linear2): Linear(in_features=100, out_features=10, bias=True)
)

and we want to know the output of linear1. We create a hook with the following signature (module, input, and output):

>>> dummylist = []
>>> def dummyhook(module, inp, outp): dummylist.append(outp)

Then, we call the hook on a module using the following:

>>> model.linear1.register_forward_hook(dummyhook);

Finally, run the model and then inspect the dummy list.

>>> y_pred = model(x)
>>> dummylist[0]
tensor([[-0.0116, -0.4900, -1.0557, ..., 0.7317, 2.4810, -0.0821],
[ 0.3536, -0.0983, -0.7922, ..., -0.8643, 1.5925, -0.1461],
[ 0.3246, -0.1753, -0.1105, ..., -0.4116, -0.5870, 0.9387],
...,
[ 0.3176, -1.0613, -0.1252, ..., -0.5476, -0.4361, -0.2262],
[ 2.0090, 2.9247, -0.5890, ..., -0.1938, 0.6834, -0.0979],
[-0.5889, 2.3918, 0.2750, ..., 1.5354, -0.0317, 0.9574]],
grad_fn=<AddmmBackward>)

Now, you might ask, where can I use this? one use case is to peak into your conventional neural network by generating a heat map of each layer.

--

--

No responses yet