# Copyright 2021 Tianzi Wang
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0
# Thanks to Abdelrahman Mohamed and Wei-Ning Hsu's help in this implementation,
# Their origial Hubert work is in:
# Paper: https://arxiv.org/pdf/2106.07447.pdf
# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert
"""Encoder definition."""
import contextlib
import copy
import logging
import os
import torch
import yaml
from filelock import FileLock
from pathlib import Path
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet2.asr.encoder.abs_encoder import AbsEncoder
[docs]class FairseqHubertEncoder(AbsEncoder):
"""FairSeq Hubert encoder module, used for loading pretrained weight and finetuning
Args:
input_size: input dim
hubert_url: url to Hubert pretrained model
hubert_dir_path: directory to download the Wav2Vec2.0 pretrained model.
output_size: dimension of attention
normalize_before: whether to use layer_norm before the first block
freeze_finetune_updates: steps that freeze all layers except output layer
before tuning the whole model (nessasary to prevent overfit).
dropout_rate: dropout rate
activation_dropout: dropout rate in activation function
attention_dropout: dropout rate in attention
Hubert specific Args:
Please refer to:
https://github.com/pytorch/fairseq/blob/master/fairseq/models/hubert/hubert.py
"""
def __init__(
self,
input_size: int,
hubert_url: str = "./",
hubert_dir_path: str = "./",
output_size: int = 256,
normalize_before: bool = False,
freeze_finetune_updates: int = 0,
dropout_rate: float = 0.0,
activation_dropout: float = 0.1,
attention_dropout: float = 0.0,
mask_length: int = 10,
mask_prob: float = 0.75,
mask_selection: str = "static",
mask_other: int = 0,
apply_mask: bool = True,
mask_channel_length: int = 64,
mask_channel_prob: float = 0.5,
mask_channel_other: int = 0,
mask_channel_selection: str = "static",
layerdrop: float = 0.1,
feature_grad_mult: float = 0.0,
):
assert check_argument_types()
super().__init__()
self.apply_mask = apply_mask
try:
import fairseq
from fairseq.models.hubert.hubert import HubertModel
except Exception as e:
print("Error: FairSeq is not properly installed.")
print("Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done")
raise e
arg_overrides = {
"dropout": dropout_rate,
"activation_dropout": activation_dropout,
"attention_dropout": attention_dropout,
"mask_length": mask_length,
"mask_prob": mask_prob,
"mask_selection": mask_selection,
"mask_other": mask_other,
"mask_channel_length": mask_channel_length,
"mask_channel_prob": mask_channel_prob,
"mask_channel_selection": mask_channel_selection,
"mask_channel_other": mask_channel_other,
"encoder_layerdrop": layerdrop,
"feature_grad_mult": feature_grad_mult,
"data": hubert_dir_path,
}
if hubert_url == "espnet":
self.hubert_model_path = hubert_dir_path
s = torch.load(
self.hubert_model_path,
map_location=torch.device("cpu"),
)
if all("encoder.encoder" in k for k in s):
try:
state = {
k.replace("encoder.encoder.", ""): v
for k, v in s.items()
if "label_embs_concat" not in k
}
except Exception as e:
raise e
config_file = os.path.join(
"/".join(self.hubert_model_path.split("/")[:-1]),
"config.yaml",
)
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
self.pretrained_cfg = yaml.safe_load(f)
model = FairseqHubertPretrainEncoder(
input_size=self.pretrained_cfg["input_size"],
hubert_dict=self.pretrained_cfg["hubert_dict"],
**self.pretrained_cfg["encoder_conf"],
)
model = model.encoder
d = self.pretrained_cfg["encoder_conf"]["output_size"]
self.pretrained_params = copy.deepcopy(state)
else:
self.hubert_model_path = download_hubert(hubert_url, hubert_dir_path)
(
models,
self.pretrained_cfg,
task,
) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[self.hubert_model_path],
arg_overrides=arg_overrides,
strict=False,
)
model = models[0]
d = self.pretrained_cfg.model.encoder_embed_dim
self.pretrained_params = copy.deepcopy(model.state_dict())
self._output_size = output_size
if not isinstance(model, HubertModel):
try:
model = model.hubert_encoder.hubert_model
except Exception as e:
print(
"Error: pretrained models should be within: "
"'HubertModel, Hubertctc' classes, etc."
)
raise e
self.encoders = model
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
if output_size and output_size != d:
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(d, output_size),
)
else:
self.output_layer = None
self.freeze_finetune_updates = freeze_finetune_updates
self.register_buffer("num_updates", torch.LongTensor([0]))
[docs] def output_size(self) -> int:
return self._output_size
[docs] def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Forward Hubert ASR Encoder.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = make_pad_mask(ilens).to(xs_pad.device)
ft = self.freeze_finetune_updates <= self.num_updates
if self.num_updates <= self.freeze_finetune_updates:
self.num_updates += 1
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
self.num_updates += 1
logging.info("Start fine-tuning hubert parameters!")
else:
self.num_updates += 1
with torch.no_grad() if not ft else contextlib.nullcontext():
enc_outputs = self.encoders(
xs_pad,
padding_mask=masks,
mask=self.apply_mask and self.training,
features_only=True,
output_layer=None,
)
xs_pad = enc_outputs["x"] # (B,T,C),
masks = enc_outputs["padding_mask"] # (B, T)
# save gpu memory
del enc_outputs
olens = (~masks).sum(dim=1)
if self.output_layer is not None:
xs_pad = self.output_layer(xs_pad)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
return xs_pad, olens, None
[docs] def reload_pretrained_parameters(self):
self.encoders.load_state_dict(self.pretrained_params, strict=False)
logging.info("Pretrained Hubert model parameters reloaded!")
[docs]class FairseqHubertPretrainEncoder(AbsEncoder):
"""FairSeq Hubert pretrain encoder module, only used for pretraining stage
Args:
input_size: input dim
output_size: dimension of attention
linear_units: dimension of feedforward layers
attention_heads: the number of heads of multi head attention
num_blocks: the number of encoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
hubert_dict: target dictionary for Hubert pretraining
label_rate: label frame rate. -1 for sequence label
sample_rate: target sample rate.
use_amp: whether to use automatic mixed precision
normalize_before: whether to use layer_norm before the first block
"""
def __init__(
self,
input_size: int = 1,
output_size: int = 1024,
linear_units: int = 1024,
attention_heads: int = 12,
num_blocks: int = 12,
dropout_rate: float = 0.0,
attention_dropout_rate: float = 0.0,
activation_dropout_rate: float = 0.0,
hubert_dict: str = "./dict.txt",
label_rate: int = 100,
checkpoint_activations: bool = False,
sample_rate: int = 16000,
use_amp: bool = False,
**kwargs,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
self.use_amp = use_amp
try:
from fairseq.data.dictionary import Dictionary
from fairseq.models.hubert.hubert import (
HubertModel, # noqa: H301
HubertConfig, # noqa: H301
HubertPretrainingConfig, # noqa: H301
)
except Exception as e:
print("Error: FairSeq is not properly installed.")
print("Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done")
raise e
cfg_overides = {
"encoder_embed_dim": output_size,
"encoder_ffn_embed_dim": linear_units,
"encoder_attention_heads": attention_heads,
"encoder_layers": num_blocks,
"final_dim": output_size,
"dropout": dropout_rate,
"attention_dropout": attention_dropout_rate,
"label_rate": label_rate,
"checkpoint_activations": checkpoint_activations,
}
cfg_overides = {**cfg_overides, **kwargs}
self.cfg = HubertConfig()
for key, value in cfg_overides.items():
if hasattr(self.cfg, key):
setattr(self.cfg, key, value)
hubert_task_cfg = HubertPretrainingConfig()
hubert_task_cfg_overides = {
"label_rate": label_rate,
"sample_rate": sample_rate,
}
for key, value in hubert_task_cfg_overides.items():
if hasattr(hubert_task_cfg, key):
setattr(hubert_task_cfg, key, value)
d = Dictionary()
self._build_dictionary(d, hubert_dict)
self.encoder = HubertModel(self.cfg, hubert_task_cfg, self.dictionaries)
def _build_dictionary(self, dictionary, hubert_dict_path):
if os.path.exists(f"{hubert_dict_path}"):
setattr(dictionary, "symbols", [])
setattr(dictionary, "count", [])
setattr(dictionary, "indices", {})
dictionary.add_from_file(f"{hubert_dict_path}")
else:
dictionary.add_symbol("0")
self.dictionaries = [dictionary]
[docs] def output_size(self) -> int:
return self._output_size
[docs] def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_length: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Forward Hubert Pretrain Encoder.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
self.cast_mask_emb()
masks = make_pad_mask(ilens).to(xs_pad.device)
ys_pad = ys_pad[:, : min(ys_pad_length)]
enc_outputs = self.encoder(
xs_pad,
padding_mask=masks,
mask=True,
target_list=[ys_pad],
features_only=False,
)
return enc_outputs
[docs] def cast_mask_emb(self):
if self.use_amp and self.encoder.mask_emb.dtype != torch.cuda.HalfTensor:
self.encoder.mask_emb = torch.nn.Parameter(self.encoder.mask_emb.half())
[docs] def reload_pretrained_parameters(self):
self.encoder.mask_emb = torch.nn.Parameter(
torch.HalfTensor(self.cfg.encoder_embed_dim).uniform_()
)
logging.info(
f"Hubert mask embedding re-initiallized!, \
{self.encoder.mask_emb.dtype}, \
{self.use_amp}"
)
[docs]def download_hubert(model_url, dir_path):
os.makedirs(dir_path, exist_ok=True)
model_name = model_url.split("/")[-1]
model_path = os.path.join(dir_path, model_name)
with FileLock(model_path + ".lock"):
if not os.path.exists(model_path):
torch.hub.download_url_to_file(model_url, model_path)
logging.info(f"Hubert model downloaded {model_path}")
else:
logging.info(f"Hubert model {model_path} already exists.")
return model_path