Update pipeline_ace_step.py

This commit is contained in:
mrfakename
2025-05-09 09:49:59 -07:00
committed by GitHub
parent 24f6e73013
commit 25b4c4bbcd

View File

@@ -44,6 +44,7 @@ from acestep.apg_guidance import (
cfg_double_condition_forward,
)
import torchaudio
from .cpu_offload import cpu_offload
torch.backends.cudnn.benchmark = False
@@ -96,6 +97,7 @@ class ACEStepPipeline:
text_encoder_checkpoint_path=None,
persistent_storage_path=None,
torch_compile=False,
cpu_offload=False,
**kwargs,
):
if not checkpoint_dir:
@@ -122,6 +124,7 @@ class ACEStepPipeline:
self.device = device
self.loaded = False
self.torch_compile = torch_compile
self.cpu_offload = cpu_offload
def load_checkpoint(self, checkpoint_dir=None):
device = self.device
@@ -133,7 +136,7 @@ class ACEStepPipeline:
vocoder_model_path = os.path.join(checkpoint_dir_models, "music_vocoder")
ace_step_model_path = os.path.join(checkpoint_dir_models, "ace_step_transformer")
text_encoder_model_path = os.path.join(checkpoint_dir_models, "umt5-base")
dcae_checkpoint_path = dcae_model_path
vocoder_checkpoint_path = vocoder_model_path
ace_step_checkpoint_path = ace_step_model_path
@@ -143,12 +146,20 @@ class ACEStepPipeline:
dcae_checkpoint_path=dcae_checkpoint_path,
vocoder_checkpoint_path=vocoder_checkpoint_path,
)
self.music_dcae.to(device).eval().to(self.dtype)
# self.music_dcae.to(device).eval().to(self.dtype)
if self.cpu_offload: # might be redundant
self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype)
else:
self.music_dcae = self.music_dcae.to(device).eval().to(self.dtype)
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(
ace_step_checkpoint_path, torch_dtype=self.dtype
)
self.ace_step_transformer.to(device).eval().to(self.dtype)
# self.ace_step_transformer.to(device).eval().to(self.dtype)
if self.cpu_offload:
self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype)
else:
self.ace_step_transformer = self.ace_step_transformer.to(device).eval().to(self.dtype)
lang_segment = LangSegment()
@@ -258,7 +269,11 @@ class ACEStepPipeline:
text_encoder_model = UMT5EncoderModel.from_pretrained(
text_encoder_checkpoint_path, torch_dtype=self.dtype
).eval()
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
# text_encoder_model = text_encoder_model.to(device).to(self.dtype)
if self.cpu_offload:
text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype)
else:
text_encoder_model = text_encoder_model.to(device).eval().to(self.dtype)
text_encoder_model.requires_grad_(False)
self.text_encoder_model = text_encoder_model
self.text_tokenizer = AutoTokenizer.from_pretrained(
@@ -272,6 +287,7 @@ class ACEStepPipeline:
self.ace_step_transformer = torch.compile(self.ace_step_transformer)
self.text_encoder_model = torch.compile(self.text_encoder_model)
@cpu_offload("text_encoder_model")
def get_text_embeddings(self, texts, device, text_max_length=256):
inputs = self.text_tokenizer(
texts,
@@ -289,6 +305,7 @@ class ACEStepPipeline:
attention_mask = inputs["attention_mask"]
return last_hidden_states, attention_mask
@cpu_offload("text_encoder_model")
def get_text_embeddings_null(
self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10
):
@@ -331,28 +348,37 @@ class ACEStepPipeline:
return last_hidden_states
def set_seeds(self, batch_size, manual_seeds=None):
seeds = None
processed_input_seeds = None
if manual_seeds is not None:
if isinstance(manual_seeds, str):
if "," in manual_seeds:
seeds = list(map(int, manual_seeds.split(",")))
processed_input_seeds = list(map(int, manual_seeds.split(",")))
elif manual_seeds.isdigit():
seeds = int(manual_seeds)
processed_input_seeds = int(manual_seeds)
elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds):
if len(manual_seeds) > 0:
processed_input_seeds = list(manual_seeds)
elif isinstance(manual_seeds, int):
processed_input_seeds = manual_seeds
random_generators = [
torch.Generator(device=self.device) for _ in range(batch_size)
]
actual_seeds = []
for i in range(batch_size):
seed = None
if seeds is None:
seed = torch.randint(0, 2**32, (1,)).item()
if isinstance(seeds, int):
seed = seeds
if isinstance(seeds, list):
seed = seeds[i]
random_generators[i].manual_seed(seed)
actual_seeds.append(seed)
current_seed_for_generator = None
if processed_input_seeds is None:
current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
elif isinstance(processed_input_seeds, int):
current_seed_for_generator = processed_input_seeds
elif isinstance(processed_input_seeds, list):
if i < len(processed_input_seeds):
current_seed_for_generator = processed_input_seeds[i]
else:
current_seed_for_generator = processed_input_seeds[-1]
if current_seed_for_generator is None:
current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
random_generators[i].manual_seed(current_seed_for_generator)
actual_seeds.append(current_seed_for_generator)
return random_generators, actual_seeds
def get_lang(self, text):
@@ -400,6 +426,7 @@ class ACEStepPipeline:
print("tokenize error", e, "for line", line, "major_language", lang)
return lyric_token_idx
@cpu_offload("ace_step_transformer")
def calc_v(
self,
zt_src,
@@ -688,6 +715,7 @@ class ACEStepPipeline:
target_latents = zt_edit if xt_tar is None else xt_tar
return target_latents
@cpu_offload("ace_step_transformer")
@torch.no_grad()
def text2music_diffusion_process(
self,
@@ -1204,6 +1232,7 @@ class ACEStepPipeline:
)
return target_latents
@cpu_offload("music_dcae")
def latents2audio(
self,
latents,
@@ -1211,29 +1240,20 @@ class ACEStepPipeline:
sample_rate=48000,
save_path=None,
format="wav",
do_save=True,
):
if do_save:
output_audio_paths = []
bs = latents.shape[0]
audio_lengths = [target_wav_duration_second * sample_rate] * bs
pred_latents = latents
with torch.no_grad():
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
for i in tqdm(range(bs)):
output_audio_path = self.save_wav_file(
pred_wavs[i], i, sample_rate=sample_rate
)
output_audio_paths.append(output_audio_path)
return output_audio_paths
else:
bs = latents.shape[0]
pred_latents = latents
with torch.no_grad():
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
return pred_wavs
output_audio_paths = []
bs = latents.shape[0]
audio_lengths = [target_wav_duration_second * sample_rate] * bs
pred_latents = latents
with torch.no_grad():
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
for i in tqdm(range(bs)):
output_audio_path = self.save_wav_file(
pred_wavs[i], i, save_path=save_path, sample_rate=sample_rate, format=format
)
output_audio_paths.append(output_audio_path)
return output_audio_paths
def save_wav_file(
self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"
@@ -1242,20 +1262,25 @@ class ACEStepPipeline:
logger.warning("save_path is None, using default path ./outputs/")
base_path = f"./outputs"
ensure_directory_exists(base_path)
output_path_wav = (
f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav"
)
else:
base_path = save_path
ensure_directory_exists(base_path)
output_path_wav = (
f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav"
)
ensure_directory_exists(os.path.dirname(save_path))
if os.path.isdir(save_path):
logger.info(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.")
output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.wav")
else:
output_path_wav = save_path
target_wav = target_wav.float()
print(target_wav)
logger.info(f"Saving audio to {output_path_wav}")
torchaudio.save(
output_path_wav, target_wav, sample_rate=sample_rate, format=format
)
return output_path_wav
@cpu_offload("music_dcae")
def infer_latents(self, input_audio_path):
if input_audio_path is None:
return None
@@ -1301,7 +1326,6 @@ class ACEStepPipeline:
format: str = "wav",
batch_size: int = 1,
debug: bool = False,
do_save: bool = True,
):
start_time = time.time()
@@ -1485,7 +1509,6 @@ class ACEStepPipeline:
target_wav_duration_second=audio_duration,
save_path=save_path,
format=format,
do_save=do_save,
)
end_time = time.time()
@@ -1529,13 +1552,12 @@ class ACEStepPipeline:
"edit_target_lyrics": edit_target_lyrics,
}
# save input_params_json
if do_save:
for output_audio_path in output_paths:
input_params_json_save_path = output_audio_path.replace(
f".{format}", "_input_params.json"
)
input_params_json["audio_path"] = output_audio_path
with open(input_params_json_save_path, "w", encoding="utf-8") as f:
json.dump(input_params_json, f, indent=4, ensure_ascii=False)
for output_audio_path in output_paths:
input_params_json_save_path = output_audio_path.replace(
f".{format}", "_input_params.json"
)
input_params_json["audio_path"] = output_audio_path
with open(input_params_json_save_path, "w", encoding="utf-8") as f:
json.dump(input_params_json, f, indent=4, ensure_ascii=False)
return output_paths + [input_params_json]