반응형
소개
딥러닝 모델 학습할 때에 여러가지 이유로 선언된 모델의 layer을 수정해야할 때가 있다. 만약 layer가 적다면 직접 변경해줄 수 있다. 그러나 모델의 layer가 많아 자동적으로 변경하고 싶거나, nn.Sequential이 사용되면 변경에 어려움을 느끼게 된다.
이번 포스트에서는 nn.Sequential이 존재하는 model의 layer 수정법을 작성하였다.
정리
모델의 멤버변수를 return 해주는 named_children()과 클래스의 모듈들을 return 해주는 _modules를 활용한다.
코드
class ToyNet(nn.Module):
def __init__(self):
super(ToyNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer4=nn.Conv2d(4,4,4)
self.fc = nn.Sequential(
nn.Linear(8*8*64, 10)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2./n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
out = self.layer1(x) #32x32
out = self.layer2(out) #16x16
out = self.layer3(out) #8x8
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
위와 같은 모델이 있다고 하자.
모델의 layer에 직접 접근할 때에는 해당 class의 멤버 변수를 리턴해주는 self.named_children()과 모듈을 리스트 형태로 볼 수 있는 self._modules()를 활용한다.
여기선 예제로 nn.Conv2d만을 찾아 변경하는 코드를 작성하도록 하겠다. 이를 보고 다양하게 응용할 수 있을 것이다.
model=ToyNet()
print('before change ====')
print(model)
for name, child in model.named_children():
if isinstance(child, nn.Conv2d):
model._modules[name]=nn.Conv2d(3,3,3)
elif isinstance(child, nn.Sequential):
for sname, schild in child.named_children():
if isinstance(schild, nn.Conv2d):
print(name,sname)
model._modules[name]._modules[sname]=nn.Conv2d(3,3,3)
print(model)
반응형
'Pytorch' 카테고리의 다른 글
[Pytorch] 재생산성을 위한 랜덤설정 reproductibility, randomness control, seed (0) | 2022.03.23 |
---|---|
[Pytorch] network(model) parameter 얻는 방법 (1) | 2021.09.26 |