diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index af957fec64..920a23db9c 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -1126,16 +1126,16 @@ class EMAConfig(FairseqDataclass): @dataclass class FairseqConfig(FairseqDataclass): - common: CommonConfig = CommonConfig() - common_eval: CommonEvalConfig = CommonEvalConfig() - distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() - dataset: DatasetConfig = DatasetConfig() - optimization: OptimizationConfig = OptimizationConfig() - checkpoint: CheckpointConfig = CheckpointConfig() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() - generation: GenerationConfig = GenerationConfig() - eval_lm: EvalLMConfig = EvalLMConfig() - interactive: InteractiveConfig = InteractiveConfig() + common: CommonConfig = field(default_factory=CommonConfig) + common_eval: CommonEvalConfig = field(default_factory=CommonEvalConfig) + distributed_training: DistributedTrainingConfig = field(default_factory=DistributedTrainingConfig) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + optimization: OptimizationConfig = field(default_factory=OptimizationConfig) + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + bmuf: FairseqBMUFConfig = field(default_factory=FairseqBMUFConfig) + generation: GenerationConfig = field(default_factory=GenerationConfig) + eval_lm: EvalLMConfig = field(default_factory=EvalLMConfig) + interactive: InteractiveConfig = field(default_factory=InteractiveConfig) model: Any = MISSING task: Any = None criterion: Any = None @@ -1144,4 +1144,4 @@ class FairseqConfig(FairseqDataclass): scoring: Any = None bpe: Any = None tokenizer: Any = None - ema: EMAConfig = EMAConfig() + ema: EMAConfig = field(default_factory=EMAConfig) diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py index cc3b777efd..8ea684d71e 100644 --- a/fairseq/models/hubert/hubert.py +++ b/fairseq/models/hubert/hubert.py @@ -354,6 +354,7 @@ def apply_mask(self, x, padding_mask, target_list): min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, + require_same_masks=False, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb