From 5b02db5b2b83f1f7a4bc66ae642da201b9a65be2 Mon Sep 17 00:00:00 2001 From: Kapulkin Stanislav Date: Wed, 9 Feb 2022 17:53:40 +0300 Subject: [PATCH] [+] save model state dict and load it --- src/train/loadModel.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/train/loadModel.py diff --git a/src/train/loadModel.py b/src/train/loadModel.py new file mode 100644 index 0000000..c3d2842 --- /dev/null +++ b/src/train/loadModel.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from src.models.modnet import MODNet + +def makeStateDict(modelPath): + modnet = MODNet(backbone_pretrained=False) + + torch.save(modnet.state_dict(), modelPath) + +def loadStateDict(modelPath): + modelState = torch.load(modelPath, map_location=torch.device('cpu')) + + state = {} + + prefix = "module." + for key in modelState: + stateKey = prefix + key + state[stateKey] = modelState[key] + return state + +def main(): + modelPath = "models/model.ckpt" + pretrainedModelPath = "pretrained/modnet_webcam_portrait_matting.ckpt" + + makeStateDict(modelPath) + + modnet = MODNet(backbone_pretrained=False) + modnet = nn.DataParallel(modnet) + + state = loadStateDict(modelPath) + stateKeys = list(state.keys()) + print(f"state keys {stateKeys[:5]}") + + modnet.load_state_dict(state) + + pretrainedState = torch.load(pretrainedModelPath, map_location=torch.device('cpu')) + pretrainedStateKeys = list(pretrainedState.keys()) + print(f"pretrainedState keys {pretrainedStateKeys[:5]}") + + modnet.load_state_dict(pretrainedState) + + print(f"state {len(stateKeys)}, preptrainedState {len(pretrainedStateKeys)}, intersection {len(set(stateKeys) & set(pretrainedStateKeys))}") + +if __name__ == "__main__": + main()