HyperAIHyperAI

Command Palette

Search for a command to run...

The Design Space of Tri-Modal Masked Diffusion Models

Abstract

Discrete diffusion models have emerged as strong alternatives to autoregressive language models, with recent work initializing and finetuning a base unimodal model for bi-modal generation. Diverging from previous approaches, we introduce the first tri-modal Masked Diffusion Models (MDM) pretrained from scratch on text, image-text, and audio-text data. We systematically analyze multimodal scaling laws, modality mixing ratios, noise schedules, batch-size effects and provide optimized inference sampling defaults. Our batch-size analysis yields a novel stochastic differential equation (SDE) based reparameterization, eliminating the need for tuning the optimal batch size as reported in recent work. This re-parameterization decouples the physical batch size, often chosen based on compute (GPU saturation, FLOP-efficiency, wall-clock time) from the logical batch size, chosen to balance the variance of gradients during stochastic optimization. Finally, we pretrain a preliminary model showcasing the capabilities of a unified design, achieving strong results at 3B model scale (6.4T tokens), in both text generation, T2I tasks, and T2S tasks. Our work represents the largest scale systematic open study of multimodal discrete diffusion models conducted to date, providing valuable insights into scaling behaviors across multiple modalities.

One-sentence Summary

This study introduces the first tri-modal Masked Diffusion Models (MDM) pretrained from scratch on text, image-text, and audio-text data, providing a systematic analysis of multimodal scaling laws and a novel SDE-based reparameterization that decouples physical and logical batch sizes to optimize cross-modal generation.

Key Contributions

  • This work introduces the first tri-modal Masked Diffusion Model (MDM) pretrained from scratch on a unified stream of text, image, and audio tokens using a single transformer backbone. The architecture enables flexible cross-modal tasks such as text-to-image and text-to-speech without requiring modality-specific heads or bespoke factorizations.
  • The paper presents a novel stochastic differential equation (SDE) based reparameterization that decouples the physical batch size used for hardware efficiency from the logical batch size used for gradient variance control. This method eliminates the need for expensive manual tuning of optimal batch sizes during large-scale training.
  • The study provides a systematic analysis of multimodal scaling laws, modality mixing ratios, and noise schedules, supported by a 3B parameter model trained on 6.4T tokens. Results demonstrate strong performance across text generation, text-to-image, and text-to-speech tasks, while also validating that interventions like anti-masking improve benchmark performance without increasing computational costs.

Introduction

While causal transformers dominate modern sequence modeling, they rely on a strict left-to-right factorization that may not be optimal for conditional generation tasks where evidence is scattered across different modalities. Discrete diffusion models offer a bidirectional alternative through iterative refinement, yet existing multimodal research often relies on adapting pretrained unimodal models or focuses only on bi-modal (text and image) setups. This limits the ability to create truly unified systems capable of handling diverse data streams like audio.

The authors leverage a tri-modal Masked Diffusion Model (MDM) pretrained from scratch on text, image, and audio data using a single transformer backbone and a unified discrete token space. They introduce an SDE-based reparameterization that decouples physical batch size from logical batch size, effectively eliminating the need to tune for an optimal batch size during training. Furthermore, the authors provide a systematic study of multimodal scaling laws and demonstrate that optimal inference parameters, such as noise schedules and guidance, must be tailored specifically to each modality.

Method

The authors propose a unified modeling framework designed to handle multiple modalities, specifically text, audio, and image, within a single transformer-based architecture. To achieve this, they construct a shared vocabulary V\mathcal{V}V through the disjoint union of modality-specific vocabularies: V=VtextVaudioVimage\mathcal{V} = \mathcal{V}_{\text{text}} \sqcup \mathcal{V}_{\text{audio}} \sqcup \mathcal{V}_{\text{image}}V=VtextVaudioVimage. This unified vocabulary is augmented with modality-specific special tokens, such as BOSm\text{BOS}_mBOSm, EOSm\text{EOS}_mEOSm, and MASKm\text{MASK}_mMASKm for each modality m{text,audio,image}m \in \{\text{text}, \text{audio}, \text{image}\}m{text,audio,image}, as well as task-specific tokens Vtask\mathcal{V}_{\text{task}}Vtask that signal the intended operation, such as TASKtext\text{TASK}_{\text{text}}TASKtext or TASKaudio-text\text{TASK}_{\text{audio-text}}TASKaudio-text.

The training process relies on constructing sequences where modality tokens are wrapped with their respective boundary tokens. For instance, an audio-text sample is formatted as a sequence starting with a task token, followed by the audio segment (bounded by BOSaudio\text{BOS}_{\text{audio}}BOSaudio and EOSaudio\text{EOS}_{\text{audio}}EOSaudio), and ending with the text segment (bounded by BOStext\text{BOS}_{\text{text}}BOStext and EOStext\text{EOS}_{\text{text}}EOStext). To maintain a consistent sequence length LL^\starL across a minibatch, the authors employ packing for text-only sequences and right-padding with PADtext\text{PAD}_{\text{text}}PADtext for mixed-modality sequences that are shorter than the target length.

As shown in the figure below, the training data is organized into minibatches where different sequences follow these specific formatting and padding rules before being processed by the model:

The core of the method is a continuous-time forward masking process indexed by t[0,1]t \in [0, 1]t[0,1]. Each position in the sequence is independently corrupted according to a Bernoulli masking mechanism with a probability βt\beta_tβt, where β\betaβ is a monotonic function. The corrupted token stis_t^isti is either replaced by a modality-specific mask token MASKm(i)\text{MASK}_{m(i)}MASKm(i) or remains the same as the previous state st1is_{t-1}^ist1i. This process ensures that once a token is masked, it remains masked, smoothly interpolating from the original sequence at t=0t=0t=0 to a fully masked sequence at t=1t=1t=1.

To perform the reverse process, the authors utilize a denoising model fθf_{\theta}fθ parameterized as a bi-directional transformer. This model predicts logits over the unified vocabulary for each position in the corrupted sequence sts_tst. The training objective is to minimize a per-token loss i(θ,s)\ell_i(\theta, s)i(θ,s) averaged over the set of masked, non-padding positions ItI_tIt. To ensure an unbiased estimator of the Evidence Lower Bound (ELBO) under the Bernoulli masking scheme, the authors apply an importance weighting w(t)=1/tw(t) = 1/tw(t)=1/t. This weighting compensates for the fact that fewer tokens are masked at early time steps, ensuring that every token contributes equally to the loss in expectation across the entire diffusion process.

Experiment

The experiments evaluate the scaling properties, architectural design, and training dynamics of tri-modal Masked Discrete Diffusion Models (MDM) under SDE reparameterization. Results demonstrate that the critical batch size is independent of model size but scales sub-linearly with the token horizon, and that a polynomial masking schedule provides superior generation quality across modalities. Furthermore, the study establishes robust scaling laws for tri-modal MDMs, revealing that these models become increasingly data-efficient per parameter as they grow, requiring significantly more tokens than traditional autoregressive language models to reach compute optimality.

The authors compare various image tokenizers by evaluating their reconstruction performance on the CC12M and ImageNet datasets. The results demonstrate that different tokenizer types, including continuous, FSQ, IBQ, MoVQ, LFQ, and MCQ, yield varying levels of reconstruction fidelity. MoVQ-based tokenizers achieve high reconstruction performance across both datasets. MCQ tokenizers show strong performance on ImageNet at higher resolutions. Continuous models and certain discrete tokenizers exhibit different levels of effectiveness depending on the dataset and resolution used.

The authors compare standard MDM training against an anti-masking strategy across different modalities. Results indicate that the anti-masking approach improves generation quality for both images and audio. Anti-masking leads to lower FID scores for image generation on both training and CC12M datasets. The anti-masking method improves audio generation quality as measured by FAD on both training and LibriSpeech data. The performance gains from anti-masking are observed across both unimodal text and multimodal settings.

The authors compare the performance of standard MDM training against an anti-masking strategy across various text-based benchmarks. Results indicate that the anti-masking approach generally improves model accuracy across most evaluated tasks. Anti-masking leads to performance gains in reasoning and knowledge benchmarks such as BBH, MMLU, and ARC-Challenge The strategy shows consistent improvements in linguistic and common sense tasks including Winogrande and HellaSwag Most metrics demonstrate higher mean accuracy when using the anti-masking technique compared to the base model

The authors present the results of a per-module hyperparameter search to optimize AdamW settings for the tri-modal MDM. The findings show that different components of the model, such as embedding weights and various transformer block parameters, require distinct learning rate multipliers, weight decay, and epsilon values to achieve optimal performance. Embedding and unembedding weights benefit from significantly larger effective learning rates compared to other modules. Attention projection and MLP gate weights are tuned more conservatively with specific adjustments to epsilon for numerical stability. The learned depth factors indicate that later blocks in the model benefit from smaller updates and increased stabilization.

The authors compare different audio tokenizers based on reconstruction and perceptual metrics. The results show that increasing the number of codebooks generally improves reconstruction performance, though it impacts the efficiency of the token rate. Higher codebook counts lead to better PESQ scores across different pretrained models The Higgs pretrained tokenizer demonstrates strong performance in content enjoyment and usefulness The DAC retrained model with fewer codebooks achieves high scores in content usefulness and probability quality

The researchers evaluate various image and audio tokenizers, training strategies, and hyperparameter configurations to optimize tri-modal model performance. The results show that MoVQ and MCQ tokenizers provide strong reconstruction fidelity, while an anti-masking strategy consistently enhances generation quality across image, audio, and text modalities. Furthermore, a per-module hyperparameter search reveals that different model components require specialized optimization settings, and increasing audio codebook counts generally improves perceptual quality.


Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing

HyperAI Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp