-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add ACE-Step pipeline for text-to-music generation #13095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
## What does this PR do? This PR adds support for the ACE-Step pipeline, a text-to-music generation model that generates high-quality music with lyrics from text prompts. ACE-Step generates variable-length stereo music at 48kHz from text prompts and optional lyrics. The implementation includes: - **AceStepDiTModel**: A Diffusion Transformer (DiT) model that operates in the latent space using flow matching - **AceStepPipeline**: The main pipeline for text-to-music generation with support for lyrics conditioning - **AceStepConditionEncoder**: Condition encoder that combines text, lyric, and timbre embeddings - **Conversion script**: Script to convert ACE-Step checkpoint weights to Diffusers format - **Comprehensive tests**: Full test coverage for the pipeline and models - **Documentation**: API documentation for the pipeline and transformer model ## Key Features - Text-to-music generation with optional lyrics support - Multi-language lyrics support (English, Chinese, Japanese, Korean, and more) - Flow matching with custom timestep schedules - Turbo model variant optimized for 8 inference steps - Variable-length audio generation (configurable duration) ## Technical Details ACE-Step comprises three main components: 1. **Oobleck autoencoder (VAE)**: Compresses waveforms into 25Hz latent representations 2. **Qwen3-based text encoder**: Encodes text prompts and lyrics for conditioning 3. **Diffusion Transformer (DiT)**: Operates in the latent space using flow matching The pipeline supports multiple shift parameters (1.0, 2.0, 3.0) for different timestep schedules, with the turbo model designed for 8 inference steps using `shift=3.0`. ## Testing All tests pass successfully: - Model forward pass tests - Pipeline basic functionality tests - Batch processing tests - Latent output tests - Return dict tests Run tests with: ```bash pytest tests/pipelines/ace_step/test_ace_step.py -v ``` ## Code Quality - Code formatted with `make style` - Quality checks passed with `make quality` - All tests passing ## References - Original codebase: [ACE-Step/ACE-Step](https://github.com/ACE-Step/ACE-Step) - Paper: [ACE-Step: A Step Towards Music Generation Foundation Model](https://github.com/ACE-Step/ACE-Step)
- Add gradient checkpointing test for AceStepDiTModel - Add save/load config test for AceStepConditionEncoder - Enhance pipeline tests with PipelineTesterMixin - Update documentation to reflect ACE-Step 1.5 - Add comprehensive transformer model tests - Improve test coverage and code quality
- Add support for multiple task types: text2music, repaint, cover, extract, lego, complete - Add audio normalization and preprocessing utilities - Add tiled encode/decode for handling long audio sequences - Add reference audio support for timbre transfer in cover task - Add repaint functionality for regenerating audio sections - Add metadata handling (BPM, keyscale, timesignature) - Add audio code parsing and chunk mask building utilities - Improve documentation with multi-task usage examples
|
Hi @ChuxiJ, thanks for the PR! As a preliminary comment, I tried the test script given above but got an error, which I think is due to the fact that the If I convert the checkpoint locally from a local snapshot of python scripts/convert_ace_step_to_diffusers.py \
--checkpoint_dir /path/to/acestep-v15-repo \
--dit_config acestep-v15-turbo \
--output_dir /path/to/acestep-v15-diffusers \
--dtype bf16and then test it using the following script: import torch
import soundfile as sf
from diffusers import AceStepPipeline
OUTPUT_SAMPLE_RATE = 48000
model_id = "/path/to/acestep-v15-diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42
pipe = AceStepPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
# Text-to-music generation
audio = pipe(
prompt="A beautiful piano piece with soft melodies",
lyrics="[verse]\nSoft notes in the morning light\n[chorus]\nMusic fills the air tonight",
audio_duration=30.0,
num_inference_steps=8,
bpm=120,
keyscale="C major",
generator=generator,
).audios
sf.write("acestep_t2m.wav", audio[0, 0].cpu().numpy(), OUTPUT_SAMPLE_RATE)I get the following sample: The sample quality is lower than expected, so there is probably a bug. Could you look into it? |
| return mask_tensor | ||
|
|
||
|
|
||
| def _pack_sequences( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As _pack_sequences is not used in the DiT code but used in the condition encoder code, could it be moved to modeling_ace_step.py?
| return hidden_left, new_mask | ||
|
|
||
|
|
||
| class AceStepRMSNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use diffusers.models.normalization.RMSNorm in place of AceStepRMSNorm?
| class RMSNorm(nn.Module): |
I believe the implementations are essentially the same (including the FP32 upcasting).
| return self.weight * hidden_states.to(input_dtype) | ||
|
|
||
|
|
||
| class AceStepRotaryEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use get_1d_rotary_pos_embed in place of AceStepRotaryEmbedding?
diffusers/src/diffusers/models/embeddings.py
Line 1120 in 20efb79
| def get_1d_rotary_pos_embed( |
I believe something like
position_embeddings = get_1d_rotary_pos_embed(
self.config.head_dim,
position_ids,
theta=self.config.rope_theta,
use_real=True,
freqs_dtype=torch.float32,
)should be equivalent.
| class AceStepRotaryEmbedding(nn.Module): | ||
| """Rotary Position Embedding (RoPE) for ACE-Step attention layers.""" | ||
|
|
||
| def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 1000000.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is max_position_embeddings used anywhere? If not, could it be removed?
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | ||
|
|
||
|
|
||
| class AceStepTimestepEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic here is already implemented in Timesteps for the sinusoidal embedding
diffusers/src/diffusers/models/embeddings.py
Line 1310 in 20efb79
| class Timesteps(nn.Module): |
and TimestepEmbedding for the MLP:
diffusers/src/diffusers/models/embeddings.py
Line 1262 in 20efb79
| class TimestepEmbedding(nn.Module): |
Could we refactor AceStepTimestepEmbedding into Timesteps + TimestepEmbedding + a custom AdaLayerNormZero implementation (e.g. AceStepAdaLayerNormZero)? (I believe none of the existing AdaLN implementations match the one used here.)
| return temb, timestep_proj | ||
|
|
||
|
|
||
| class AceStepAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we refactor AceStepAttention into a Attention + AttnProcessor design? For example, Flux 2 implements a Flux2Attention class which holds the attention state (e.g. Q, K, V projections):
| class Flux2Attention(torch.nn.Module, AttentionModuleMixin): |
and a Flux2AttnProcessor class which defines the attention logic:
| class Flux2AttnProcessor: |
This makes it easier to support attention backends such as Flash Attention and operations like fusing/unfusing QKV projections.
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=-3) | ||
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=-3) | ||
|
|
||
| attn_output = F.scaled_dot_product_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the refactoring suggested in #13095 (comment), you can support attention backends such as Flash Attention by using the dispatch_attention_fn function:
| def dispatch_attention_fn( |
A usage example from Flux 2 is as follows:
diffusers/src/diffusers/models/transformers/transformer_flux2.py
Lines 158 to 165 in 20efb79
| hidden_states = dispatch_attention_fn( | |
| query, | |
| key, | |
| value, | |
| attn_mask=attention_mask, | |
| backend=self._attention_backend, | |
| parallel_config=self._parallel_config, | |
| ) |
You can look at the attention backend docs for more info.
| return attn_output | ||
|
|
||
|
|
||
| class AceStepEncoderLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since AceStepEncoderLayer isn't used by the DiT model, can it be moved to modeling_ace_step.py?
| return hidden_states | ||
|
|
||
|
|
||
| class AceStepDiTLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename AceStepDiTLayer to AceStepTransformerBlock, following the usual diffusers naming convention
| return hidden_states | ||
|
|
||
|
|
||
| class AceStepDiTModel(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename AceStepDiTModel to AceStepTransformer1DModel following the usual diffusers naming convention
| return hidden_states | ||
|
|
||
|
|
||
| class AceStepDiTModel(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: having AceStepDiTModel inherit from AttentionMixin will help support model-wide attention operations like fusing QKV projections.
In addition, if the ACE-Step model is compatible with caching techniques like MagCache, you can also consider inheriting from CacheMixin:
| class CacheMixin: |
| attention_bias: bool = False, | ||
| attention_dropout: float = 0.0, | ||
| rms_norm_eps: float = 1e-6, | ||
| use_sliding_window: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a case where we want use_sliding_window=False? If not, perhaps we can remove this argument?
|
|
||
| if torch.is_grad_enabled() and self.gradient_checkpointing: | ||
| hidden_states = self._gradient_checkpointing_func( | ||
| layer_module.__call__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| layer_module.__call__, | |
| layer_module, |
nit: the usual idiom here is to supply just layer_module, such as in Flux 2:
diffusers/src/diffusers/models/transformers/transformer_flux2.py
Lines 863 to 872 in 20efb79
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( | |
| block, | |
| hidden_states, | |
| encoder_hidden_states, | |
| double_stream_mod_img, | |
| double_stream_mod_txt, | |
| concat_rotary_emb, | |
| joint_attention_kwargs, | |
| ) |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| class AceStepLyricEncoder(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the original code supports gradient checkpointing for AceStepLyricEncoder and AceStepTimbreEncoder, so I think we can support it here as well, in the same manner as AceStepDiTModel.
| TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"] | ||
|
|
||
| # Sample rate used by ACE-Step | ||
| SAMPLE_RATE = 48000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of hardcoding the sample rate, could we read it from the VAE config (self.vae.config.sampling_rate)?
| latents = latents.squeeze(0) | ||
| return latents | ||
|
|
||
| def _tiled_encode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to move the VAE tiled encoding/decoding logic to the VAE (AutoencoderOoblock? Ideally, the VAE defines the tiling logic and then users can enable it if desired with pipe.vae.enable_tiling().
|
|
||
| model_cpu_offload_seq = "text_encoder->condition_encoder->transformer->vae" | ||
|
|
||
| def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it including the audio tokenizer and detokenizer (I believe AceStepAudioTokenizer and AudioTokenDetokenizer in the original code) as registered modules would make the pipeline easier to use and more self-contained. If I understand correctly, this would allow users to supply raw audio waveforms instead of audio_codes to __call__ and the user would not have to manually tokenize the audio to an audio code string first.
| If `return_dict` is `True`, an `AudioPipelineOutput` is returned, otherwise a tuple with the generated | ||
| audio. | ||
| """ | ||
| # 0. Default values and input validation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguments to __call__ should be validated in a check_inputs method. Here is an example from the Flux 2 pipeline:
| def check_inputs( |
| ) | ||
|
|
||
| # 2. Prepare source latents and latent length | ||
| latent_length = int(audio_duration * 25) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to calculate the latents per second instead of hardcoding it to be 25? I believe this can be done from the VAE config values as
latents_per_second = self.vae.config.sampling_rate / math.prod(self.vae.config.downsampling_ratios)| ) | ||
| latent_length = src_latent_length | ||
| else: | ||
| src_latents = torch.zeros(batch_size, latent_length, acoustic_dim, device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the src_audio / audio_codes to src_latents logic be refactored into a method (for example encode_audio)? I think this would make the code more readable.
| ) | ||
|
|
||
| # 8. Prepare null condition for CFG (if guidance_scale > 1) | ||
| do_cfg = guidance_scale > 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be refactored into a do_classifier_free_guidance property? See for example LTX2Pipeline:
diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Lines 760 to 762 in 20efb79
| @property | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 1.0 |
|
|
||
| model_cpu_offload_seq = "text_encoder->condition_encoder->transformer->vae" | ||
|
|
||
| def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use a scheduler such as FlowMatchEulerDiscreteScheduler to handle the timestep schedule and sampling logic.
| ) | ||
|
|
||
| # 9. Get timestep schedule | ||
| t_schedule = self._get_timestep_schedule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #13095 (comment). In particular, self.scheduler.set_timesteps should be called here. If we want to use a custom sigma schedule, some schedulers (such as FlowMatchEulerDiscreteScheduler) accept it through a sigmas argument to set_timesteps.
| timestep_ratio = 1.0 - current_timestep # t=1 -> ratio=0, t=0 -> ratio=1 | ||
| apply_cfg = do_cfg and (cfg_interval_start <= timestep_ratio <= cfg_interval_end) | ||
|
|
||
| if apply_cfg: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would refactoring the CFG calculation to be batched make sense here? See for example LTX2Pipeline:
diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Lines 1103 to 1104 in 20efb79
| with self.transformer.cache_context("cond_uncond"): | |
| noise_pred_video, noise_pred_audio = self.transformer( |
| # Euler ODE step: x_{t-1} = x_t - v_t * dt | ||
| next_timestep = t_schedule[step_idx + 1].item() | ||
| dt = current_timestep - next_timestep | ||
| dt_tensor = dt * torch.ones((batch_size,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1) | ||
| xt = xt - vt * dt_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #13095 (comment). In particular, the scheduler's step method should be called here instead to get the next latent in the denoising loop.
|
|
||
| if callback is not None and step_idx % callback_steps == 0: | ||
| callback(step_idx, t_curr_tensor, xt) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we support callback_on_step_end here? See for example Flux2Pipeline:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 993 to 1004 in 20efb79
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() |
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Convert ACE-Step model weights to Diffusers pipeline format") | ||
| parser.add_argument( | ||
| "--checkpoint_dir", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the conversion script support supplying a HF hub repo id (such as ACE-Step/Ace-Step1.5) for --checkpoint_dir? I think this would make it easier to use (since you don't need to download a local copy of the repo first).
dg845
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ChuxiJ, thanks for the PR! I have left an initial design review. Happy to help assist with anything :).
What does this PR do?
This PR adds the ACE-Step 1.5 pipeline to Diffusers — a text-to-music generation model that produces high-quality stereo music with lyrics at 48kHz from text prompts.
New Components
AceStepDiTModel(src/diffusers/models/transformers/ace_step_transformer.py): A Diffusion Transformer (DiT) model with RoPE, GQA, sliding window attention, and flow matching for denoising audio latents. Includes custom components:AceStepRMSNorm,AceStepRotaryEmbedding,AceStepMLP,AceStepTimestepEmbedding,AceStepAttention,AceStepEncoderLayer, andAceStepDiTLayer.AceStepConditionEncoder(src/diffusers/pipelines/ace_step/modeling_ace_step.py): Condition encoder that fuses text, lyric, and timbre embeddings into a unified cross-attention conditioning signal. IncludesAceStepLyricEncoderandAceStepTimbreEncodersub-modules.AceStepPipeline(src/diffusers/pipelines/ace_step/pipeline_ace_step.py): The main pipeline supporting 6 task types:text2music— generate music from text and lyricscover— generate from audio semantic codes or with timbre transfer via reference audiorepaint— regenerate a time region within existing audioextract— extract a specific track (vocals, drums, etc.) from audiolego— generate a specific track given audio contextcomplete— complete audio with additional tracksConversion script (
scripts/convert_ace_step_to_diffusers.py): Converts original ACE-Step 1.5 checkpoint weights to Diffusers format.Key Features
_get_task_instructionbpm,keyscale,timesignatureparameters formatted into the SFT prompt templatesrc_audio) and reference audio (reference_audio) inputs with VAE encoding_tiled_encode) and decoding (_tiled_decode) for long audioguidance_scale,cfg_interval_start, andcfg_interval_end(primarily for base/SFT models; turbo models have guidance distilled into weights)audio_cover_strength_parse_audio_code_stringextracts semantic codes from<|audio_code_N|>tokens for cover tasks_build_chunk_maskcreates time-region masks for repaint/lego taskstimestepsshift=3.0Architecture
ACE-Step 1.5 comprises three main components:
Tests
tests/pipelines/ace_step/test_ace_step.py):AceStepDiTModelTests— forward shape, return dict, gradient checkpointingAceStepConditionEncoderTests— forward shape, save/load configAceStepPipelineFastTests(extendsPipelineTesterMixin) — 39 tests covering basic generation, batch processing, latent output, save/load, float16 inference, CPU/model offloading, encode_prompt, prepare_latents, timestep_schedule, format_prompt, and moretests/models/transformers/test_models_transformer_ace_step.py):TestAceStepDiTModel(extendsModelTesterMixin) — forward pass, dtype inference, save/load, determinismTestAceStepDiTModelMemory(extendsMemoryTesterMixin) — layerwise casting, group offloadingTestAceStepDiTModelTraining(extendsTrainingTesterMixin) — training, EMA, gradient checkpointing, mixed precisionAll 70 tests pass (39 pipeline + 31 model).
Documentation
docs/source/en/api/pipelines/ace_step.md— Pipeline API documentation with usage examplesdocs/source/en/api/models/ace_step_transformer.md— Transformer model documentationUsage
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
References