본문 바로가기

Pytorch

[Pytorch] model layer, Sequential 변경하는 법

반응형

소개

딥러닝 모델 학습할 때에 여러가지 이유로 선언된 모델의 layer을 수정해야할 때가 있다. 만약 layer가 적다면 직접 변경해줄 수 있다. 그러나 모델의 layer가 많아 자동적으로 변경하고 싶거나, nn.Sequential이 사용되면 변경에 어려움을 느끼게 된다. 

이번 포스트에서는 nn.Sequential이 존재하는 model의 layer 수정법을 작성하였다.

 

정리

모델의 멤버변수를 return 해주는 named_children()과 클래스의 모듈들을 return 해주는 _modules를 활용한다.

 

코드

class ToyNet(nn.Module):
  def __init__(self):
    super(ToyNet, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(16),
        nn.ReLU(inplace=True)
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True)
    )
    self.layer3 = nn.Sequential(
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True)
    )
    self.layer4=nn.Conv2d(4,4,4)
    self.fc = nn.Sequential(
        nn.Linear(8*8*64, 10)
    )
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2./n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

  def forward(self, x):
    out = self.layer1(x) #32x32
    out = self.layer2(out) #16x16
    out = self.layer3(out) #8x8

    out = out.view(x.size(0), -1)
    out = self.fc(out)
    
    return out

 

위와 같은 모델이 있다고 하자.

 

모델의 layer에 직접 접근할 때에는 해당 class의 멤버 변수를 리턴해주는 self.named_children()과 모듈을 리스트 형태로 볼 수 있는 self._modules()를 활용한다.

 

여기선 예제로 nn.Conv2d만을 찾아 변경하는 코드를 작성하도록 하겠다. 이를 보고 다양하게 응용할 수 있을 것이다.

model=ToyNet()
print('before change ====')
print(model)
for name, child in model.named_children():
    if isinstance(child, nn.Conv2d):
        model._modules[name]=nn.Conv2d(3,3,3)
    elif isinstance(child, nn.Sequential):
        for sname, schild in child.named_children():
            if isinstance(schild, nn.Conv2d):
                print(name,sname)
                model._modules[name]._modules[sname]=nn.Conv2d(3,3,3)

print(model)

 

변경 이전
변경 이후

 

반응형