mirror of
https://github.com/outbackdingo/ACE-Step.git
synced 2026-03-21 00:45:57 +00:00
Fix downloads
This commit is contained in:
@@ -17,7 +17,7 @@ from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import math
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
# from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from acestep.schedulers.scheduling_flow_match_euler_discrete import (
|
||||
@@ -44,7 +44,6 @@ from acestep.apg_guidance import (
|
||||
cfg_double_condition_forward,
|
||||
)
|
||||
import torchaudio
|
||||
from .cpu_offload import cpu_offload
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -97,7 +96,6 @@ class ACEStepPipeline:
|
||||
text_encoder_checkpoint_path=None,
|
||||
persistent_storage_path=None,
|
||||
torch_compile=False,
|
||||
cpu_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
if not checkpoint_dir:
|
||||
@@ -124,149 +122,17 @@ class ACEStepPipeline:
|
||||
self.device = device
|
||||
self.loaded = False
|
||||
self.torch_compile = torch_compile
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir=None):
|
||||
device = self.device
|
||||
|
||||
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
||||
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
||||
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
||||
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
||||
|
||||
files_exist = (
|
||||
os.path.exists(os.path.join(dcae_model_path, "config.json"))
|
||||
and os.path.exists(
|
||||
os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")
|
||||
)
|
||||
and os.path.exists(os.path.join(vocoder_model_path, "config.json"))
|
||||
and os.path.exists(
|
||||
os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")
|
||||
)
|
||||
and os.path.exists(os.path.join(ace_step_model_path, "config.json"))
|
||||
and os.path.exists(
|
||||
os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")
|
||||
)
|
||||
and os.path.exists(os.path.join(text_encoder_model_path, "config.json"))
|
||||
and os.path.exists(
|
||||
os.path.join(text_encoder_model_path, "model.safetensors")
|
||||
)
|
||||
and os.path.exists(
|
||||
os.path.join(text_encoder_model_path, "special_tokens_map.json")
|
||||
)
|
||||
)
|
||||
|
||||
if not files_exist:
|
||||
logger.info(
|
||||
f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub"
|
||||
)
|
||||
|
||||
# download music dcae model
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="music_dcae_f8c8",
|
||||
filename="config.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="music_dcae_f8c8",
|
||||
filename="diffusion_pytorch_model.safetensors",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
# download vocoder model
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="music_vocoder",
|
||||
filename="config.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="music_vocoder",
|
||||
filename="diffusion_pytorch_model.safetensors",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
# download ace_step transformer model
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="ace_step_transformer",
|
||||
filename="config.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="ace_step_transformer",
|
||||
filename="diffusion_pytorch_model.safetensors",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
# download text encoder model
|
||||
# os.makedirs(text_encoder_model_path, exist_ok=True) # hf_hub_download should create subdirectories
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="umt5-base",
|
||||
filename="config.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="umt5-base",
|
||||
filename="model.safetensors",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="umt5-base",
|
||||
filename="special_tokens_map.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="umt5-base",
|
||||
filename="tokenizer_config.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
subfolder="umt5-base",
|
||||
filename="tokenizer.json",
|
||||
local_dir=checkpoint_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
# Verify files were downloaded correctly
|
||||
if not all([
|
||||
os.path.exists(os.path.join(dcae_model_path, "config.json")),
|
||||
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")),
|
||||
os.path.exists(os.path.join(vocoder_model_path, "config.json")),
|
||||
os.path.exists(os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")),
|
||||
os.path.exists(os.path.join(ace_step_model_path, "config.json")),
|
||||
os.path.exists(os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")),
|
||||
os.path.exists(os.path.join(text_encoder_model_path, "config.json")),
|
||||
os.path.exists(os.path.join(text_encoder_model_path, "model.safetensors")),
|
||||
os.path.exists(os.path.join(text_encoder_model_path, "special_tokens_map.json")),
|
||||
]):
|
||||
logger.error("Failed to download all required model files. Please check your internet connection and try again.")
|
||||
logger.info(f"DCAE model path: {dcae_model_path}, files exist: {os.path.exists(os.path.join(dcae_model_path, 'config.json'))}")
|
||||
logger.info(f"Vocoder model path: {vocoder_model_path}, files exist: {os.path.exists(os.path.join(vocoder_model_path, 'config.json'))}")
|
||||
logger.info(f"ACE-Step model path: {ace_step_model_path}, files exist: {os.path.exists(os.path.join(ace_step_model_path, 'config.json'))}")
|
||||
logger.info(f"Text encoder model path: {text_encoder_model_path}, files exist: {os.path.exists(os.path.join(text_encoder_model_path, 'config.json'))}")
|
||||
raise RuntimeError("Model download failed. See logs for details.")
|
||||
|
||||
logger.info("Models downloaded successfully")
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir_models = snapshot_download(REPO_ID)
|
||||
else:
|
||||
checkpoint_dir_models = snapshot_download(REPO_ID, cache_dir=checkpoint_dir)
|
||||
dcae_model_path = os.path.join(checkpoint_dir_models, "music_dcae_f8c8")
|
||||
vocoder_model_path = os.path.join(checkpoint_dir_models, "music_vocoder")
|
||||
ace_step_model_path = os.path.join(checkpoint_dir_models, "ace_step_transformer")
|
||||
text_encoder_model_path = os.path.join(checkpoint_dir_models, "umt5-base")
|
||||
|
||||
dcae_checkpoint_path = dcae_model_path
|
||||
vocoder_checkpoint_path = vocoder_model_path
|
||||
@@ -277,20 +143,12 @@ class ACEStepPipeline:
|
||||
dcae_checkpoint_path=dcae_checkpoint_path,
|
||||
vocoder_checkpoint_path=vocoder_checkpoint_path,
|
||||
)
|
||||
# self.music_dcae.to(device).eval().to(self.dtype)
|
||||
if self.cpu_offload: # might be redundant
|
||||
self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype)
|
||||
else:
|
||||
self.music_dcae = self.music_dcae.to(device).eval().to(self.dtype)
|
||||
self.music_dcae.to(device).eval().to(self.dtype)
|
||||
|
||||
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(
|
||||
ace_step_checkpoint_path, torch_dtype=self.dtype
|
||||
)
|
||||
# self.ace_step_transformer.to(device).eval().to(self.dtype)
|
||||
if self.cpu_offload:
|
||||
self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype)
|
||||
else:
|
||||
self.ace_step_transformer = self.ace_step_transformer.to(device).eval().to(self.dtype)
|
||||
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
||||
|
||||
lang_segment = LangSegment()
|
||||
|
||||
@@ -400,11 +258,7 @@ class ACEStepPipeline:
|
||||
text_encoder_model = UMT5EncoderModel.from_pretrained(
|
||||
text_encoder_checkpoint_path, torch_dtype=self.dtype
|
||||
).eval()
|
||||
# text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
||||
if self.cpu_offload:
|
||||
text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype)
|
||||
else:
|
||||
text_encoder_model = text_encoder_model.to(device).eval().to(self.dtype)
|
||||
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
||||
text_encoder_model.requires_grad_(False)
|
||||
self.text_encoder_model = text_encoder_model
|
||||
self.text_tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -418,7 +272,6 @@ class ACEStepPipeline:
|
||||
self.ace_step_transformer = torch.compile(self.ace_step_transformer)
|
||||
self.text_encoder_model = torch.compile(self.text_encoder_model)
|
||||
|
||||
@cpu_offload("text_encoder_model")
|
||||
def get_text_embeddings(self, texts, device, text_max_length=256):
|
||||
inputs = self.text_tokenizer(
|
||||
texts,
|
||||
@@ -436,7 +289,6 @@ class ACEStepPipeline:
|
||||
attention_mask = inputs["attention_mask"]
|
||||
return last_hidden_states, attention_mask
|
||||
|
||||
@cpu_offload("text_encoder_model")
|
||||
def get_text_embeddings_null(
|
||||
self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10
|
||||
):
|
||||
@@ -479,37 +331,28 @@ class ACEStepPipeline:
|
||||
return last_hidden_states
|
||||
|
||||
def set_seeds(self, batch_size, manual_seeds=None):
|
||||
processed_input_seeds = None
|
||||
seeds = None
|
||||
if manual_seeds is not None:
|
||||
if isinstance(manual_seeds, str):
|
||||
if "," in manual_seeds:
|
||||
processed_input_seeds = list(map(int, manual_seeds.split(",")))
|
||||
seeds = list(map(int, manual_seeds.split(",")))
|
||||
elif manual_seeds.isdigit():
|
||||
processed_input_seeds = int(manual_seeds)
|
||||
elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds):
|
||||
if len(manual_seeds) > 0:
|
||||
processed_input_seeds = list(manual_seeds)
|
||||
elif isinstance(manual_seeds, int):
|
||||
processed_input_seeds = manual_seeds
|
||||
seeds = int(manual_seeds)
|
||||
|
||||
random_generators = [
|
||||
torch.Generator(device=self.device) for _ in range(batch_size)
|
||||
]
|
||||
actual_seeds = []
|
||||
for i in range(batch_size):
|
||||
current_seed_for_generator = None
|
||||
if processed_input_seeds is None:
|
||||
current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
|
||||
elif isinstance(processed_input_seeds, int):
|
||||
current_seed_for_generator = processed_input_seeds
|
||||
elif isinstance(processed_input_seeds, list):
|
||||
if i < len(processed_input_seeds):
|
||||
current_seed_for_generator = processed_input_seeds[i]
|
||||
else:
|
||||
current_seed_for_generator = processed_input_seeds[-1]
|
||||
if current_seed_for_generator is None:
|
||||
current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
|
||||
random_generators[i].manual_seed(current_seed_for_generator)
|
||||
actual_seeds.append(current_seed_for_generator)
|
||||
seed = None
|
||||
if seeds is None:
|
||||
seed = torch.randint(0, 2**32, (1,)).item()
|
||||
if isinstance(seeds, int):
|
||||
seed = seeds
|
||||
if isinstance(seeds, list):
|
||||
seed = seeds[i]
|
||||
random_generators[i].manual_seed(seed)
|
||||
actual_seeds.append(seed)
|
||||
return random_generators, actual_seeds
|
||||
|
||||
def get_lang(self, text):
|
||||
@@ -557,7 +400,6 @@ class ACEStepPipeline:
|
||||
print("tokenize error", e, "for line", line, "major_language", lang)
|
||||
return lyric_token_idx
|
||||
|
||||
@cpu_offload("ace_step_transformer")
|
||||
def calc_v(
|
||||
self,
|
||||
zt_src,
|
||||
@@ -846,7 +688,6 @@ class ACEStepPipeline:
|
||||
target_latents = zt_edit if xt_tar is None else xt_tar
|
||||
return target_latents
|
||||
|
||||
@cpu_offload("ace_step_transformer")
|
||||
@torch.no_grad()
|
||||
def text2music_diffusion_process(
|
||||
self,
|
||||
@@ -1363,7 +1204,6 @@ class ACEStepPipeline:
|
||||
)
|
||||
return target_latents
|
||||
|
||||
@cpu_offload("music_dcae")
|
||||
def latents2audio(
|
||||
self,
|
||||
latents,
|
||||
@@ -1371,20 +1211,29 @@ class ACEStepPipeline:
|
||||
sample_rate=48000,
|
||||
save_path=None,
|
||||
format="wav",
|
||||
do_save=True,
|
||||
):
|
||||
output_audio_paths = []
|
||||
bs = latents.shape[0]
|
||||
audio_lengths = [target_wav_duration_second * sample_rate] * bs
|
||||
pred_latents = latents
|
||||
with torch.no_grad():
|
||||
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
|
||||
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
|
||||
for i in tqdm(range(bs)):
|
||||
output_audio_path = self.save_wav_file(
|
||||
pred_wavs[i], i, save_path=save_path, sample_rate=sample_rate, format=format
|
||||
)
|
||||
output_audio_paths.append(output_audio_path)
|
||||
return output_audio_paths
|
||||
if do_save:
|
||||
output_audio_paths = []
|
||||
bs = latents.shape[0]
|
||||
audio_lengths = [target_wav_duration_second * sample_rate] * bs
|
||||
pred_latents = latents
|
||||
with torch.no_grad():
|
||||
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
|
||||
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
|
||||
for i in tqdm(range(bs)):
|
||||
output_audio_path = self.save_wav_file(
|
||||
pred_wavs[i], i, sample_rate=sample_rate
|
||||
)
|
||||
output_audio_paths.append(output_audio_path)
|
||||
return output_audio_paths
|
||||
else:
|
||||
bs = latents.shape[0]
|
||||
pred_latents = latents
|
||||
with torch.no_grad():
|
||||
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
|
||||
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
|
||||
return pred_wavs
|
||||
|
||||
def save_wav_file(
|
||||
self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"
|
||||
@@ -1393,25 +1242,20 @@ class ACEStepPipeline:
|
||||
logger.warning("save_path is None, using default path ./outputs/")
|
||||
base_path = f"./outputs"
|
||||
ensure_directory_exists(base_path)
|
||||
output_path_wav = (
|
||||
f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav"
|
||||
)
|
||||
else:
|
||||
ensure_directory_exists(os.path.dirname(save_path))
|
||||
if os.path.isdir(save_path):
|
||||
logger.info(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.")
|
||||
output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav")
|
||||
else:
|
||||
output_path_wav = save_path
|
||||
|
||||
base_path = save_path
|
||||
ensure_directory_exists(base_path)
|
||||
|
||||
output_path_wav = (
|
||||
f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav"
|
||||
)
|
||||
target_wav = target_wav.float()
|
||||
logger.info(f"Saving audio to {output_path_wav}")
|
||||
print(target_wav)
|
||||
torchaudio.save(
|
||||
output_path_wav, target_wav, sample_rate=sample_rate, format=format
|
||||
)
|
||||
return output_path_wav
|
||||
|
||||
@cpu_offload("music_dcae")
|
||||
def infer_latents(self, input_audio_path):
|
||||
if input_audio_path is None:
|
||||
return None
|
||||
@@ -1457,6 +1301,7 @@ class ACEStepPipeline:
|
||||
format: str = "wav",
|
||||
batch_size: int = 1,
|
||||
debug: bool = False,
|
||||
do_save: bool = True,
|
||||
):
|
||||
|
||||
start_time = time.time()
|
||||
@@ -1640,6 +1485,7 @@ class ACEStepPipeline:
|
||||
target_wav_duration_second=audio_duration,
|
||||
save_path=save_path,
|
||||
format=format,
|
||||
do_save=do_save,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@@ -1683,12 +1529,13 @@ class ACEStepPipeline:
|
||||
"edit_target_lyrics": edit_target_lyrics,
|
||||
}
|
||||
# save input_params_json
|
||||
for output_audio_path in output_paths:
|
||||
input_params_json_save_path = output_audio_path.replace(
|
||||
f".{format}", "_input_params.json"
|
||||
)
|
||||
input_params_json["audio_path"] = output_audio_path
|
||||
with open(input_params_json_save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(input_params_json, f, indent=4, ensure_ascii=False)
|
||||
if do_save:
|
||||
for output_audio_path in output_paths:
|
||||
input_params_json_save_path = output_audio_path.replace(
|
||||
f".{format}", "_input_params.json"
|
||||
)
|
||||
input_params_json["audio_path"] = output_audio_path
|
||||
with open(input_params_json_save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(input_params_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
return output_paths + [input_params_json]
|
||||
|
||||
Reference in New Issue
Block a user