본문 바로가기
Artificial Intelligence/Trouble shooting

(PyTorch) Missing keys & unexpected keys in state_dict when loading self trained model

by sohyunwriter 2022. 4. 23.

(Trouble)

Missing keys & unexpected keys in state_dict when loading self trained model

 

에러 예시1)

RuntimeError: Error(s) in loading state_dict for VGG:
        Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias".
        Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".

 

에러 예시2)

 

...
Missing key(s) in state_dict: "1.weight", "1.bias", "4.weight", "4.bias", "7.weight", "7.bias".
Unexpected key(s) in state_dict: "model.1.weight", "model.1.bias", "model.4.weight", "model.4.bias", "model.7.weight", "model.7.bias".

 

 

내가 싫어하는(?) 에러 중에 하나가 모델 load state dict 단계에서 에러나는 거다.

처음에 해커톤에 나갔을 때 파이프라인을 다 만들고 모델까지 저장하고

이제 저장한 모델을 불러와서 인퍼런싱만 하면 된다는 기대감에 부풀어 있었을 때

아래와 같은 에러를 발견했을 때가 있었다.

 

 


 

(해결방법)

 

1) 모델 객체를 생성한다.

model = Model(**config.model)

 

2) checkpoint가 load되는지 확인한다.

checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']

만약, model_state_dict가 key로 없다면 torch.load까지만 해보고 key가 뭐가 있는지 확인한다.

 

3-1) model에 checkpoint를 load한다.

self.model.load_state_dict(checkpoint)

 

3-2) 3-1)이 안된다면, 다음과 같이 model.을 ''으로 바꾼 후 checkpoint를 load한다.

for key in list(checkpoint.keys()):
    if 'model.' in key:
        checkpoint[key.replace('model.', '')] = checkpoint[key]
        del checkpoint[key]
self.model.load_state_dict(checkpoint)

 

3-3) model.을 ''으로 바꿨는데 안 되면 model. 부분을 module.으로 바꾸고 진행한다.

 

 

checkpoint 부분에 keys를 까보면 알겠지만 model.이나 module. 이 덧붙여져서 존재하지 않는 key로 인식해 못 찾아오는 경우도 있다. 이 경우 model. 이나 module.을 없애주면 제대로 key를 인식한다.

 

module.이 붙는 이유는 DataParallel을 사용해서 module.이 붙는 것 같다.

그러니까 DataParallel을 사용할 경우 model 저장할 때 다음과 같이 해야 올바르다.

torch.save(model.module.state_dict(), 'file_name.pt')

 

그런데 model.은 왜 붙는지 모르겠다. 그냥 load할 때 에러 뜨면 model. 그때마다 지워줘야할 듯.