mirror of
https://github.com/outbackdingo/ACE-Step.git
synced 2026-03-21 08:45:58 +00:00
add lora interface
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -201,4 +201,5 @@ app_demo.py
|
||||
ui/components_demo.py
|
||||
data_sampler_demo.py
|
||||
pipeline_ace_step_demo.py
|
||||
*.wav
|
||||
*.wav
|
||||
start.sh
|
||||
|
||||
@@ -3,19 +3,27 @@ from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
DEFAULT_ROOT_DIR = "examples/input_params"
|
||||
|
||||
DEFAULT_ROOT_DIR = "examples/default/input_params"
|
||||
ZH_RAP_LORA_ROOT_DIR = "examples/zh_rap_lora/input_params"
|
||||
|
||||
class DataSampler:
|
||||
def __init__(self, root_dir=DEFAULT_ROOT_DIR):
|
||||
self.root_dir = root_dir
|
||||
self.input_params_files = list(Path(self.root_dir).glob("*.json"))
|
||||
self.zh_rap_lora_input_params_files = list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))
|
||||
|
||||
def load_json(self, file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def sample(self):
|
||||
json_path = random.choice(self.input_params_files)
|
||||
json_data = self.load_json(json_path)
|
||||
def sample(self, lora_name_or_path=None):
|
||||
if lora_name_or_path is None or lora_name_or_path == "none":
|
||||
json_path = random.choice(self.input_params_files)
|
||||
json_data = self.load_json(json_path)
|
||||
else:
|
||||
json_path = random.choice(self.zh_rap_lora_input_params_files)
|
||||
json_data = self.load_json(json_path)
|
||||
# Update the lora_name in the json_data
|
||||
json_data["lora_name_or_path"] = lora_name_or_path
|
||||
|
||||
return json_data
|
||||
|
||||
@@ -113,6 +113,7 @@ class ACEStepPipeline:
|
||||
ensure_directory_exists(checkpoint_dir)
|
||||
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.lora_path = "none"
|
||||
device = (
|
||||
torch.device(f"cuda:{device_id}")
|
||||
if torch.cuda.is_available()
|
||||
@@ -1568,6 +1569,7 @@ class ACEStepPipeline:
|
||||
audio2audio_enable: bool = False,
|
||||
ref_audio_strength: float = 0.5,
|
||||
ref_audio_input: str = None,
|
||||
lora_name_or_path: str = "none",
|
||||
retake_seeds: list = None,
|
||||
retake_variance: float = 0.5,
|
||||
task: str = "text2music",
|
||||
@@ -1596,6 +1598,15 @@ class ACEStepPipeline:
|
||||
self.load_quantized_checkpoint(self.checkpoint_dir)
|
||||
else:
|
||||
self.load_checkpoint(self.checkpoint_dir)
|
||||
# lora_path=lora_name_or_path
|
||||
if lora_name_or_path != "none":
|
||||
self.ace_step_transformer.load_lora_adapter(os.path.join(lora_name_or_path, "pytorch_lora_weights.safetensors"), adapter_name="zh_rap_lora", with_alpha=True)
|
||||
logger.info(f"Loading lora weights from: {lora_name_or_path}")
|
||||
self.lora_path = lora_name_or_path
|
||||
elif self.lora_path != "none" and lora_name_or_path == "none":
|
||||
logger.info("No lora weights to load.")
|
||||
self.ace_step_transformer.unload_lora_weights()
|
||||
|
||||
load_model_cost = time.time() - start_time
|
||||
logger.info(f"Model loaded in {load_model_cost:.2f} seconds.")
|
||||
|
||||
|
||||
@@ -87,7 +87,14 @@ def create_text2music_ui(
|
||||
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
||||
|
||||
# audio2audio
|
||||
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
|
||||
with gr.Row(equal_height=True):
|
||||
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
|
||||
lora_name_or_path = gr.Dropdown(
|
||||
label="Lora Name or Path",
|
||||
choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"],
|
||||
value="none",
|
||||
)
|
||||
|
||||
ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True)
|
||||
ref_audio_strength = gr.Slider(
|
||||
label="Refer audio strength",
|
||||
@@ -290,6 +297,7 @@ def create_text2music_ui(
|
||||
retake_seeds=retake_seeds,
|
||||
retake_variance=retake_variance,
|
||||
task="retake",
|
||||
lora_name_or_path=lora_name_or_path,
|
||||
)
|
||||
|
||||
retake_bnt.click(
|
||||
@@ -412,6 +420,7 @@ def create_text2music_ui(
|
||||
repaint_start=repaint_start,
|
||||
repaint_end=repaint_end,
|
||||
src_audio_path=src_audio_path,
|
||||
lora_name_or_path=lora_name_or_path,
|
||||
)
|
||||
|
||||
repaint_bnt.click(
|
||||
@@ -585,6 +594,7 @@ def create_text2music_ui(
|
||||
edit_n_min=edit_n_min,
|
||||
edit_n_max=edit_n_max,
|
||||
retake_seeds=retake_seeds,
|
||||
lora_name_or_path=lora_name_or_path,
|
||||
)
|
||||
|
||||
edit_bnt.click(
|
||||
@@ -729,6 +739,7 @@ def create_text2music_ui(
|
||||
repaint_start=repaint_start,
|
||||
repaint_end=repaint_end,
|
||||
src_audio_path=src_audio_path,
|
||||
lora_name_or_path=lora_name_or_path,
|
||||
)
|
||||
|
||||
extend_bnt.click(
|
||||
@@ -806,10 +817,16 @@ def create_text2music_ui(
|
||||
if "ref_audio_input" in json_data
|
||||
else None
|
||||
),
|
||||
(
|
||||
json_data["lora_name_or_path"]
|
||||
if "lora_name_or_path" in json_data
|
||||
else "none"
|
||||
)
|
||||
)
|
||||
|
||||
sample_bnt.click(
|
||||
sample_data,
|
||||
inputs=[lora_name_or_path],
|
||||
outputs=[
|
||||
audio_duration,
|
||||
prompt,
|
||||
@@ -859,6 +876,7 @@ def create_text2music_ui(
|
||||
audio2audio_enable,
|
||||
ref_audio_strength,
|
||||
ref_audio_input,
|
||||
lora_name_or_path,
|
||||
],
|
||||
outputs=outputs + [input_params_json],
|
||||
)
|
||||
|
||||
@@ -20,4 +20,5 @@ spacy==3.8.4
|
||||
accelerate==1.6.0
|
||||
cutlet
|
||||
fugashi[unidic-lite]
|
||||
click
|
||||
click
|
||||
peft
|
||||
Reference in New Issue
Block a user