[+] save model state dict and load it

pull/170/head
Kapulkin Stanislav 2022-02-09 17:53:40 +03:00
parent 4e2f7ddfbe
commit 5b02db5b2b
1 changed files with 46 additions and 0 deletions

46
src/train/loadModel.py Normal file
View File

@ -0,0 +1,46 @@
import torch
import torch.nn as nn
from src.models.modnet import MODNet
def makeStateDict(modelPath):
modnet = MODNet(backbone_pretrained=False)
torch.save(modnet.state_dict(), modelPath)
def loadStateDict(modelPath):
modelState = torch.load(modelPath, map_location=torch.device('cpu'))
state = {}
prefix = "module."
for key in modelState:
stateKey = prefix + key
state[stateKey] = modelState[key]
return state
def main():
modelPath = "models/model.ckpt"
pretrainedModelPath = "pretrained/modnet_webcam_portrait_matting.ckpt"
makeStateDict(modelPath)
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
state = loadStateDict(modelPath)
stateKeys = list(state.keys())
print(f"state keys {stateKeys[:5]}")
modnet.load_state_dict(state)
pretrainedState = torch.load(pretrainedModelPath, map_location=torch.device('cpu'))
pretrainedStateKeys = list(pretrainedState.keys())
print(f"pretrainedState keys {pretrainedStateKeys[:5]}")
modnet.load_state_dict(pretrainedState)
print(f"state {len(stateKeys)}, preptrainedState {len(pretrainedStateKeys)}, intersection {len(set(stateKeys) & set(pretrainedStateKeys))}")
if __name__ == "__main__":
main()