mirror of https://github.com/ZHKKKe/MODNet.git
[+] save model state dict and load it
parent
4e2f7ddfbe
commit
5b02db5b2b
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.models.modnet import MODNet
|
||||
|
||||
def makeStateDict(modelPath):
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
|
||||
torch.save(modnet.state_dict(), modelPath)
|
||||
|
||||
def loadStateDict(modelPath):
|
||||
modelState = torch.load(modelPath, map_location=torch.device('cpu'))
|
||||
|
||||
state = {}
|
||||
|
||||
prefix = "module."
|
||||
for key in modelState:
|
||||
stateKey = prefix + key
|
||||
state[stateKey] = modelState[key]
|
||||
return state
|
||||
|
||||
def main():
|
||||
modelPath = "models/model.ckpt"
|
||||
pretrainedModelPath = "pretrained/modnet_webcam_portrait_matting.ckpt"
|
||||
|
||||
makeStateDict(modelPath)
|
||||
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
state = loadStateDict(modelPath)
|
||||
stateKeys = list(state.keys())
|
||||
print(f"state keys {stateKeys[:5]}")
|
||||
|
||||
modnet.load_state_dict(state)
|
||||
|
||||
pretrainedState = torch.load(pretrainedModelPath, map_location=torch.device('cpu'))
|
||||
pretrainedStateKeys = list(pretrainedState.keys())
|
||||
print(f"pretrainedState keys {pretrainedStateKeys[:5]}")
|
||||
|
||||
modnet.load_state_dict(pretrainedState)
|
||||
|
||||
print(f"state {len(stateKeys)}, preptrainedState {len(pretrainedStateKeys)}, intersection {len(set(stateKeys) & set(pretrainedStateKeys))}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue