import logging
import os
import torch
from collections import OrderedDict
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.utils.dynamic_import import dynamic_import
[docs]def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (odict): the initial model state_dict
partial_state_dict (odict): the trained model state_dict
modules (list): specified module list for transfer
Return:
(boolean): allow transfer
"""
modules_model = []
partial_modules = []
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
modules_model += [(key_m, value_m.shape)]
len_match = (len(modules_model) == len(partial_modules))
module_match = (sorted(modules_model, key=lambda x: (x[0], x[1])) ==
sorted(partial_modules, key=lambda x: (x[0], x[1])))
return len_match and module_match
[docs]def get_partial_asr_mt_state_dict(model_state_dict, modules):
"""Create state_dict with specified modules matching input model modules.
Args:
model_state_dict (odict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (odict): the updated state_dict
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
[docs]def get_partial_lm_state_dict(model_state_dict, modules):
"""Create compatible ASR state_dict from model_state_dict (LM).
The keys for specified modules are modified to match ASR decoder modules keys.
Args:
model_state_dict (odict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (odict): the updated state_dict
new_mods (list): the updated module list
"""
new_state_dict = OrderedDict()
new_modules = []
for key, value in list(model_state_dict.items()):
if key == "predictor.embed.weight" \
and "predictor.embed." in modules:
new_key = "dec.embed.weight"
new_state_dict[new_key] = value
new_modules += [new_key]
elif "predictor.rnn." in key \
and "predictor.rnn." in modules:
new_key = "dec.decoder." + key.split("predictor.rnn.", 1)[1]
new_state_dict[new_key] = value
new_modules += [new_key]
return new_state_dict, new_modules
[docs]def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in module_state_dict
Args:
model_state_dict (odict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_mods (list): the update module list
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.info("module(s) %s don\'t match or (partially match) "
"available modules in model.", incorrect_mods)
logging.info('for information, the existing modules in model are:')
logging.info('%s', mods_model)
return new_mods
[docs]def load_trained_model(model_path):
"""Load the trained model for recognition.
Args:
model_path(str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), 'model.json'))
logging.info('reading model parameters from ' + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
[docs]def get_trained_model_state_dict(model_path):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to model.***.best
Return:
model.state_dict() (odict): the loaded model state_dict
(str): Type of model. Either ASR/MT or LM.
"""
conf_path = os.path.join(os.path.dirname(model_path), 'model.json')
if 'rnnlm' in model_path:
logging.info('reading model parameters from %s', model_path)
return torch.load(model_path), 'lm'
idim, odim, args = get_model_conf(model_path, conf_path)
logging.info('reading model parameters from ' + model_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert isinstance(model, MTInterface) or isinstance(model, ASRInterface)
return model.state_dict(), 'asr-mt'
[docs]def load_trained_modules(idim, odim, args):
"""Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
Args:
idim (int): initial input dimension.
odim (int): initial output dimension.
args (namespace): The initial model arguments.
Return:
model (torch.nn.Module): The model with pretrained modules.
"""
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, ASRInterface)
main_state_dict = main_model.state_dict()
logging.info('model(s) found for pre-initialization')
for model_path, modules in [(enc_model_path, enc_modules),
(dec_model_path, dec_modules)]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict, mode = get_trained_model_state_dict(model_path)
modules = filter_modules(model_state_dict, modules)
if mode == 'lm':
partial_state_dict, modules = get_partial_lm_state_dict(model_state_dict, modules)
else:
partial_state_dict = get_partial_asr_mt_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(main_state_dict, partial_state_dict,
modules):
logging.info('loading %s from model: %s', modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.info('modules %s in model %s don\'t match your training config',
modules, model_path)
else:
logging.info('model was not found : %s', model_path)
main_model.load_state_dict(main_state_dict)
return main_model