Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
426 views
in Technique[技术] by (71.8m points)

I tried to divide resnet into two parts using pytorch children(), but it doesn't work

Here is a simple example. I tried to divide a network (Resnet50) into two parts: head and tail using children. Conceptually, this should work but it doesn't. Why is it?

import torch
import torch.nn as nn
from torchvision.models import resnet50

head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*list(resnet.children())[-2:])
x = torch.zeros(1, 3, 160, 160)

resnet(x).shape      # torch.Size([1, 1000])
head(x).shape        # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape  # Error: RuntimeError: size mismatch, m1: [2048 x 1], m2: [2048 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136

For information, the tail is nothing but

Sequential(
  (0): AdaptiveAvgPool2d(output_size=(1, 1))
  (1): Linear(in_features=2048, out_features=1000, bias=True)
)

So I actually know that if I can do like this. But then, why the reshaping function (view) is not in the children?

pool =resnet._modules['avgpool']
fc = resnet._modules['fc']
fc(pool(head(x)).view(1, -1))

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

What you are looking to do is separate the feature extractor from the classifier.

  • What I should point out straight away, is that Resnet is not a sequential model (as the name implies - residual network - it as residuals)!

    Therefore compiling it down to a nn.Sequential will not be accurate. There's a difference between model definition the layers that appear ordered with .children() and the actual underlying implementation of that model's forward function.


  • The flattening you performed using view(1, -1) is not registered as a layer in all torchvision.models.resnet* models. Instead it is performed on this line in the forward definition:

    x = torch.flatten(x, 1)
    

    They could have registered it as a layer in the __init__ as self.flatten = nn.Flatten(), to be used in the forward implementation as x = self.flatten(x).

    Even so fc(pool(head(x)).view(1, -1)) is completely different to resnet(x) (cf. first point).


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share
...