diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index a948e62..0059a28 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -17,7 +17,7 @@ from loguru import logger from tqdm import tqdm import json import math -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download # from diffusers.pipelines.pipeline_utils import DiffusionPipeline from acestep.schedulers.scheduling_flow_match_euler_discrete import ( @@ -44,7 +44,6 @@ from acestep.apg_guidance import ( cfg_double_condition_forward, ) import torchaudio -from .cpu_offload import cpu_offload torch.backends.cudnn.benchmark = False @@ -97,7 +96,6 @@ class ACEStepPipeline: text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, - cpu_offload=False, **kwargs, ): if not checkpoint_dir: @@ -124,149 +122,17 @@ class ACEStepPipeline: self.device = device self.loaded = False self.torch_compile = torch_compile - self.cpu_offload = cpu_offload def load_checkpoint(self, checkpoint_dir=None): device = self.device - - dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") - vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder") - ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer") - text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base") - - files_exist = ( - os.path.exists(os.path.join(dcae_model_path, "config.json")) - and os.path.exists( - os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors") - ) - and os.path.exists(os.path.join(vocoder_model_path, "config.json")) - and os.path.exists( - os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors") - ) - and os.path.exists(os.path.join(ace_step_model_path, "config.json")) - and os.path.exists( - os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors") - ) - and os.path.exists(os.path.join(text_encoder_model_path, "config.json")) - and os.path.exists( - os.path.join(text_encoder_model_path, "model.safetensors") - ) - and os.path.exists( - os.path.join(text_encoder_model_path, "special_tokens_map.json") - ) - ) - - if not files_exist: - logger.info( - f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub" - ) - - # download music dcae model - hf_hub_download( - repo_id=REPO_ID, - subfolder="music_dcae_f8c8", - filename="config.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="music_dcae_f8c8", - filename="diffusion_pytorch_model.safetensors", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - - # download vocoder model - hf_hub_download( - repo_id=REPO_ID, - subfolder="music_vocoder", - filename="config.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="music_vocoder", - filename="diffusion_pytorch_model.safetensors", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - - # download ace_step transformer model - hf_hub_download( - repo_id=REPO_ID, - subfolder="ace_step_transformer", - filename="config.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="ace_step_transformer", - filename="diffusion_pytorch_model.safetensors", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - - # download text encoder model - # os.makedirs(text_encoder_model_path, exist_ok=True) # hf_hub_download should create subdirectories - hf_hub_download( - repo_id=REPO_ID, - subfolder="umt5-base", - filename="config.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="umt5-base", - filename="model.safetensors", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="umt5-base", - filename="special_tokens_map.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="umt5-base", - filename="tokenizer_config.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - hf_hub_download( - repo_id=REPO_ID, - subfolder="umt5-base", - filename="tokenizer.json", - local_dir=checkpoint_dir, - local_dir_use_symlinks=False, - ) - - # Verify files were downloaded correctly - if not all([ - os.path.exists(os.path.join(dcae_model_path, "config.json")), - os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")), - os.path.exists(os.path.join(vocoder_model_path, "config.json")), - os.path.exists(os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")), - os.path.exists(os.path.join(ace_step_model_path, "config.json")), - os.path.exists(os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")), - os.path.exists(os.path.join(text_encoder_model_path, "config.json")), - os.path.exists(os.path.join(text_encoder_model_path, "model.safetensors")), - os.path.exists(os.path.join(text_encoder_model_path, "special_tokens_map.json")), - ]): - logger.error("Failed to download all required model files. Please check your internet connection and try again.") - logger.info(f"DCAE model path: {dcae_model_path}, files exist: {os.path.exists(os.path.join(dcae_model_path, 'config.json'))}") - logger.info(f"Vocoder model path: {vocoder_model_path}, files exist: {os.path.exists(os.path.join(vocoder_model_path, 'config.json'))}") - logger.info(f"ACE-Step model path: {ace_step_model_path}, files exist: {os.path.exists(os.path.join(ace_step_model_path, 'config.json'))}") - logger.info(f"Text encoder model path: {text_encoder_model_path}, files exist: {os.path.exists(os.path.join(text_encoder_model_path, 'config.json'))}") - raise RuntimeError("Model download failed. See logs for details.") - - logger.info("Models downloaded successfully") + if checkpoint_dir is None: + checkpoint_dir_models = snapshot_download(REPO_ID) + else: + checkpoint_dir_models = snapshot_download(REPO_ID, cache_dir=checkpoint_dir) + dcae_model_path = os.path.join(checkpoint_dir_models, "music_dcae_f8c8") + vocoder_model_path = os.path.join(checkpoint_dir_models, "music_vocoder") + ace_step_model_path = os.path.join(checkpoint_dir_models, "ace_step_transformer") + text_encoder_model_path = os.path.join(checkpoint_dir_models, "umt5-base") dcae_checkpoint_path = dcae_model_path vocoder_checkpoint_path = vocoder_model_path @@ -277,20 +143,12 @@ class ACEStepPipeline: dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path, ) - # self.music_dcae.to(device).eval().to(self.dtype) - if self.cpu_offload: # might be redundant - self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype) - else: - self.music_dcae = self.music_dcae.to(device).eval().to(self.dtype) + self.music_dcae.to(device).eval().to(self.dtype) self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained( ace_step_checkpoint_path, torch_dtype=self.dtype ) - # self.ace_step_transformer.to(device).eval().to(self.dtype) - if self.cpu_offload: - self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype) - else: - self.ace_step_transformer = self.ace_step_transformer.to(device).eval().to(self.dtype) + self.ace_step_transformer.to(device).eval().to(self.dtype) lang_segment = LangSegment() @@ -400,11 +258,7 @@ class ACEStepPipeline: text_encoder_model = UMT5EncoderModel.from_pretrained( text_encoder_checkpoint_path, torch_dtype=self.dtype ).eval() - # text_encoder_model = text_encoder_model.to(device).to(self.dtype) - if self.cpu_offload: - text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype) - else: - text_encoder_model = text_encoder_model.to(device).eval().to(self.dtype) + text_encoder_model = text_encoder_model.to(device).to(self.dtype) text_encoder_model.requires_grad_(False) self.text_encoder_model = text_encoder_model self.text_tokenizer = AutoTokenizer.from_pretrained( @@ -418,7 +272,6 @@ class ACEStepPipeline: self.ace_step_transformer = torch.compile(self.ace_step_transformer) self.text_encoder_model = torch.compile(self.text_encoder_model) - @cpu_offload("text_encoder_model") def get_text_embeddings(self, texts, device, text_max_length=256): inputs = self.text_tokenizer( texts, @@ -436,7 +289,6 @@ class ACEStepPipeline: attention_mask = inputs["attention_mask"] return last_hidden_states, attention_mask - @cpu_offload("text_encoder_model") def get_text_embeddings_null( self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10 ): @@ -479,37 +331,28 @@ class ACEStepPipeline: return last_hidden_states def set_seeds(self, batch_size, manual_seeds=None): - processed_input_seeds = None + seeds = None if manual_seeds is not None: if isinstance(manual_seeds, str): if "," in manual_seeds: - processed_input_seeds = list(map(int, manual_seeds.split(","))) + seeds = list(map(int, manual_seeds.split(","))) elif manual_seeds.isdigit(): - processed_input_seeds = int(manual_seeds) - elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds): - if len(manual_seeds) > 0: - processed_input_seeds = list(manual_seeds) - elif isinstance(manual_seeds, int): - processed_input_seeds = manual_seeds + seeds = int(manual_seeds) + random_generators = [ torch.Generator(device=self.device) for _ in range(batch_size) ] actual_seeds = [] for i in range(batch_size): - current_seed_for_generator = None - if processed_input_seeds is None: - current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() - elif isinstance(processed_input_seeds, int): - current_seed_for_generator = processed_input_seeds - elif isinstance(processed_input_seeds, list): - if i < len(processed_input_seeds): - current_seed_for_generator = processed_input_seeds[i] - else: - current_seed_for_generator = processed_input_seeds[-1] - if current_seed_for_generator is None: - current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() - random_generators[i].manual_seed(current_seed_for_generator) - actual_seeds.append(current_seed_for_generator) + seed = None + if seeds is None: + seed = torch.randint(0, 2**32, (1,)).item() + if isinstance(seeds, int): + seed = seeds + if isinstance(seeds, list): + seed = seeds[i] + random_generators[i].manual_seed(seed) + actual_seeds.append(seed) return random_generators, actual_seeds def get_lang(self, text): @@ -557,7 +400,6 @@ class ACEStepPipeline: print("tokenize error", e, "for line", line, "major_language", lang) return lyric_token_idx - @cpu_offload("ace_step_transformer") def calc_v( self, zt_src, @@ -846,7 +688,6 @@ class ACEStepPipeline: target_latents = zt_edit if xt_tar is None else xt_tar return target_latents - @cpu_offload("ace_step_transformer") @torch.no_grad() def text2music_diffusion_process( self, @@ -1363,7 +1204,6 @@ class ACEStepPipeline: ) return target_latents - @cpu_offload("music_dcae") def latents2audio( self, latents, @@ -1371,20 +1211,29 @@ class ACEStepPipeline: sample_rate=48000, save_path=None, format="wav", + do_save=True, ): - output_audio_paths = [] - bs = latents.shape[0] - audio_lengths = [target_wav_duration_second * sample_rate] * bs - pred_latents = latents - with torch.no_grad(): - _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) - pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] - for i in tqdm(range(bs)): - output_audio_path = self.save_wav_file( - pred_wavs[i], i, save_path=save_path, sample_rate=sample_rate, format=format - ) - output_audio_paths.append(output_audio_path) - return output_audio_paths + if do_save: + output_audio_paths = [] + bs = latents.shape[0] + audio_lengths = [target_wav_duration_second * sample_rate] * bs + pred_latents = latents + with torch.no_grad(): + _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) + pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] + for i in tqdm(range(bs)): + output_audio_path = self.save_wav_file( + pred_wavs[i], i, sample_rate=sample_rate + ) + output_audio_paths.append(output_audio_path) + return output_audio_paths + else: + bs = latents.shape[0] + pred_latents = latents + with torch.no_grad(): + _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) + pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] + return pred_wavs def save_wav_file( self, target_wav, idx, save_path=None, sample_rate=48000, format="wav" @@ -1393,25 +1242,20 @@ class ACEStepPipeline: logger.warning("save_path is None, using default path ./outputs/") base_path = f"./outputs" ensure_directory_exists(base_path) - output_path_wav = ( - f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav" - ) else: - ensure_directory_exists(os.path.dirname(save_path)) - if os.path.isdir(save_path): - logger.info(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.") - output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav") - else: - output_path_wav = save_path - + base_path = save_path + ensure_directory_exists(base_path) + + output_path_wav = ( + f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav" + ) target_wav = target_wav.float() - logger.info(f"Saving audio to {output_path_wav}") + print(target_wav) torchaudio.save( output_path_wav, target_wav, sample_rate=sample_rate, format=format ) return output_path_wav - @cpu_offload("music_dcae") def infer_latents(self, input_audio_path): if input_audio_path is None: return None @@ -1457,6 +1301,7 @@ class ACEStepPipeline: format: str = "wav", batch_size: int = 1, debug: bool = False, + do_save: bool = True, ): start_time = time.time() @@ -1640,6 +1485,7 @@ class ACEStepPipeline: target_wav_duration_second=audio_duration, save_path=save_path, format=format, + do_save=do_save, ) end_time = time.time() @@ -1683,12 +1529,13 @@ class ACEStepPipeline: "edit_target_lyrics": edit_target_lyrics, } # save input_params_json - for output_audio_path in output_paths: - input_params_json_save_path = output_audio_path.replace( - f".{format}", "_input_params.json" - ) - input_params_json["audio_path"] = output_audio_path - with open(input_params_json_save_path, "w", encoding="utf-8") as f: - json.dump(input_params_json, f, indent=4, ensure_ascii=False) + if do_save: + for output_audio_path in output_paths: + input_params_json_save_path = output_audio_path.replace( + f".{format}", "_input_params.json" + ) + input_params_json["audio_path"] = output_audio_path + with open(input_params_json_save_path, "w", encoding="utf-8") as f: + json.dump(input_params_json, f, indent=4, ensure_ascii=False) return output_paths + [input_params_json]