add lora interface

This commit is contained in:
chuxij
2025-05-12 08:09:26 +00:00
parent 933f65fcbc
commit 348cebc7f8
5 changed files with 47 additions and 8 deletions

3
.gitignore vendored
View File

@@ -201,4 +201,5 @@ app_demo.py
ui/components_demo.py
data_sampler_demo.py
pipeline_ace_step_demo.py
*.wav
*.wav
start.sh

View File

@@ -3,19 +3,27 @@ from pathlib import Path
import random
DEFAULT_ROOT_DIR = "examples/input_params"
DEFAULT_ROOT_DIR = "examples/default/input_params"
ZH_RAP_LORA_ROOT_DIR = "examples/zh_rap_lora/input_params"
class DataSampler:
def __init__(self, root_dir=DEFAULT_ROOT_DIR):
self.root_dir = root_dir
self.input_params_files = list(Path(self.root_dir).glob("*.json"))
self.zh_rap_lora_input_params_files = list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))
def load_json(self, file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
def sample(self):
json_path = random.choice(self.input_params_files)
json_data = self.load_json(json_path)
def sample(self, lora_name_or_path=None):
if lora_name_or_path is None or lora_name_or_path == "none":
json_path = random.choice(self.input_params_files)
json_data = self.load_json(json_path)
else:
json_path = random.choice(self.zh_rap_lora_input_params_files)
json_data = self.load_json(json_path)
# Update the lora_name in the json_data
json_data["lora_name_or_path"] = lora_name_or_path
return json_data

View File

@@ -113,6 +113,7 @@ class ACEStepPipeline:
ensure_directory_exists(checkpoint_dir)
self.checkpoint_dir = checkpoint_dir
self.lora_path = "none"
device = (
torch.device(f"cuda:{device_id}")
if torch.cuda.is_available()
@@ -1568,6 +1569,7 @@ class ACEStepPipeline:
audio2audio_enable: bool = False,
ref_audio_strength: float = 0.5,
ref_audio_input: str = None,
lora_name_or_path: str = "none",
retake_seeds: list = None,
retake_variance: float = 0.5,
task: str = "text2music",
@@ -1596,6 +1598,15 @@ class ACEStepPipeline:
self.load_quantized_checkpoint(self.checkpoint_dir)
else:
self.load_checkpoint(self.checkpoint_dir)
# lora_path=lora_name_or_path
if lora_name_or_path != "none":
self.ace_step_transformer.load_lora_adapter(os.path.join(lora_name_or_path, "pytorch_lora_weights.safetensors"), adapter_name="zh_rap_lora", with_alpha=True)
logger.info(f"Loading lora weights from: {lora_name_or_path}")
self.lora_path = lora_name_or_path
elif self.lora_path != "none" and lora_name_or_path == "none":
logger.info("No lora weights to load.")
self.ace_step_transformer.unload_lora_weights()
load_model_cost = time.time() - start_time
logger.info(f"Model loaded in {load_model_cost:.2f} seconds.")

View File

@@ -87,7 +87,14 @@ def create_text2music_ui(
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
# audio2audio
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
with gr.Row(equal_height=True):
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
lora_name_or_path = gr.Dropdown(
label="Lora Name or Path",
choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"],
value="none",
)
ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True)
ref_audio_strength = gr.Slider(
label="Refer audio strength",
@@ -290,6 +297,7 @@ def create_text2music_ui(
retake_seeds=retake_seeds,
retake_variance=retake_variance,
task="retake",
lora_name_or_path=lora_name_or_path,
)
retake_bnt.click(
@@ -412,6 +420,7 @@ def create_text2music_ui(
repaint_start=repaint_start,
repaint_end=repaint_end,
src_audio_path=src_audio_path,
lora_name_or_path=lora_name_or_path,
)
repaint_bnt.click(
@@ -585,6 +594,7 @@ def create_text2music_ui(
edit_n_min=edit_n_min,
edit_n_max=edit_n_max,
retake_seeds=retake_seeds,
lora_name_or_path=lora_name_or_path,
)
edit_bnt.click(
@@ -729,6 +739,7 @@ def create_text2music_ui(
repaint_start=repaint_start,
repaint_end=repaint_end,
src_audio_path=src_audio_path,
lora_name_or_path=lora_name_or_path,
)
extend_bnt.click(
@@ -806,10 +817,16 @@ def create_text2music_ui(
if "ref_audio_input" in json_data
else None
),
(
json_data["lora_name_or_path"]
if "lora_name_or_path" in json_data
else "none"
)
)
sample_bnt.click(
sample_data,
inputs=[lora_name_or_path],
outputs=[
audio_duration,
prompt,
@@ -859,6 +876,7 @@ def create_text2music_ui(
audio2audio_enable,
ref_audio_strength,
ref_audio_input,
lora_name_or_path,
],
outputs=outputs + [input_params_json],
)

View File

@@ -20,4 +20,5 @@ spacy==3.8.4
accelerate==1.6.0
cutlet
fugashi[unidic-lite]
click
click
peft