From 348cebc7f82bbd3aa17eadd719a7d8910c1ae5e9 Mon Sep 17 00:00:00 2001 From: chuxij Date: Mon, 12 May 2025 08:09:26 +0000 Subject: [PATCH] add lora interface --- .gitignore | 3 ++- acestep/data_sampler.py | 18 +++++++++++++----- acestep/pipeline_ace_step.py | 11 +++++++++++ acestep/ui/components.py | 20 +++++++++++++++++++- requirements.txt | 3 ++- 5 files changed, 47 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index a12cef8..7adc6f4 100644 --- a/.gitignore +++ b/.gitignore @@ -201,4 +201,5 @@ app_demo.py ui/components_demo.py data_sampler_demo.py pipeline_ace_step_demo.py -*.wav \ No newline at end of file +*.wav +start.sh diff --git a/acestep/data_sampler.py b/acestep/data_sampler.py index 64b1636..5f7131f 100644 --- a/acestep/data_sampler.py +++ b/acestep/data_sampler.py @@ -3,19 +3,27 @@ from pathlib import Path import random -DEFAULT_ROOT_DIR = "examples/input_params" - +DEFAULT_ROOT_DIR = "examples/default/input_params" +ZH_RAP_LORA_ROOT_DIR = "examples/zh_rap_lora/input_params" class DataSampler: def __init__(self, root_dir=DEFAULT_ROOT_DIR): self.root_dir = root_dir self.input_params_files = list(Path(self.root_dir).glob("*.json")) + self.zh_rap_lora_input_params_files = list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json")) def load_json(self, file_path): with open(file_path, "r", encoding="utf-8") as f: return json.load(f) - def sample(self): - json_path = random.choice(self.input_params_files) - json_data = self.load_json(json_path) + def sample(self, lora_name_or_path=None): + if lora_name_or_path is None or lora_name_or_path == "none": + json_path = random.choice(self.input_params_files) + json_data = self.load_json(json_path) + else: + json_path = random.choice(self.zh_rap_lora_input_params_files) + json_data = self.load_json(json_path) + # Update the lora_name in the json_data + json_data["lora_name_or_path"] = lora_name_or_path + return json_data diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 2ee866e..fa510a0 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -113,6 +113,7 @@ class ACEStepPipeline: ensure_directory_exists(checkpoint_dir) self.checkpoint_dir = checkpoint_dir + self.lora_path = "none" device = ( torch.device(f"cuda:{device_id}") if torch.cuda.is_available() @@ -1568,6 +1569,7 @@ class ACEStepPipeline: audio2audio_enable: bool = False, ref_audio_strength: float = 0.5, ref_audio_input: str = None, + lora_name_or_path: str = "none", retake_seeds: list = None, retake_variance: float = 0.5, task: str = "text2music", @@ -1596,6 +1598,15 @@ class ACEStepPipeline: self.load_quantized_checkpoint(self.checkpoint_dir) else: self.load_checkpoint(self.checkpoint_dir) + # lora_path=lora_name_or_path + if lora_name_or_path != "none": + self.ace_step_transformer.load_lora_adapter(os.path.join(lora_name_or_path, "pytorch_lora_weights.safetensors"), adapter_name="zh_rap_lora", with_alpha=True) + logger.info(f"Loading lora weights from: {lora_name_or_path}") + self.lora_path = lora_name_or_path + elif self.lora_path != "none" and lora_name_or_path == "none": + logger.info("No lora weights to load.") + self.ace_step_transformer.unload_lora_weights() + load_model_cost = time.time() - start_time logger.info(f"Model loaded in {load_model_cost:.2f} seconds.") diff --git a/acestep/ui/components.py b/acestep/ui/components.py index 4760a17..05aff86 100644 --- a/acestep/ui/components.py +++ b/acestep/ui/components.py @@ -87,7 +87,14 @@ def create_text2music_ui( sample_bnt = gr.Button("Sample", variant="primary", scale=1) # audio2audio - audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox") + with gr.Row(equal_height=True): + audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox") + lora_name_or_path = gr.Dropdown( + label="Lora Name or Path", + choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"], + value="none", + ) + ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True) ref_audio_strength = gr.Slider( label="Refer audio strength", @@ -290,6 +297,7 @@ def create_text2music_ui( retake_seeds=retake_seeds, retake_variance=retake_variance, task="retake", + lora_name_or_path=lora_name_or_path, ) retake_bnt.click( @@ -412,6 +420,7 @@ def create_text2music_ui( repaint_start=repaint_start, repaint_end=repaint_end, src_audio_path=src_audio_path, + lora_name_or_path=lora_name_or_path, ) repaint_bnt.click( @@ -585,6 +594,7 @@ def create_text2music_ui( edit_n_min=edit_n_min, edit_n_max=edit_n_max, retake_seeds=retake_seeds, + lora_name_or_path=lora_name_or_path, ) edit_bnt.click( @@ -729,6 +739,7 @@ def create_text2music_ui( repaint_start=repaint_start, repaint_end=repaint_end, src_audio_path=src_audio_path, + lora_name_or_path=lora_name_or_path, ) extend_bnt.click( @@ -806,10 +817,16 @@ def create_text2music_ui( if "ref_audio_input" in json_data else None ), + ( + json_data["lora_name_or_path"] + if "lora_name_or_path" in json_data + else "none" + ) ) sample_bnt.click( sample_data, + inputs=[lora_name_or_path], outputs=[ audio_duration, prompt, @@ -859,6 +876,7 @@ def create_text2music_ui( audio2audio_enable, ref_audio_strength, ref_audio_input, + lora_name_or_path, ], outputs=outputs + [input_params_json], ) diff --git a/requirements.txt b/requirements.txt index 0e7bbf3..f504969 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,5 @@ spacy==3.8.4 accelerate==1.6.0 cutlet fugashi[unidic-lite] -click \ No newline at end of file +click +peft \ No newline at end of file