mirror of
https://github.com/outbackdingo/ACE-Step.git
synced 2026-03-20 19:45:37 +00:00
all inference code
This commit is contained in:
65
apg_guidance.py
Normal file
65
apg_guidance.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float = -0.75):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def project(
|
||||
v0: torch.Tensor, # [B, C, H, W]
|
||||
v1: torch.Tensor, # [B, C, H, W]
|
||||
dims=[-1, -2],
|
||||
):
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
||||
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
||||
|
||||
|
||||
def apg_forward(
|
||||
pred_cond: torch.Tensor, # [B, C, H, W]
|
||||
pred_uncond: torch.Tensor, # [B, C, H, W]
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
eta: float = 0.0,
|
||||
norm_threshold: float = 2.5,
|
||||
dims=[-1, -2],
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
# orig_cfg_guided = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# print("======== 新的一轮 =========")
|
||||
# print("原来的diff", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std(), f"cfg会乘上{guidance_scale=}")
|
||||
# print("如果跑cfg orig_cfg_guided", "min:", orig_cfg_guided.min(), "max:", orig_cfg_guided.max(), "mean:", orig_cfg_guided.mean(), "std:", orig_cfg_guided.std())
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
# print("跑完momentum_buffer后", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std(), f"cfg会乘上{guidance_scale=}")
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
|
||||
# print("diff_norm", diff_norm)
|
||||
# 只有比1大的时候(爆音)才会进行缩放
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
# print("跑完norm_threshold scale factor后", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std())
|
||||
|
||||
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
|
||||
# print("跑完project后, diff_parallel", "min:", diff_parallel.min(), "max:", diff_parallel.max(), "mean:", diff_parallel.mean(), "std:", diff_parallel.std())
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
# print("跑完normalized_update后", "min:", normalized_update.min(), "max:", normalized_update.max(), "mean:", normalized_update.mean(), "std:", normalized_update.std())
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
# print("最终pred_guided", "min:", pred_guided.min(), "max:", pred_guided.max(), "mean:", pred_guided.mean(), "std:", pred_guided.std())
|
||||
return pred_guided
|
||||
|
||||
|
||||
def cfg_forward(cond_output, uncond_output, cfg_strength):
|
||||
return uncond_output + cfg_strength * (cond_output - uncond_output)
|
||||
390
demo_infer_pipeline_text2music_v3.py
Normal file
390
demo_infer_pipeline_text2music_v3.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, default="checkpoints/epoch=22-step=460k_pretrained_ft_80k.ckpt")
|
||||
parser.add_argument("--port", type=int, default=7862)
|
||||
parser.add_argument("--device_id", type=int, default=0)
|
||||
parser.add_argument("--share", action='store_true', default=False)
|
||||
parser.add_argument("--bf16", action='store_true', default=True)
|
||||
parser.add_argument("--hide_dataset_sampler", action='store_true', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.nn.functional as F
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from pathlib import Path
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
import json
|
||||
from ui.auth import same_auth
|
||||
from ui.text2music_large_lyric_components_v3 import create_main_demo_ui
|
||||
|
||||
from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
|
||||
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
||||
from apg_guidance import apg_forward, MomentumBuffer, cfg_forward
|
||||
from language_segmentation import LangSegment
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
logger.add("demo_v3.log", level="INFO")
|
||||
|
||||
|
||||
def ensure_directory_exists(directory):
|
||||
directory = str(directory)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
VALID_STRUCTURE_PATTERN = ["hook", "break", "pre-chorus", "solo", "inst", "end", "outro", "bridge", "chorus", "verse", "intro", "start"]
|
||||
|
||||
def is_structure_tag(lin):
|
||||
lin = lin.lower()
|
||||
pattern = re.compile(r"\[.*\]")
|
||||
for tag in VALID_STRUCTURE_PATTERN:
|
||||
if tag in lin and pattern.match(lin):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# 重新tokenize的逻辑
|
||||
SUPPORT_LANGUAGES = {
|
||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
||||
"ko": 6152, "hi": 6680
|
||||
}
|
||||
|
||||
structure_pattern = re.compile(r"\[.*?\]")
|
||||
|
||||
|
||||
class InferDemo:
|
||||
def __init__(self, args):
|
||||
logger.info(f"init model with checkpoint: {args.checkpoint_path}")
|
||||
model_checkpoint_name = "AceFlow3_250401" + Path(args.checkpoint_path).stem
|
||||
if args.bf16:
|
||||
self.dtype = torch.bfloat16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
self.device = "cuda:0"
|
||||
|
||||
self.model_checkpoint_name = model_checkpoint_name
|
||||
|
||||
self.checkpoint_path = ""
|
||||
|
||||
lang_segment = LangSegment()
|
||||
|
||||
lang_segment.setfilters([
|
||||
'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
|
||||
'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
|
||||
'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
|
||||
'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
|
||||
'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
|
||||
'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
|
||||
])
|
||||
self.lang_segment = lang_segment
|
||||
self.lyric_tokenizer = VoiceBpeTokenizer()
|
||||
|
||||
def reload_model(self, checkpoint_path):
|
||||
if checkpoint_path in self.checkpoint_path or self.checkpoint_path == checkpoint_path:
|
||||
return
|
||||
|
||||
logger.info(f"re-init model with checkpoint: {checkpoint_path}")
|
||||
model_checkpoint_name = "AceFlow3_250401" + Path(checkpoint_path).stem
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
from main_text2music_large_sana_dcae_0331_finetune import Pipeline
|
||||
|
||||
model = Pipeline(infer=True, train=False)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
self.model = model.eval().to(self.device).to(self.dtype)
|
||||
self.model_checkpoint_name = model_checkpoint_name
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
|
||||
def save_wav_file(self, target_wav, idx, sample_rate=48000):
|
||||
base_path = f"./test_results/{self.model_checkpoint_name}/demo_outputs"
|
||||
ensure_directory_exists(base_path)
|
||||
# 压缩成mp3
|
||||
output_path_flac = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.flac"
|
||||
target_wav = target_wav.float()
|
||||
torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format='flac', backend="ffmpeg", compression=torchaudio.io.CodecConfig(bit_rate=320000))
|
||||
return output_path_flac
|
||||
|
||||
def set_seeds(self, batch_size, manual_seeds=None):
|
||||
seeds = None
|
||||
if manual_seeds is not None:
|
||||
if isinstance(manual_seeds, str):
|
||||
if "," in manual_seeds:
|
||||
seeds = list(map(int, manual_seeds.split(",")))
|
||||
elif manual_seeds.isdigit():
|
||||
seeds = int(manual_seeds)
|
||||
|
||||
random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
|
||||
actual_seeds = []
|
||||
for i in range(batch_size):
|
||||
seed = None
|
||||
if seeds is None:
|
||||
seed = torch.randint(0, 2**32, (1,)).item()
|
||||
if isinstance(seeds, int):
|
||||
seed = seeds
|
||||
if isinstance(seeds, list):
|
||||
seed = seeds[i]
|
||||
logger.info(f"batch idx: {i}, seed: {seed}")
|
||||
random_generators[i].manual_seed(seed)
|
||||
actual_seeds.append(seed)
|
||||
return random_generators, actual_seeds
|
||||
|
||||
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000):
|
||||
output_audio_paths = []
|
||||
bs = latents.shape[0]
|
||||
audio_lengths = [target_wav_duration_second * sample_rate] * bs
|
||||
pred_latents = latents
|
||||
with torch.no_grad():
|
||||
_, pred_wavs = self.model.vae.decode(pred_latents, sr=sample_rate)
|
||||
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
|
||||
for i in tqdm(range(bs)):
|
||||
output_audio_path = self.save_wav_file(pred_wavs[i], i, sample_rate=sample_rate)
|
||||
output_audio_paths.append(output_audio_path)
|
||||
return output_audio_paths
|
||||
|
||||
def get_lang(self, text):
|
||||
language = "en"
|
||||
try:
|
||||
langs = self.lang_segment.getTexts(text)
|
||||
langCounts = self.lang_segment.getCounts()
|
||||
language = langCounts[0][0]
|
||||
if len(langCounts) > 1 and language == "en":
|
||||
language = langCounts[1][0]
|
||||
except Exception as err:
|
||||
language = "en"
|
||||
return language
|
||||
|
||||
def tokenize_lyrics(self, lyrics, debug=False):
|
||||
lines = lyrics.split("\n")
|
||||
lyric_token_idx = [261]
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lyric_token_idx += [2]
|
||||
continue
|
||||
|
||||
lang = self.get_lang(line)
|
||||
|
||||
if lang not in SUPPORT_LANGUAGES:
|
||||
lang = "en"
|
||||
if "zh" in lang:
|
||||
lang = "zh"
|
||||
if "spa" in lang:
|
||||
lang = "es"
|
||||
|
||||
try:
|
||||
if structure_pattern.match(line):
|
||||
token_idx = self.lyric_tokenizer.encode(line, "en")
|
||||
else:
|
||||
token_idx = self.lyric_tokenizer.encode(line, lang)
|
||||
if debug:
|
||||
toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
|
||||
logger.info(f"debbug {line} --> {lang} --> {toks}")
|
||||
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||
except Exception as e:
|
||||
print("tokenize error", e, "for line", line, "major_language", lang)
|
||||
return lyric_token_idx
|
||||
|
||||
@torch.no_grad()
|
||||
def text2music_diffusion_process(
|
||||
self,
|
||||
duration,
|
||||
encoder_text_hidden_states,
|
||||
text_attention_mask,
|
||||
speaker_embds,
|
||||
lyric_token_ids,
|
||||
lyric_mask,
|
||||
random_generators=None,
|
||||
infer_steps=60,
|
||||
guidance_scale=15.0,
|
||||
omega_scale=10.0,
|
||||
scheduler_type="euler",
|
||||
cfg_type="apg",
|
||||
):
|
||||
|
||||
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
||||
do_classifier_free_guidance = True
|
||||
if guidance_scale == 0.0 or guidance_scale == 1.0:
|
||||
do_classifier_free_guidance = False
|
||||
|
||||
device = encoder_text_hidden_states.device
|
||||
dtype = encoder_text_hidden_states.dtype
|
||||
bsz = encoder_text_hidden_states.shape[0]
|
||||
|
||||
if scheduler_type == "euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = FlowMatchHeunDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
)
|
||||
frame_length = int(duration * 44100 / 512 / 8)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
|
||||
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
||||
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
attention_mask = torch.cat([attention_mask] * 2, dim=0)
|
||||
encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
|
||||
text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
|
||||
|
||||
speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
|
||||
|
||||
lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
|
||||
lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
|
||||
|
||||
momentum_buffer = MomentumBuffer()
|
||||
|
||||
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latents = target_latents
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
noise_pred = self.model.transformers(
|
||||
hidden_states=latent_model_input,
|
||||
attention_mask=attention_mask,
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embds,
|
||||
lyric_token_idx=lyric_token_ids,
|
||||
lyric_mask=lyric_mask,
|
||||
timestep=timestep,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_with_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||
if cfg_type == "apg":
|
||||
noise_pred = apg_forward(
|
||||
pred_cond=noise_pred_with_cond,
|
||||
pred_uncond=noise_pred_uncond,
|
||||
guidance_scale=guidance_scale,
|
||||
momentum_buffer=momentum_buffer,
|
||||
)
|
||||
else:
|
||||
noise_pred = cfg_forward(
|
||||
cond_output=noise_pred_with_cond,
|
||||
uncond_output=noise_pred_uncond,
|
||||
cfg_strength=guidance_scale,
|
||||
)
|
||||
|
||||
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
||||
|
||||
return target_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def process_text2music(
|
||||
self,
|
||||
audio_duration,
|
||||
prompt,
|
||||
lyrics,
|
||||
input_params_json,
|
||||
selected_checkpoint,
|
||||
scheduler_type,
|
||||
cfg_type,
|
||||
infer_step,
|
||||
guidance_scale,
|
||||
omega_scale,
|
||||
manual_seeds,
|
||||
):
|
||||
# 1 check if need to reload model
|
||||
if selected_checkpoint is not None and self.checkpoint_path != selected_checkpoint:
|
||||
self.reload_model(selected_checkpoint)
|
||||
|
||||
batch_size = 2
|
||||
|
||||
# 2 set seed
|
||||
random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
|
||||
|
||||
# 8 x 16 x T//8
|
||||
# 4 prompt
|
||||
texts = [prompt]
|
||||
encoder_text_hidden_states, text_attention_mask = self.model.lyric_processor.get_text_embeddings(texts, self.device)
|
||||
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|
||||
text_attention_mask = text_attention_mask.repeat(batch_size, 1)
|
||||
|
||||
speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
|
||||
|
||||
# 6 lyric
|
||||
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
||||
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
||||
if len(lyrics) > 0:
|
||||
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=True)
|
||||
lyric_mask = [1] * len(lyric_token_idx)
|
||||
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
||||
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
||||
|
||||
if audio_duration <= 0:
|
||||
audio_duration = random.uniform(30.0, 300.0)
|
||||
logger.info(f"random audio duration: {audio_duration}")
|
||||
|
||||
# 7. encode
|
||||
target_latents = self.text2music_diffusion_process(
|
||||
audio_duration,
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embds=speaker_embeds,
|
||||
lyric_token_ids=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
guidance_scale=guidance_scale,
|
||||
omega_scale=omega_scale,
|
||||
infer_steps=infer_step,
|
||||
random_generators=random_generators,
|
||||
scheduler_type=scheduler_type,
|
||||
cfg_type=cfg_type,
|
||||
)
|
||||
|
||||
# 8 latents2audio
|
||||
output_paths = self.latents2audio(latents=target_latents, target_wav_duration_second=audio_duration)
|
||||
if input_params_json is None:
|
||||
input_params_json = {}
|
||||
input_params_json["prompt"] = prompt
|
||||
input_params_json["lyrics"] = lyrics
|
||||
input_params_json["infer_steps"] = infer_step
|
||||
input_params_json["guidance_scale"] = guidance_scale
|
||||
input_params_json["manual_seeds"] = manual_seeds
|
||||
input_params_json["actual_seeds"] = actual_seeds
|
||||
input_params_json["checkpoint_path"] = self.checkpoint_path
|
||||
input_params_json["omega_scale"] = omega_scale
|
||||
input_params_json["scheduler_type"] = scheduler_type
|
||||
input_params_json["cfg_type"] = cfg_type
|
||||
input_params_json["audio_duration"] = audio_duration
|
||||
logger.info(json.dumps(input_params_json, indent=4, ensure_ascii=False))
|
||||
|
||||
return output_paths + [input_params_json]
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
model_demo = InferDemo(args)
|
||||
|
||||
demo = create_main_demo_ui(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
text2music_process_func=model_demo.process_text2music,
|
||||
)
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.port,
|
||||
auth=same_auth,
|
||||
share=args.share
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
866
language_segmentation/LangSegment.py
Normal file
866
language_segmentation/LangSegment.py
Normal file
@@ -0,0 +1,866 @@
|
||||
"""
|
||||
This file bundles language identification functions.
|
||||
|
||||
Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
|
||||
|
||||
Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
|
||||
Based on research by Marco Lui and Tim Baldwin.
|
||||
|
||||
See LICENSE file for more info.
|
||||
https://github.com/adbar/py3langid
|
||||
|
||||
Projects:
|
||||
https://github.com/juntaosun/LangSegment
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from collections import defaultdict
|
||||
|
||||
# import langid
|
||||
# import py3langid as langid
|
||||
# pip install py3langid==0.2.2
|
||||
|
||||
# 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
|
||||
# langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
|
||||
# For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
|
||||
from py3langid.langid import LanguageIdentifier, MODEL_FILE
|
||||
|
||||
# Digital processing
|
||||
try:from .utils.num import num2str
|
||||
except ImportError:
|
||||
try:from utils.num import num2str
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
||||
# -----------------------------------
|
||||
# 更新日志:新版本分词更加精准。
|
||||
# Changelog: The new version of the word segmentation is more accurate.
|
||||
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
||||
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
||||
# -----------------------------------
|
||||
|
||||
|
||||
# Word segmentation function:
|
||||
# automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
|
||||
# making it more suitable for TTS processing.
|
||||
# This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
|
||||
# This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
|
||||
|
||||
#===========================================================================================================
|
||||
#分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
|
||||
#このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
|
||||
#===========================================================================================================
|
||||
#(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
|
||||
#(2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
|
||||
#この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
|
||||
#===========================================================================================================
|
||||
|
||||
#===========================================================================================================
|
||||
# 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
|
||||
# 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
|
||||
#===========================================================================================================
|
||||
# (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
|
||||
# (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
|
||||
# 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
|
||||
#===========================================================================================================
|
||||
|
||||
# ===========================================================================================================
|
||||
# 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
|
||||
# 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
|
||||
# ===========================================================================================================
|
||||
# (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
|
||||
# (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
|
||||
# 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
|
||||
# ===========================================================================================================
|
||||
|
||||
|
||||
# 手动分词标签规范:<语言标签>文本内容</语言标签>
|
||||
# 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
|
||||
# Manual word segmentation tag specification: <language tags> text content </language tags>
|
||||
# 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
|
||||
# ===========================================================================================================
|
||||
# For manual word segmentation, labels need to appear in pairs, such as:
|
||||
# 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
|
||||
# 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
|
||||
# Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
|
||||
# ===========================================================================================================
|
||||
|
||||
|
||||
# ===========================================================================================================
|
||||
# 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
|
||||
# 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
|
||||
# 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
|
||||
# Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
|
||||
# ===========================================================================================================
|
||||
# 中文实现:Chinese implementation:
|
||||
# 【SSML】<number>=中文大写数字读法(单字)
|
||||
# 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
|
||||
# 【SSML】<currency>=按金额发音。
|
||||
# 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
||||
# ===========================================================================================================
|
||||
class LangSSML:
|
||||
|
||||
def __init__(self):
|
||||
# 纯数字
|
||||
self._zh_numerals_number = {
|
||||
'0': '零',
|
||||
'1': '一',
|
||||
'2': '二',
|
||||
'3': '三',
|
||||
'4': '四',
|
||||
'5': '五',
|
||||
'6': '六',
|
||||
'7': '七',
|
||||
'8': '八',
|
||||
'9': '九'
|
||||
}
|
||||
|
||||
# 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
|
||||
# Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
|
||||
def _format_chinese_data(self, date_str:str):
|
||||
# 处理日期格式
|
||||
input_date = date_str
|
||||
if date_str is None or date_str.strip() == "":return ""
|
||||
date_str = re.sub(r"[\/\._|年|月]","-",date_str)
|
||||
date_str = re.sub(r"日",r"",date_str)
|
||||
date_arrs = date_str.split(' ')
|
||||
if len(date_arrs) == 1 and ":" in date_arrs[0]:
|
||||
time_str = date_arrs[0]
|
||||
date_arrs = []
|
||||
else:
|
||||
time_str = date_arrs[1] if len(date_arrs) >=2 else ""
|
||||
def nonZero(num,cn,func=None):
|
||||
if func is not None:num=func(num)
|
||||
return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
|
||||
f_number = self.to_chinese_number
|
||||
f_currency = self.to_chinese_currency
|
||||
# year, month, day
|
||||
year_month_day = ""
|
||||
if len(date_arrs) > 0:
|
||||
year, month, day = "","",""
|
||||
parts = date_arrs[0].split('-')
|
||||
if len(parts) == 3: # 格式为 YYYY-MM-DD
|
||||
year, month, day = parts
|
||||
elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
|
||||
if len(parts[0]) == 4: # 年-月
|
||||
year, month = parts
|
||||
else:month, day = parts # 月-日
|
||||
elif len(parts[0]) > 0: # 仅有月-日或年
|
||||
if len(parts[0]) == 4:
|
||||
year = parts[0]
|
||||
else:day = parts[0]
|
||||
year,month,day = nonZero(year,"年",f_number),nonZero(month,"月",f_currency),nonZero(day,"日",f_currency)
|
||||
year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
|
||||
# hours, minutes, seconds
|
||||
time_str = re.sub(r"[\/\.\-:_]",":",time_str)
|
||||
time_arrs = time_str.split(":")
|
||||
hours, minutes, seconds = "","",""
|
||||
if len(time_arrs) == 3: # H/M/S
|
||||
hours, minutes, seconds = time_arrs
|
||||
elif len(time_arrs) == 2:# H/M
|
||||
hours, minutes = time_arrs
|
||||
elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}点' # H
|
||||
if len(time_arrs) > 1:
|
||||
hours, minutes, seconds = nonZero(hours,"点",f_currency),nonZero(minutes,"分",f_currency),nonZero(seconds,"秒",f_currency)
|
||||
hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
|
||||
output_date = f"{year_month_day}{hours_minutes_seconds}"
|
||||
return output_date
|
||||
|
||||
# 【SSML】number=中文大写数字读法(单字)
|
||||
# Chinese Numbers(single word)
|
||||
def to_chinese_number(self, num:str):
|
||||
pattern = r'(\d+)'
|
||||
zh_numerals = self._zh_numerals_number
|
||||
arrs = re.split(pattern, num)
|
||||
output = ""
|
||||
for item in arrs:
|
||||
if re.match(pattern,item):
|
||||
output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
|
||||
else:output += item
|
||||
output = output.replace(".","点")
|
||||
return output
|
||||
|
||||
# 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
|
||||
# Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
|
||||
def to_chinese_telephone(self, num:str):
|
||||
output = self.to_chinese_number(num.replace("+86","")) # zh +86
|
||||
output = output.replace("一","幺")
|
||||
return output
|
||||
|
||||
# 【SSML】currency=按金额发音。
|
||||
# Digital processing from GPT_SoVITS num.py (thanks)
|
||||
def to_chinese_currency(self, num:str):
|
||||
pattern = r'(\d+)'
|
||||
arrs = re.split(pattern, num)
|
||||
output = ""
|
||||
for item in arrs:
|
||||
if re.match(pattern,item):
|
||||
output += num2str(item)
|
||||
else:output += item
|
||||
output = output.replace(".","点")
|
||||
return output
|
||||
|
||||
# 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
||||
def to_chinese_date(self, num:str):
|
||||
chinese_date = self._format_chinese_data(num)
|
||||
return chinese_date
|
||||
|
||||
|
||||
class LangSegment:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
|
||||
|
||||
self._text_cache = None
|
||||
self._text_lasts = None
|
||||
self._text_langs = None
|
||||
self._lang_count = None
|
||||
self._lang_eos = None
|
||||
|
||||
# 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
|
||||
# Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
|
||||
# <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
|
||||
self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
|
||||
|
||||
# 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
|
||||
# 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
|
||||
# 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
|
||||
# The language filter group function allows you to specify reserved languages.
|
||||
# Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
|
||||
# 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
|
||||
|
||||
# 系统默认过滤器。System default filter。(ISO 639-1 codes given)
|
||||
# ----------------------------------------------------------------------------------------------------------------------------------
|
||||
# "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
|
||||
# "th"泰语=Thai
|
||||
# ----------------------------------------------------------------------------------------------------------------------------------
|
||||
self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
|
||||
|
||||
# 用户可自定义过滤器。User-defined filters
|
||||
self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本
|
||||
|
||||
# 合并文本
|
||||
self.isLangMerge = True
|
||||
|
||||
# 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
|
||||
# 请使用API启用:self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
|
||||
|
||||
# 预览版功能,自动启用或禁用,无需设置
|
||||
# Preview feature, automatically enabled or disabled, no settings required
|
||||
self.EnablePreview = False
|
||||
|
||||
# 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
|
||||
# In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
|
||||
# 示例:您可以任意指定多种组合,进行过滤
|
||||
# Example: You can specify any combination to filter
|
||||
|
||||
# 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
|
||||
# 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
|
||||
# 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
|
||||
# Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
|
||||
# Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
|
||||
self.LangPriorityThreshold = 0.89
|
||||
|
||||
# Langfilters = ["zh"] # 按中文识别
|
||||
# Langfilters = ["en"] # 按英文识别
|
||||
# Langfilters = ["ja"] # 按日文识别
|
||||
# Langfilters = ["ko"] # 按韩文识别
|
||||
# Langfilters = ["zh_ja"] # 中日混合识别
|
||||
# Langfilters = ["zh_en"] # 中英混合识别
|
||||
# Langfilters = ["ja_en"] # 日英混合识别
|
||||
# Langfilters = ["zh_ko"] # 中韩混合识别
|
||||
# Langfilters = ["ja_ko"] # 日韩混合识别
|
||||
# Langfilters = ["en_ko"] # 英韩混合识别
|
||||
# Langfilters = ["zh_ja_en"] # 中日英混合识别
|
||||
# Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
|
||||
|
||||
# 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
|
||||
# より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
|
||||
|
||||
# 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
|
||||
# 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
|
||||
self.keepPinyin = False
|
||||
|
||||
# DEFINITION
|
||||
self.PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
|
||||
|
||||
self.LangSSML = LangSSML()
|
||||
|
||||
def _clears(self):
|
||||
self._text_cache = None
|
||||
self._text_lasts = None
|
||||
self._text_langs = None
|
||||
self._text_waits = None
|
||||
self._lang_count = None
|
||||
self._lang_eos = None
|
||||
|
||||
def _is_english_word(self, word):
|
||||
return bool(re.match(r'^[a-zA-Z]+$', word))
|
||||
|
||||
def _is_chinese(self, word):
|
||||
for char in word:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_japanese_kana(self, word):
|
||||
pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
|
||||
matches = pattern.findall(word)
|
||||
return len(matches) > 0
|
||||
|
||||
def _insert_english_uppercase(self, word):
|
||||
modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
|
||||
modified_text = modified_text.strip('-')
|
||||
return modified_text + " "
|
||||
|
||||
def _split_camel_case(self, word):
|
||||
return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
|
||||
|
||||
def _statistics(self, language, text):
|
||||
# Language word statistics:
|
||||
# Chinese characters usually occupy double bytes
|
||||
if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
|
||||
self._lang_count = defaultdict(int)
|
||||
lang_count = self._lang_count
|
||||
if not "|" in language:
|
||||
lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
|
||||
self._lang_count = lang_count
|
||||
|
||||
def _clear_text_number(self, text):
|
||||
if text == "\n":return text,False # Keep Line Breaks
|
||||
clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
|
||||
is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
|
||||
return clear_text,is_number
|
||||
|
||||
def _saveData(self, words,language:str,text:str,score:float,symbol=None):
|
||||
# Pre-detection
|
||||
clear_text , is_number = self._clear_text_number(text)
|
||||
# Merge the same language and save the results
|
||||
preData = words[-1] if len(words) > 0 else None
|
||||
if symbol is not None:pass
|
||||
elif preData is not None and preData["symbol"] is None:
|
||||
if len(clear_text) == 0:language = preData["lang"]
|
||||
elif is_number == True:language = preData["lang"]
|
||||
_ , pre_is_number = self._clear_text_number(preData["text"])
|
||||
if (preData["lang"] == language):
|
||||
self._statistics(preData["lang"],text)
|
||||
text = preData["text"] + text
|
||||
preData["text"] = text
|
||||
return preData
|
||||
elif pre_is_number == True:
|
||||
text = f'{preData["text"]}{text}'
|
||||
words.pop()
|
||||
elif is_number == True:
|
||||
priority_language = self._get_filters_string()[:2]
|
||||
if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
|
||||
data = {"lang":language,"text": text,"score":score,"symbol":symbol}
|
||||
filters = self.Langfilters
|
||||
if filters is None or len(filters) == 0 or "?" in language or \
|
||||
language in filters or language in filters[0] or \
|
||||
filters[0] == "*" or filters[0] in "alls-mixs-autos":
|
||||
words.append(data)
|
||||
self._statistics(data["lang"],data["text"])
|
||||
return data
|
||||
|
||||
def _addwords(self, words,language,text,score,symbol=None):
|
||||
if text == "\n":pass # Keep Line Breaks
|
||||
elif text is None or len(text.strip()) == 0:return True
|
||||
if language is None:language = ""
|
||||
language = language.lower()
|
||||
if language == 'en':text = self._insert_english_uppercase(text)
|
||||
# text = re.sub(r'[(())]', ',' , text) # Keep it.
|
||||
text_waits = self._text_waits
|
||||
ispre_waits = len(text_waits)>0
|
||||
preResult = text_waits.pop() if ispre_waits else None
|
||||
if preResult is None:preResult = words[-1] if len(words) > 0 else None
|
||||
if preResult and ("|" in preResult["lang"]):
|
||||
pre_lang = preResult["lang"]
|
||||
if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
|
||||
else:preResult["lang"]=pre_lang.split("|")[0]
|
||||
if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
|
||||
pre_lang = preResult["lang"] if preResult else None
|
||||
if ("|" in language) and (pre_lang and not pre_lang in language and not "…" in language):language = language.split("|")[0]
|
||||
if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
|
||||
else:self._saveData(words,language,text,score,symbol)
|
||||
return False
|
||||
|
||||
def _get_prev_data(self, words):
|
||||
data = words[-1] if words and len(words) > 0 else None
|
||||
if data:return (data["lang"] , data["text"])
|
||||
return (None,"")
|
||||
|
||||
def _match_ending(self, input , index):
|
||||
if input is None or len(input) == 0:return False,None
|
||||
input = re.sub(r'\s+', '', input)
|
||||
if len(input) == 0 or abs(index) > len(input):return False,None
|
||||
ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
|
||||
return ending_pattern.match(input[index]),input[index]
|
||||
|
||||
def _cleans_text(self, cleans_text):
|
||||
cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
|
||||
cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
|
||||
return cleans_text.strip()
|
||||
|
||||
def _mean_processing(self, text:str):
|
||||
if text is None or (text.strip()) == "":return None , 0.0
|
||||
arrs = self._split_camel_case(text).split(" ")
|
||||
langs = []
|
||||
for t in arrs:
|
||||
if len(t.strip()) <= 3:continue
|
||||
language, score = self.langid.classify(t)
|
||||
langs.append({"lang":language})
|
||||
if len(langs) == 0:return None , 0.0
|
||||
return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
|
||||
|
||||
def _lang_classify(self, cleans_text):
|
||||
language, score = self.langid.classify(cleans_text)
|
||||
# fix: Huggingface is np.float32
|
||||
if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
|
||||
score = score.item()
|
||||
score = round(score , 3)
|
||||
return language, score
|
||||
|
||||
def _get_filters_string(self):
|
||||
filters = self.Langfilters
|
||||
return "-".join(filters).lower().strip() if filters is not None else ""
|
||||
|
||||
def _parse_language(self, words , segment):
|
||||
LANG_JA = "ja"
|
||||
LANG_ZH = "zh"
|
||||
LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
|
||||
LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
|
||||
language = LANG_ZH
|
||||
regex_pattern = re.compile(r'([^\w\s]+)')
|
||||
lines = regex_pattern.split(segment)
|
||||
lines_max = len(lines)
|
||||
LANG_EOS =self._lang_eos
|
||||
for index, text in enumerate(lines):
|
||||
if len(text) == 0:continue
|
||||
EOS = index >= (lines_max - 1)
|
||||
nextId = index + 1
|
||||
nextText = lines[nextId] if not EOS else ""
|
||||
nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
|
||||
textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
|
||||
if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
|
||||
lines[nextId] = f'{text}{nextText}'
|
||||
continue
|
||||
number_tags = re.compile(r'(⑥\d{6,}⑥)')
|
||||
cleans_text = re.sub(number_tags, '' ,text)
|
||||
cleans_text = re.sub(r'\d+', '' ,cleans_text)
|
||||
cleans_text = self._cleans_text(cleans_text)
|
||||
# fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
|
||||
if not EOS and len(cleans_text) <= 2:
|
||||
lines[nextId] = f'{text}{nextText}'
|
||||
continue
|
||||
language,score = self._lang_classify(cleans_text)
|
||||
prev_language , prev_text = self._get_prev_data(words)
|
||||
if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
|
||||
if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
|
||||
filters_string = self._get_filters_string()
|
||||
if score < self.LangPriorityThreshold and len(filters_string) > 0:
|
||||
index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
|
||||
if index_ja != -1 and index_ja < index_zh:language = LANG_JA
|
||||
elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
|
||||
if self._is_japanese_kana(cleans_text):language = LANG_JA
|
||||
elif len(cleans_text) > 2 and score > 0.90:pass
|
||||
elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
|
||||
else:
|
||||
LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
|
||||
match_end,match_char = self._match_ending(text, -1)
|
||||
referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
|
||||
if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
|
||||
else:language = f"{LANG_UNKNOWN}|…"
|
||||
text,*_ = re.subn(number_tags , self._restore_number , text )
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 【SSML】中文数字处理:Chinese Number Processing (SSML support)
|
||||
# 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
|
||||
# The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
|
||||
# 中文电话号码:<telephone>1234567</telephone>
|
||||
# 中文数字号码:<number>1234567</number>
|
||||
def _process_symbol_SSML(self, words,data):
|
||||
tag , match = data
|
||||
language = SSML = match[1]
|
||||
text = match[2]
|
||||
score = 1.0
|
||||
if SSML == "telephone":
|
||||
# 中文-电话号码
|
||||
language = "zh"
|
||||
text = self.LangSSML.to_chinese_telephone(text)
|
||||
elif SSML == "number":
|
||||
# 中文-数字读法
|
||||
language = "zh"
|
||||
text = self.LangSSML.to_chinese_number(text)
|
||||
elif SSML == "currency":
|
||||
# 中文-按金额发音
|
||||
language = "zh"
|
||||
text = self.LangSSML.to_chinese_currency(text)
|
||||
elif SSML == "date":
|
||||
# 中文-按金额发音
|
||||
language = "zh"
|
||||
text = self.LangSSML.to_chinese_date(text)
|
||||
self._addwords(words,language,text,score,SSML)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
def _restore_number(self, matche):
|
||||
value = matche.group(0)
|
||||
text_cache = self._text_cache
|
||||
if value in text_cache:
|
||||
process , data = text_cache[value]
|
||||
tag , match = data
|
||||
value = match
|
||||
return value
|
||||
|
||||
def _pattern_symbols(self, item , text):
|
||||
if text is None:return text
|
||||
tag , pattern , process = item
|
||||
matches = pattern.findall(text)
|
||||
if len(matches) == 1 and "".join(matches[0]) == text:
|
||||
return text
|
||||
for i , match in enumerate(matches):
|
||||
key = f"⑥{tag}{i:06d}⑥"
|
||||
text = re.sub(pattern , key , text , count=1)
|
||||
self._text_cache[key] = (process , (tag , match))
|
||||
return text
|
||||
|
||||
def _process_symbol(self, words,data):
|
||||
tag , match = data
|
||||
language = match[1]
|
||||
text = match[2]
|
||||
score = 1.0
|
||||
filters = self._get_filters_string()
|
||||
if language not in filters:
|
||||
self._process_symbol_SSML(words,data)
|
||||
else:
|
||||
self._addwords(words,language,text,score,True)
|
||||
|
||||
def _process_english(self, words,data):
|
||||
tag , match = data
|
||||
text = match[0]
|
||||
filters = self._get_filters_string()
|
||||
priority_language = filters[:2]
|
||||
# Preview feature, other language segmentation processing
|
||||
enablePreview = self.EnablePreview
|
||||
if enablePreview == True:
|
||||
# Experimental: Other language support
|
||||
regex_pattern = re.compile(r'(.*?[。.??!!]+[\n]{,1})')
|
||||
lines = regex_pattern.split(text)
|
||||
for index , text in enumerate(lines):
|
||||
if len(text.strip()) == 0:continue
|
||||
cleans_text = self._cleans_text(text)
|
||||
language,score = self._lang_classify(cleans_text)
|
||||
if language not in filters:
|
||||
language,score = self._mean_processing(cleans_text)
|
||||
if language is None or score <= 0.0:continue
|
||||
elif language in filters:pass # pass
|
||||
elif score >= 0.95:continue # High score, but not in the filter, excluded.
|
||||
elif score <= 0.15 and filters[:2] == "fr":language = priority_language
|
||||
else:language = "en"
|
||||
self._addwords(words,language,text,score)
|
||||
else:
|
||||
# Default is English
|
||||
language, score = "en", 1.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_Russian(self, words,data):
|
||||
tag , match = data
|
||||
text = match[0]
|
||||
language = "ru"
|
||||
score = 1.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_Thai(self, words,data):
|
||||
tag , match = data
|
||||
text = match[0]
|
||||
language = "th"
|
||||
score = 1.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_korean(self, words,data):
|
||||
tag , match = data
|
||||
text = match[0]
|
||||
language = "ko"
|
||||
score = 1.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_quotes(self, words,data):
|
||||
tag , match = data
|
||||
text = "".join(match)
|
||||
childs = self.PARSE_TAG.findall(text)
|
||||
if len(childs) > 0:
|
||||
self._process_tags(words , text , False)
|
||||
else:
|
||||
cleans_text = self._cleans_text(match[1])
|
||||
if len(cleans_text) <= 5:
|
||||
self._parse_language(words,text)
|
||||
else:
|
||||
language,score = self._lang_classify(cleans_text)
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_pinyin(self, words,data):
|
||||
tag , match = data
|
||||
text = match
|
||||
language = "zh"
|
||||
score = 1.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_number(self, words,data): # "$0" process only
|
||||
"""
|
||||
Numbers alone cannot accurately identify language.
|
||||
Because numbers are universal in all languages.
|
||||
So it won't be executed here, just for testing.
|
||||
"""
|
||||
tag , match = data
|
||||
language = words[0]["lang"] if len(words) > 0 else "zh"
|
||||
text = match
|
||||
score = 0.0
|
||||
self._addwords(words,language,text,score)
|
||||
|
||||
def _process_tags(self, words , text , root_tag):
|
||||
text_cache = self._text_cache
|
||||
segments = re.split(self.PARSE_TAG, text)
|
||||
segments_len = len(segments) - 1
|
||||
for index , text in enumerate(segments):
|
||||
if root_tag:self._lang_eos = index >= segments_len
|
||||
if self.PARSE_TAG.match(text):
|
||||
process , data = text_cache[text]
|
||||
if process:process(words , data)
|
||||
else:
|
||||
self._parse_language(words , text)
|
||||
return words
|
||||
|
||||
def _merge_results(self, words):
|
||||
new_word = []
|
||||
for index , cur_data in enumerate(words):
|
||||
if "symbol" in cur_data:del cur_data["symbol"]
|
||||
if index == 0:new_word.append(cur_data)
|
||||
else:
|
||||
pre_data = new_word[-1]
|
||||
if cur_data["lang"] == pre_data["lang"]:
|
||||
pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
|
||||
else:new_word.append(cur_data)
|
||||
return new_word
|
||||
|
||||
def _parse_symbols(self, text):
|
||||
TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
|
||||
TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
|
||||
TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
|
||||
# Get custom language filter
|
||||
filters = self.Langfilters
|
||||
filters = filters if filters is not None else ""
|
||||
# =======================================================================================================
|
||||
# Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
|
||||
# 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
|
||||
# If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
|
||||
# S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
|
||||
# Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
|
||||
# -------------------------------------------------------------------------------------------------------
|
||||
# Preview feature, other language support
|
||||
enablePreview = self.EnablePreview
|
||||
if "fr" in filters or \
|
||||
"vi" in filters:enablePreview = True
|
||||
self.EnablePreview = enablePreview
|
||||
# 实验性:法语字符支持。Prise en charge des caractères français
|
||||
RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
|
||||
# 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
|
||||
RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
|
||||
# -------------------------------------------------------------------------------------------------------
|
||||
# Basic options:
|
||||
process_list = [
|
||||
( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
|
||||
( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
|
||||
( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
|
||||
( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
|
||||
( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
|
||||
( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
|
||||
( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
|
||||
( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
|
||||
]
|
||||
# Extended options: Default False
|
||||
if self.keepPinyin == True:process_list.insert(1 ,
|
||||
( TAG_S2 , re.compile(r'([\(({](?:\s*\w*\d\w*\s*)+[})\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
|
||||
)
|
||||
# -------------------------------------------------------------------------------------------------------
|
||||
words = []
|
||||
lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
|
||||
for index , text in enumerate(lines):
|
||||
if len(text.strip()) == 0:continue
|
||||
self._lang_eos = False
|
||||
self._text_cache = {}
|
||||
for item in process_list:
|
||||
text = self._pattern_symbols(item , text)
|
||||
cur_word = self._process_tags([] , text , True)
|
||||
if len(cur_word) == 0:continue
|
||||
cur_data = cur_word[0] if len(cur_word) > 0 else None
|
||||
pre_data = words[-1] if len(words) > 0 else None
|
||||
if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] \
|
||||
and cur_data["symbol"] == False and pre_data["symbol"] :
|
||||
cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
|
||||
words.pop()
|
||||
words += cur_word
|
||||
if self.isLangMerge == True:words = self._merge_results(words)
|
||||
lang_count = self._lang_count
|
||||
if lang_count and len(lang_count) > 0:
|
||||
lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
|
||||
lang_count = list(lang_count.items())
|
||||
self._lang_count = lang_count
|
||||
return words
|
||||
|
||||
def setfilters(self, filters):
|
||||
# 当过滤器更改时,清除缓存
|
||||
# 필터가 변경되면 캐시를 지웁니다.
|
||||
# フィルタが変更されると、キャッシュがクリアされます
|
||||
# When the filter changes, clear the cache
|
||||
if self.Langfilters != filters:
|
||||
self._clears()
|
||||
self.Langfilters = filters
|
||||
|
||||
def getfilters(self):
|
||||
return self.Langfilters
|
||||
|
||||
def setPriorityThreshold(self, threshold:float):
|
||||
self.LangPriorityThreshold = threshold
|
||||
|
||||
def getPriorityThreshold(self):
|
||||
return self.LangPriorityThreshold
|
||||
|
||||
def getCounts(self):
|
||||
lang_count = self._lang_count
|
||||
if lang_count is not None:return lang_count
|
||||
text_langs = self._text_langs
|
||||
if text_langs is None or len(text_langs) == 0:return [("zh",0)]
|
||||
lang_counts = defaultdict(int)
|
||||
for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
|
||||
lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
|
||||
lang_counts = list(lang_counts.items())
|
||||
self._lang_count = lang_counts
|
||||
return lang_counts
|
||||
|
||||
def getTexts(self, text:str):
|
||||
if text is None or len(text.strip()) == 0:
|
||||
self._clears()
|
||||
return []
|
||||
# lasts
|
||||
text_langs = self._text_langs
|
||||
if self._text_lasts == text and text_langs is not None:return text_langs
|
||||
# parse
|
||||
self._text_waits = []
|
||||
self._lang_count = None
|
||||
self._text_lasts = text
|
||||
text = self._parse_symbols(text)
|
||||
self._text_langs = text
|
||||
return text
|
||||
|
||||
def classify(self, text:str):
|
||||
return self.getTexts(text)
|
||||
|
||||
def printList(langlist):
|
||||
"""
|
||||
功能:打印数组结果
|
||||
기능: 어레이 결과 인쇄
|
||||
機能:配列結果を印刷
|
||||
Function: Print array results
|
||||
"""
|
||||
print("\n===================【打印结果】===================")
|
||||
if langlist is None or len(langlist) == 0:
|
||||
print("无内容结果,No content result")
|
||||
return
|
||||
for line in langlist:
|
||||
print(line)
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# -----------------------------------
|
||||
# 更新日志:新版本分词更加精准。
|
||||
# Changelog: The new version of the word segmentation is more accurate.
|
||||
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
||||
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
||||
# -----------------------------------
|
||||
|
||||
# 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
||||
# text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
|
||||
|
||||
# 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
||||
# text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
|
||||
|
||||
# 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
||||
# text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
|
||||
|
||||
|
||||
# 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
|
||||
# text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
|
||||
|
||||
|
||||
# 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
|
||||
langsegment = LangSegment()
|
||||
langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
|
||||
text = """
|
||||
我喜欢在雨天里听音乐。
|
||||
I enjoy listening to music on rainy days.
|
||||
雨の日に音楽を聴くのが好きです。
|
||||
비 오는 날에 음악을 듣는 것을 즐깁니다。
|
||||
J'aime écouter de la musique les jours de pluie.
|
||||
Tôi thích nghe nhạc vào những ngày mưa.
|
||||
Мне нравится слушать музыку в дождливую погоду.
|
||||
ฉันชอบฟังเพลงในวันที่ฝนตก
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
|
||||
langlist = langsegment.getTexts(text)
|
||||
printList(langlist)
|
||||
|
||||
|
||||
# 语种统计:Language statistics:
|
||||
print("\n===================【语种统计】===================")
|
||||
# 获取所有语种数组结果,根据内容字数降序排列
|
||||
# Get the array results in all languages, sorted in descending order according to the number of content words
|
||||
langCounts = langsegment.getCounts()
|
||||
print(langCounts , "\n")
|
||||
|
||||
# 根据结果获取内容的主要语种 (语言,字数含标点)
|
||||
# Get the main language of content based on the results (language, word count including punctuation)
|
||||
lang , count = langCounts[0]
|
||||
print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
|
||||
print("==================================================\n")
|
||||
|
||||
|
||||
# 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
|
||||
# ===================【打印结果】===================
|
||||
# {'lang': 'zh', 'text': '你的名字叫'}
|
||||
# {'lang': 'ja', 'text': '佐々木?'}
|
||||
# {'lang': 'zh', 'text': '吗?韩语中的'}
|
||||
# {'lang': 'ko', 'text': '안녕 오빠'}
|
||||
# {'lang': 'zh', 'text': '读什么呢?'}
|
||||
# {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
|
||||
# {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
|
||||
# {'lang': 'en', 'text': 'i Phone '}
|
||||
# {'lang': 'zh', 'text': '15系列机型和三款'}
|
||||
# {'lang': 'en', 'text': 'Apple Watch '}
|
||||
# {'lang': 'zh', 'text': '等一系列新品,这次的'}
|
||||
# {'lang': 'en', 'text': 'i Pad Air '}
|
||||
# {'lang': 'zh', 'text': '采用了'}
|
||||
# {'lang': 'en', 'text': 'L C D '}
|
||||
# {'lang': 'zh', 'text': '屏幕'}
|
||||
# ===================【语种统计】===================
|
||||
|
||||
# ===================【语种统计】===================
|
||||
# [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
|
||||
|
||||
# 输入内容的主要语言为 = zh ,字数 = 51
|
||||
# ==================================================
|
||||
# The main language of the input content is = zh, word count = 51
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
language_segmentation/__init__.py
Normal file
9
language_segmentation/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .LangSegment import LangSegment
|
||||
|
||||
|
||||
# release
|
||||
__version__ = '0.3.5'
|
||||
|
||||
|
||||
# develop
|
||||
__develop__ = 'dev-0.0.1'
|
||||
0
language_segmentation/utils/__init__.py
Normal file
0
language_segmentation/utils/__init__.py
Normal file
327
language_segmentation/utils/num.py
Normal file
327
language_segmentation/utils/num.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Digital processing from GPT_SoVITS num.py (thanks)
|
||||
"""
|
||||
Rules to verbalize numbers into Chinese characters.
|
||||
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
||||
UNITS = OrderedDict({
|
||||
1: '十',
|
||||
2: '百',
|
||||
3: '千',
|
||||
4: '万',
|
||||
8: '亿',
|
||||
})
|
||||
|
||||
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
|
||||
# 分数表达式
|
||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||
|
||||
|
||||
def replace_frac(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
sign = match.group(1)
|
||||
nominator = match.group(2)
|
||||
denominator = match.group(3)
|
||||
sign: str = "负" if sign else ""
|
||||
nominator: str = num2str(nominator)
|
||||
denominator: str = num2str(denominator)
|
||||
result = f"{sign}{denominator}分之{nominator}"
|
||||
return result
|
||||
|
||||
|
||||
# 百分数表达式
|
||||
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
||||
|
||||
|
||||
def replace_percentage(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
sign = match.group(1)
|
||||
percent = match.group(2)
|
||||
sign: str = "负" if sign else ""
|
||||
percent: str = num2str(percent)
|
||||
result = f"{sign}百分之{percent}"
|
||||
return result
|
||||
|
||||
|
||||
# 整数表达式
|
||||
# 带负号的整数 -10
|
||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||
|
||||
|
||||
def replace_negative_num(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
sign = match.group(1)
|
||||
number = match.group(2)
|
||||
sign: str = "负" if sign else ""
|
||||
number: str = num2str(number)
|
||||
result = f"{sign}{number}"
|
||||
return result
|
||||
|
||||
|
||||
# 编号-无符号整形
|
||||
# 00078
|
||||
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
||||
|
||||
|
||||
def replace_default_num(match):
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
number = match.group(0)
|
||||
return verbalize_digit(number, alt_one=True)
|
||||
|
||||
|
||||
# 加减乘除
|
||||
# RE_ASMD = re.compile(
|
||||
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||
RE_ASMD = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
||||
|
||||
asmd_map = {
|
||||
'+': '加',
|
||||
'-': '减',
|
||||
'×': '乘',
|
||||
'÷': '除',
|
||||
'=': '等于'
|
||||
}
|
||||
|
||||
def replace_asmd(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
||||
return result
|
||||
|
||||
|
||||
# 次方专项
|
||||
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
||||
|
||||
power_map = {
|
||||
'⁰': '0',
|
||||
'¹': '1',
|
||||
'²': '2',
|
||||
'³': '3',
|
||||
'⁴': '4',
|
||||
'⁵': '5',
|
||||
'⁶': '6',
|
||||
'⁷': '7',
|
||||
'⁸': '8',
|
||||
'⁹': '9',
|
||||
'ˣ': 'x',
|
||||
'ʸ': 'y',
|
||||
'ⁿ': 'n'
|
||||
}
|
||||
|
||||
def replace_power(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
power_num = ""
|
||||
for m in match.group(0):
|
||||
power_num += power_map[m]
|
||||
result = "的" + power_num + "次方"
|
||||
return result
|
||||
|
||||
|
||||
# 数字表达式
|
||||
# 纯小数
|
||||
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
||||
# 正整数 + 量词
|
||||
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
||||
|
||||
|
||||
def replace_positive_quantifier(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
number = match.group(1)
|
||||
match_2 = match.group(2)
|
||||
if match_2 == "+":
|
||||
match_2 = "多"
|
||||
match_2: str = match_2 if match_2 else ""
|
||||
quantifiers: str = match.group(3)
|
||||
number: str = num2str(number)
|
||||
result = f"{number}{match_2}{quantifiers}"
|
||||
return result
|
||||
|
||||
|
||||
def replace_number(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
sign = match.group(1)
|
||||
number = match.group(2)
|
||||
pure_decimal = match.group(5)
|
||||
if pure_decimal:
|
||||
result = num2str(pure_decimal)
|
||||
else:
|
||||
sign: str = "负" if sign else ""
|
||||
number: str = num2str(number)
|
||||
result = f"{sign}{number}"
|
||||
return result
|
||||
|
||||
|
||||
# 范围表达式
|
||||
# match.group(1) and match.group(8) are copy from RE_NUMBER
|
||||
|
||||
RE_RANGE = re.compile(
|
||||
r"""
|
||||
(?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
|
||||
((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
|
||||
[-~] # 匹配范围分隔符
|
||||
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
||||
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
||||
""", re.VERBOSE)
|
||||
|
||||
|
||||
def replace_range(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
first, second = match.group(1), match.group(6)
|
||||
first = RE_NUMBER.sub(replace_number, first)
|
||||
second = RE_NUMBER.sub(replace_number, second)
|
||||
result = f"{first}到{second}"
|
||||
return result
|
||||
|
||||
|
||||
# ~至表达式
|
||||
RE_TO_RANGE = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
||||
|
||||
def replace_to_range(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
result = match.group(0).replace('~', '至')
|
||||
return result
|
||||
|
||||
|
||||
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
||||
stripped = value_string.lstrip('0')
|
||||
if len(stripped) == 0:
|
||||
return []
|
||||
elif len(stripped) == 1:
|
||||
if use_zero and len(stripped) < len(value_string):
|
||||
return [DIGITS['0'], DIGITS[stripped]]
|
||||
else:
|
||||
return [DIGITS[stripped]]
|
||||
else:
|
||||
largest_unit = next(
|
||||
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
first_part = value_string[:-largest_unit]
|
||||
second_part = value_string[-largest_unit:]
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
||||
second_part)
|
||||
|
||||
|
||||
def verbalize_cardinal(value_string: str) -> str:
|
||||
if not value_string:
|
||||
return ''
|
||||
|
||||
# 000 -> '零' , 0 -> '零'
|
||||
value_string = value_string.lstrip('0')
|
||||
if len(value_string) == 0:
|
||||
return DIGITS['0']
|
||||
|
||||
result_symbols = _get_value(value_string)
|
||||
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
||||
'1'] and result_symbols[1] == UNITS[1]:
|
||||
result_symbols = result_symbols[1:]
|
||||
return ''.join(result_symbols)
|
||||
|
||||
|
||||
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||
result = ''.join(result_symbols)
|
||||
if alt_one:
|
||||
result = result.replace("一", "幺")
|
||||
return result
|
||||
|
||||
|
||||
def num2str(value_string: str) -> str:
|
||||
integer_decimal = value_string.split('.')
|
||||
if len(integer_decimal) == 1:
|
||||
integer = integer_decimal[0]
|
||||
decimal = ''
|
||||
elif len(integer_decimal) == 2:
|
||||
integer, decimal = integer_decimal
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The value string: '${value_string}' has more than one point in it."
|
||||
)
|
||||
|
||||
result = verbalize_cardinal(integer)
|
||||
|
||||
decimal = decimal.rstrip('0')
|
||||
if decimal:
|
||||
# '.22' is verbalized as '零点二二'
|
||||
# '3.20' is verbalized as '三点二
|
||||
result = result if result else "零"
|
||||
result += '点' + verbalize_digit(decimal)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
text = ""
|
||||
text = num2str(text)
|
||||
print(text)
|
||||
pass
|
||||
412
lyric_processor_v2.py
Normal file
412
lyric_processor_v2.py
Normal file
@@ -0,0 +1,412 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
from transformers import UMT5EncoderModel, AutoTokenizer, AutoModel
|
||||
import re
|
||||
from typing import List, Tuple, Dict, Set
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class TrieNode:
|
||||
def __init__(self):
|
||||
self.children: Dict[str, 'TrieNode'] = {}
|
||||
self.is_end_of_word: bool = False
|
||||
|
||||
|
||||
class Trie:
|
||||
def __init__(self):
|
||||
self.root = TrieNode()
|
||||
|
||||
def insert(self, word: str):
|
||||
node = self.root
|
||||
for char in word:
|
||||
if char not in node.children:
|
||||
node.children[char] = TrieNode()
|
||||
node = node.children[char]
|
||||
node.is_end_of_word = True
|
||||
|
||||
def search_from(self, s: str, start: int) -> List[str]:
|
||||
"""
|
||||
从字符串 s 的位置 start 开始,使用 Trie 树查找所有可能的匹配 phoneme。
|
||||
返回所有匹配的 phoneme。
|
||||
"""
|
||||
node = self.root
|
||||
matches = []
|
||||
current_phoneme = []
|
||||
for i in range(start, len(s)):
|
||||
char = s[i]
|
||||
if char in node.children:
|
||||
node = node.children[char]
|
||||
current_phoneme.append(char)
|
||||
if node.is_end_of_word:
|
||||
matches.append(''.join(current_phoneme))
|
||||
else:
|
||||
break
|
||||
return matches
|
||||
|
||||
|
||||
class PhonemeMatcher:
|
||||
def __init__(self, word_dict: Set[str]):
|
||||
"""
|
||||
初始化 PhonemeMatcher,构建 Trie 树。
|
||||
|
||||
:param word_dict: Set[str] - 包含所有 phoneme 的集合
|
||||
"""
|
||||
self.trie = Trie()
|
||||
for word in word_dict:
|
||||
self.trie.insert(word)
|
||||
|
||||
def tokenize(self, s: str) -> List[str]:
|
||||
"""
|
||||
将输入的 xsampa 字符串拆分成 phoneme 序列,尽可能使用词表中的 phoneme,
|
||||
并在无法完全匹配时,选择编辑距离最小且 phoneme 数量最少的序列。
|
||||
|
||||
:param s: str - 输入的 xsampa 字符串
|
||||
:return: List[str] - 输出的 phoneme 序列
|
||||
"""
|
||||
n = len(s)
|
||||
# 初始化 DP 数组,dp[i] = (cost, phoneme_count, phone_list)
|
||||
dp: List[Tuple[int, int, List[str]]] = [(sys.maxsize, sys.maxsize, []) for _ in range(n + 1)]
|
||||
dp[0] = (0, 0, [])
|
||||
|
||||
for i in range(n):
|
||||
current_cost, current_count, current_list = dp[i]
|
||||
if current_cost == sys.maxsize:
|
||||
continue # 无法到达当前位置
|
||||
|
||||
# 查找所有从位置 i 开始的匹配 phoneme
|
||||
matches = self.trie.search_from(s, i)
|
||||
|
||||
if matches:
|
||||
for phoneme in matches:
|
||||
end = i + len(phoneme)
|
||||
new_cost = current_cost # 匹配成功,无需增加编辑距离
|
||||
new_count = current_count + 1
|
||||
new_list = current_list + [phoneme]
|
||||
|
||||
if new_cost < dp[end][0]:
|
||||
dp[end] = (new_cost, new_count, new_list)
|
||||
elif new_cost == dp[end][0]:
|
||||
if new_count < dp[end][1]:
|
||||
dp[end] = (new_cost, new_count, new_list)
|
||||
else:
|
||||
# 没有匹配的 phoneme,考虑跳过当前字符,增加编辑距离
|
||||
new_cost = current_cost + 1
|
||||
end = i + 1
|
||||
new_count = current_count + 1 # 跳过一个字符也算作一个 phoneme
|
||||
new_list = current_list + [s[i]]
|
||||
|
||||
if new_cost < dp[end][0]:
|
||||
dp[end] = (new_cost, new_count, new_list)
|
||||
elif new_cost == dp[end][0]:
|
||||
if new_count < dp[end][1]:
|
||||
dp[end] = (new_cost, new_count, new_list)
|
||||
|
||||
# 如果无法完全匹配,选择最优的近似匹配
|
||||
if dp[n][0] == sys.maxsize:
|
||||
# 找到所有可能的最小编辑距离
|
||||
min_cost = min(dp[i][0] for i in range(n + 1))
|
||||
# 选择最小编辑距离且 phoneme 数量最少的序列
|
||||
candidates = [dp[i] for i in range(n + 1) if dp[i][0] == min_cost]
|
||||
if candidates:
|
||||
# 选择 phoneme 数量最少的
|
||||
best = min(candidates, key=lambda x: x[1])
|
||||
return best[2]
|
||||
else:
|
||||
return []
|
||||
|
||||
return dp[n][2]
|
||||
|
||||
|
||||
HARMONIX_LABELS = [
|
||||
'start',
|
||||
'end',
|
||||
'intro',
|
||||
'outro',
|
||||
'break',
|
||||
'bridge',
|
||||
'inst',
|
||||
'solo',
|
||||
'verse',
|
||||
'chorus',
|
||||
]
|
||||
|
||||
|
||||
def timestamp2second(timestamps):
|
||||
res = []
|
||||
for item in timestamps:
|
||||
start, end = item["start"], item["end"]
|
||||
# convert 8kHz to latents level
|
||||
start = round(start / 8000, 2)
|
||||
end = round(end / 8000, 2)
|
||||
res.append({"start": start, "end": end})
|
||||
return res
|
||||
|
||||
|
||||
def sample_lyric_mask(voiced_timestamp, max_length):
|
||||
voiced_timestamps = timestamp2second(voiced_timestamp)
|
||||
|
||||
min_gaps = [1,2,3,4,5]
|
||||
while len(min_gaps) > 0:
|
||||
min_gap = min_gaps.pop()
|
||||
can_split_breaks = []
|
||||
last_end = 0.00
|
||||
for item in voiced_timestamps:
|
||||
if item["start"] - last_end >= min_gap:
|
||||
if last_end == 0.00:
|
||||
can_split_breaks.append((last_end, item["start"] - 0.5))
|
||||
else:
|
||||
can_split_breaks.append((last_end + 0.5, item["start"] - 0.5))
|
||||
last_end = item["end"]
|
||||
if len(can_split_breaks) > 1:
|
||||
if can_split_breaks[1][0] <= 360:
|
||||
break
|
||||
else:
|
||||
if min_gap == 1:
|
||||
return 0.0, 360.0, 36
|
||||
|
||||
if len(can_split_breaks) == 0:
|
||||
mask_start, mask_end = 0.0, max_length
|
||||
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
|
||||
return 0.0, mask_end, min_cut_level
|
||||
|
||||
if len(can_split_breaks) == 1:
|
||||
# 前后随机选一个
|
||||
mask_start = random.choice(["start", "middle"])
|
||||
if mask_start == "start":
|
||||
mask_start = 0.0
|
||||
mask_end = random.uniform(can_split_breaks[0][0], can_split_breaks[0][1])
|
||||
else:
|
||||
mask_start = random.uniform(can_split_breaks[0][0], can_split_breaks[0][1])
|
||||
mask_end = max_length
|
||||
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
|
||||
return mask_start, mask_end, min_cut_level
|
||||
|
||||
mask_start, mask_end = 0.0, 370
|
||||
min_cut_level = 37
|
||||
breaths_gap = [end-start for start, end in can_split_breaks]
|
||||
max_tried = 5
|
||||
while mask_end - mask_start > 370 and min_cut_level > 0 and min_cut_level > 36:
|
||||
total_breaths = len(can_split_breaks)
|
||||
start = random.choices(range(total_breaths-1), weights=breaths_gap[:-1])[0]
|
||||
end = random.choices(range(start + 1, total_breaths), weights=breaths_gap[start+1:], k=1)[0]
|
||||
start_break, end_break = can_split_breaks[start], can_split_breaks[end]
|
||||
mask_start, mask_end = random.uniform(start_break[0], start_break[1]), random.uniform(end_break[0], end_break[1])
|
||||
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
|
||||
if min_cut_level < 36:
|
||||
min_cut_level = random.randint(min_cut_level, 36)
|
||||
if max_tried == 0:
|
||||
print("max tried", mask_start, mask_end, min_cut_level, "breaths_gap", breaths_gap, "can_split_breaks", can_split_breaks)
|
||||
break
|
||||
max_tried -= 1
|
||||
mask_start, mask_end = round(mask_start, 2), min(round(mask_end, 2), max_length)
|
||||
return mask_start, mask_end, min_cut_level
|
||||
|
||||
|
||||
def check_valid_lyric_lines(lyric_lines):
|
||||
# must has lyric lines
|
||||
if len(lyric_lines) == 0:
|
||||
return False
|
||||
for valid_lyric_line in lyric_lines:
|
||||
if len(valid_lyric_line[1]) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def select_valid_lyric_lines(lyric_lines, mask_start, mask_end):
|
||||
# 选歌词原则
|
||||
# 宁可多,不可少
|
||||
# 选取mask_start和mask_end之间的歌词行,如果mask_end在一个歌词行中间,那么这个歌词行也要被选取,但最后的structure不要
|
||||
valid_lyric_lines = []
|
||||
add_tail_structure = True
|
||||
for lyric_line in lyric_lines:
|
||||
if lyric_line["start"] > lyric_line["end"]:
|
||||
continue
|
||||
if lyric_line["start"]+1.0 >= mask_start and lyric_line["end"]-1.0 <= mask_end:
|
||||
if len(valid_lyric_lines) > 0:
|
||||
if valid_lyric_lines[-1][0] is not None and valid_lyric_lines[-1][0] != lyric_line["structure"] and lyric_line["structure"] != "":
|
||||
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
|
||||
elif lyric_line["structure"] != "":
|
||||
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
|
||||
lyric_line["lyric_line"] = lyric_line["lyric_line"].strip()
|
||||
if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0:
|
||||
valid_lyric_lines.append((None, lyric_line["lyric_line"], lyric_line["phoneme_line_ipa"], (lyric_line["start"], lyric_line["end"])))
|
||||
elif mask_start < lyric_line["start"] and lyric_line["start"] < mask_end and lyric_line["end"] > mask_end:
|
||||
lyric_line["lyric_line"] = lyric_line["lyric_line"].strip()
|
||||
if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0:
|
||||
valid_lyric_lines.append((None, lyric_line["lyric_line"], lyric_line["phoneme_line_ipa"], (lyric_line["start"], lyric_line["end"])))
|
||||
add_tail_structure = False
|
||||
break
|
||||
elif lyric_line["start"] > mask_start and lyric_line["start"] < mask_end and not lyric_line["lyric_line"] and add_tail_structure:
|
||||
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
|
||||
add_tail_structure = False
|
||||
break
|
||||
if len(valid_lyric_lines) > 0 and len(lyric_lines) > 0 and add_tail_structure:
|
||||
if lyric_lines[-1]["structure"] != "" and lyric_lines[-1]["structure"] != valid_lyric_lines[-1][0]:
|
||||
if lyric_lines[-1]["start"] > mask_start and lyric_lines[-1]["start"] < mask_end:
|
||||
valid_lyric_lines.append((lyric_lines[-1]["structure"], [], [], (lyric_lines[-1]["start"], lyric_lines[-1]["end"])))
|
||||
return valid_lyric_lines
|
||||
|
||||
|
||||
def sample_lyric_mask_with_cut_levels(voiced_timestamp, cut_level, n_chunks, lyric_lines):
|
||||
voiced_timestamps = timestamp2second(voiced_timestamp)
|
||||
|
||||
candidate_spans = []
|
||||
for candidate_start_idx in range(n_chunks):
|
||||
candidate_start_second = candidate_start_idx * 10
|
||||
candidate_end_second = (candidate_start_idx + cut_level) * 10
|
||||
valid = True
|
||||
for item in voiced_timestamps:
|
||||
if item["start"] < candidate_start_second and candidate_start_second < item["end"]:
|
||||
valid = False
|
||||
break
|
||||
if item["start"] < candidate_end_second and candidate_end_second < item["end"]:
|
||||
valid = False
|
||||
break
|
||||
valid_lyric_lines = select_valid_lyric_lines(lyric_lines, candidate_start_second, candidate_end_second)
|
||||
if not check_valid_lyric_lines(valid_lyric_lines):
|
||||
valid = False
|
||||
if valid:
|
||||
candidate_spans.append((candidate_start_second, candidate_end_second, valid_lyric_lines))
|
||||
|
||||
if len(candidate_spans) > 0:
|
||||
return candidate_spans
|
||||
else:
|
||||
candidate_spans = []
|
||||
for candidate_start_idx in range(n_chunks):
|
||||
candidate_start_second = candidate_start_idx * 10
|
||||
candidate_end_second = (candidate_start_idx + cut_level) * 10
|
||||
valid_lyric_lines = select_valid_lyric_lines(lyric_lines, candidate_start_second, candidate_end_second)
|
||||
if check_valid_lyric_lines(valid_lyric_lines):
|
||||
candidate_spans.append((candidate_start_second, candidate_end_second, valid_lyric_lines))
|
||||
if len(candidate_spans) > 0:
|
||||
return candidate_spans
|
||||
return []
|
||||
|
||||
|
||||
def sample_lyric_mask_with_lyric_timestamp(cut_level, lyric_lines, expected_num_example, n_chunks, start_pad_offset=1.0):
|
||||
# 1 去掉structure
|
||||
# non_structure_lyric_lines = [lyric_line for lyric_line in lyric_lines if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0 and lyric_line["start"] < lyric_line["end"]]
|
||||
# 保留structure
|
||||
valid_lyric_lines = []
|
||||
last_structure = ""
|
||||
for lyric_line in lyric_lines:
|
||||
if "structure" not in lyric_line:
|
||||
lyric_line["structure"] = ""
|
||||
if lyric_line["start"] < lyric_line["end"]:
|
||||
new_line = lyric_line.copy()
|
||||
if not lyric_line["lyric_line"] or "phoneme_line_ipa" not in lyric_line or len(lyric_line["phoneme_line_ipa"]) == 0:
|
||||
if lyric_line["structure"] != "":
|
||||
new_line["lyric_line"] = "["+lyric_line["structure"]+"]"
|
||||
new_line["phoneme_line_ipa"] = ["_"]
|
||||
else:
|
||||
last_structure = lyric_line["structure"]
|
||||
continue
|
||||
else:
|
||||
if new_line["structure"] != "" and new_line["structure"] != last_structure:
|
||||
if new_line["lyric_line"] != "[" + new_line["structure"] + "]":
|
||||
new_line["lyric_line"] = f"[{new_line['structure']}]\n{new_line['lyric_line']}"
|
||||
new_line["phoneme_line_ipa"] = ["_", "_"] + new_line["phoneme_line_ipa"]
|
||||
|
||||
valid_lyric_lines.append(new_line)
|
||||
last_structure = lyric_line["structure"]
|
||||
|
||||
# 2 优先选刚好包含在里面的
|
||||
full_spans = []
|
||||
partial_spans = []
|
||||
# print("non_structure_lyric_lines", non_structure_lyric_lines, n_chunks)
|
||||
for start_idx in range(len(valid_lyric_lines)):
|
||||
for end_idx in range(start_idx, len(valid_lyric_lines)):
|
||||
start = valid_lyric_lines[start_idx]["start"]
|
||||
end = start + cut_level * 10
|
||||
|
||||
# print("start_idx:", start_idx, "end_idx:", end_idx, "start:", start, "end:", end, "non_structure_lyric_lines[end_idx]:", non_structure_lyric_lines[end_idx])
|
||||
|
||||
if start_idx == end_idx and valid_lyric_lines[start_idx]["end"] > end:
|
||||
res = [(None, valid_lyric_lines[start_idx]["lyric_line"], valid_lyric_lines[start_idx]["phoneme_line_ipa"], (valid_lyric_lines[start_idx]["start"], valid_lyric_lines[start_idx]["end"])) for line_idx in range(start_idx, end_idx+1)]
|
||||
if len(res) > 0:
|
||||
partial_spans.append((start, end, res))
|
||||
break
|
||||
|
||||
if end_idx > 0 and end < valid_lyric_lines[end_idx]["start"] and valid_lyric_lines[end_idx-1]["end"] + start_pad_offset < end:
|
||||
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx)]
|
||||
if len(res) > 0:
|
||||
full_spans.append((start, end, res))
|
||||
break
|
||||
|
||||
if end < valid_lyric_lines[end_idx]["end"] + start_pad_offset and end > valid_lyric_lines[end_idx]["start"]:
|
||||
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx)]
|
||||
if len(res) > 0:
|
||||
partial_spans.append((start, end, res))
|
||||
break
|
||||
|
||||
if valid_lyric_lines[end_idx]["start"] > end:
|
||||
break
|
||||
|
||||
if start_idx == 0 and end_idx == len(valid_lyric_lines) - 1 and len(full_spans) == 0 and len(partial_spans) == 0:
|
||||
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx+1)]
|
||||
if len(res) > 0:
|
||||
full_spans.append((start, end, res))
|
||||
if expected_num_example is not None:
|
||||
if len(full_spans) >= expected_num_example or len(partial_spans) == 0:
|
||||
return full_spans
|
||||
if len(full_spans) + len(partial_spans) >= expected_num_example:
|
||||
left = expected_num_example - len(full_spans)
|
||||
return full_spans + random.sample(partial_spans, left)
|
||||
# print("full_spans:", full_spans)
|
||||
# print("partial_spans:", partial_spans)
|
||||
return full_spans + partial_spans
|
||||
|
||||
|
||||
class LyricProcessor(nn.Module):
|
||||
def __init__(self, infer=False):
|
||||
super().__init__()
|
||||
self.lyric_text_model = UMT5EncoderModel.from_pretrained("./checkpoints/umt5-base", local_files_only=True).eval().half()
|
||||
# not required gradient
|
||||
self.lyric_text_model.requires_grad_(False)
|
||||
self.lyric_text_tokenizer = AutoTokenizer.from_pretrained("./checkpoints/umt5-base", local_files_only=True)
|
||||
|
||||
|
||||
def get_text_embeddings(self, texts, device, text_max_length=256):
|
||||
inputs = self.lyric_text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
||||
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||||
if self.lyric_text_model.device != device:
|
||||
self.lyric_text_model.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = self.lyric_text_model(**inputs)
|
||||
last_hidden_states = outputs.last_hidden_state
|
||||
attention_mask = inputs["attention_mask"]
|
||||
return last_hidden_states, attention_mask
|
||||
|
||||
def preprocess(self, valid_lyric_lines):
|
||||
lyric_texts = []
|
||||
ipa_texts = []
|
||||
for valid_line in valid_lyric_lines:
|
||||
structure, lyric_line, ipa_line = valid_line["structure"], valid_line["lyric"], valid_line["ipa"]
|
||||
if len(structure) > 0:
|
||||
lyric_texts.append(structure)
|
||||
if len(lyric_line) > 0:
|
||||
lyric_texts.append(lyric_line)
|
||||
if len(structure) == 0 and len(lyric_line) == 0:
|
||||
lyric_texts.append("")
|
||||
|
||||
if ipa_line != "_":
|
||||
ipa_line = self.split_unk(ipa_line.split(" "))
|
||||
ipa_line_str = " ".join(ipa_line)
|
||||
# 处理掉G2P的bug
|
||||
ipa_line_str = re.sub(r'\bz(?:\s+ə\s+z)+\b', "", ipa_line_str)
|
||||
ipa_line_str = re.sub(r'\s+', ' ', ipa_line_str).strip()
|
||||
ipa_texts.append(ipa_line_str)
|
||||
else:
|
||||
ipa_texts.append(ipa_line)
|
||||
|
||||
lyric_text = "\n".join(lyric_texts)
|
||||
ipa_text = " _ ".join(ipa_texts)
|
||||
return lyric_text, ipa_text
|
||||
|
||||
644
main_text2music_large_sana_dcae_0331_finetune.py
Normal file
644
main_text2music_large_sana_dcae_0331_finetune.py
Normal file
@@ -0,0 +1,644 @@
|
||||
import json
|
||||
|
||||
import matplotlib
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from models.transformer_sana_text2music_large_dcae_0319 import ACEFlowBaseModel
|
||||
from loguru import logger
|
||||
from transformers import AutoModel
|
||||
from lyric_processor_v2 import LyricProcessor
|
||||
from optimizers.cosine_wsd import configure_lr_scheduler
|
||||
import traceback
|
||||
import torchaudio
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
from music_dcae.music_dcae_pipeline import MusicDCAE
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from apg_guidance import apg_forward, MomentumBuffer
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import os
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
|
||||
class Pipeline(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
num_workers: int = 4,
|
||||
infer: bool = False,
|
||||
train: bool = True,
|
||||
T: int = 1000,
|
||||
minibatch_size: int = 32,
|
||||
batch_size: int = 1,
|
||||
snr_gamma: float = 0.5,
|
||||
prediction_type: str = "v_prediction", # epsilon, sample, v_prediction
|
||||
beta_start: float = 0.0015,
|
||||
beta_end: float = 0.0195,
|
||||
noise_offset: float = 0.1,
|
||||
input_perturbation: float = 0.1,
|
||||
use_ema: bool = False,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
weight_decay: float = 1e-2,
|
||||
num_chunk: int = 2,
|
||||
beta_schedule: str = "scaled_linear",
|
||||
scheduler_type: str = "ddpm",
|
||||
every_plot_step: int = 2000,
|
||||
vocal_noise: float = 0,
|
||||
max_length: int = 6400,
|
||||
sample_size: int = None,
|
||||
target_orig: bool = True,
|
||||
csv_path: str = None,
|
||||
config_path: str = "./models/config_sana_text2music_dcae_0225_3.5B_simple.json",
|
||||
shift: float = 3.0,
|
||||
logit_mean: float = 0.0,
|
||||
logit_std: float = 1.0,
|
||||
timestep_densities_type: str = "logit_normal",
|
||||
ssl_coeff: float = 1.0,
|
||||
wav_max_seconds: float = 30.0,
|
||||
max_steps: int = -1,
|
||||
fix_cut_level: int = 3,
|
||||
ipa_max_length: int = 8192,
|
||||
text_max_length: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters()
|
||||
self.is_train = train
|
||||
self.T = T
|
||||
self.beta_start = beta_start
|
||||
self.beta_end = beta_end
|
||||
|
||||
self.scheduler = self.get_scheduler()
|
||||
with open(config_path, "r") as f:
|
||||
self.config = json.load(f)
|
||||
self.transformers = ACEFlowBaseModel(**self.config)
|
||||
|
||||
self.lyric_processor = LyricProcessor()
|
||||
self.lyric_processor.requires_grad_(False)
|
||||
|
||||
if not infer and self.is_train:
|
||||
self.mert_model = AutoModel.from_pretrained("./checkpoints/MERT-v1-330M", trust_remote_code=True).eval()
|
||||
self.mert_model.requires_grad_(False)
|
||||
self.resampler_mert = torchaudio.transforms.Resample(orig_freq=48000, new_freq=24000)
|
||||
self.processor_mert = Wav2Vec2FeatureExtractor.from_pretrained("./checkpoints/MERT-v1-330M", trust_remote_code=True)
|
||||
|
||||
self.hubert_model = AutoModel.from_pretrained("checkpoints/mHuBERT-147", local_files_only=True).eval()
|
||||
self.hubert_model.requires_grad_(False)
|
||||
self.resampler_mhubert = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000)
|
||||
self.processor_mhubert = Wav2Vec2FeatureExtractor.from_pretrained("checkpoints/mHuBERT-147", local_files_only=True)
|
||||
|
||||
self.ssl_coeff = ssl_coeff
|
||||
|
||||
self.vae = MusicDCAE(encoder_only=False).eval()
|
||||
self.vae.requires_grad_(False)
|
||||
|
||||
# self.mert_model = torch.compile(self.mert_model)
|
||||
# self.hubert_model = torch.compile(self.hubert_model)
|
||||
# self.vae = torch.compile(self.vae)
|
||||
# self.transformers = torch.compile(self.transformers)
|
||||
else:
|
||||
self.vae = MusicDCAE(encoder_only=False).eval()
|
||||
self.vae.requires_grad_(False)
|
||||
|
||||
def infer_mert_ssl(self, target_wavs, wav_lengths):
|
||||
# 输入为 N x 2 x T (48kHz),转换为 N x T (24kHz),单声道
|
||||
mert_input_wavs_mono_24k = self.resampler_mert(target_wavs.mean(dim=1))
|
||||
bsz = target_wavs.shape[0]
|
||||
actual_lengths_24k = wav_lengths // 2 # 48kHz -> 24kHz
|
||||
|
||||
# 对实际音频部分进行归一化
|
||||
means = torch.stack([mert_input_wavs_mono_24k[i, :actual_lengths_24k[i]].mean() for i in range(bsz)])
|
||||
vars = torch.stack([mert_input_wavs_mono_24k[i, :actual_lengths_24k[i]].var() for i in range(bsz)])
|
||||
mert_input_wavs_mono_24k = (mert_input_wavs_mono_24k - means.view(-1, 1)) / torch.sqrt(vars.view(-1, 1) + 1e-7)
|
||||
|
||||
# MERT SSL 约束
|
||||
# 定义每个 chunk 的长度(5 秒的采样点数)
|
||||
chunk_size = 24000 * 5 # 5 秒,每秒 24000 个采样点
|
||||
total_length = mert_input_wavs_mono_24k.shape[1]
|
||||
|
||||
num_chunks_per_audio = (actual_lengths_24k + chunk_size - 1) // chunk_size
|
||||
|
||||
# 分块处理
|
||||
all_chunks = []
|
||||
chunk_actual_lengths = []
|
||||
for i in range(bsz):
|
||||
audio = mert_input_wavs_mono_24k[i]
|
||||
actual_length = actual_lengths_24k[i]
|
||||
for start in range(0, actual_length, chunk_size):
|
||||
end = min(start + chunk_size, actual_length)
|
||||
chunk = audio[start:end]
|
||||
if len(chunk) < chunk_size:
|
||||
chunk = F.pad(chunk, (0, chunk_size - len(chunk))) # 不足部分用零填充
|
||||
all_chunks.append(chunk)
|
||||
chunk_actual_lengths.append(end - start)
|
||||
|
||||
# 堆叠所有块为 (total_chunks, chunk_size)
|
||||
all_chunks = torch.stack(all_chunks, dim=0)
|
||||
|
||||
# 批量推理
|
||||
with torch.no_grad():
|
||||
# 输出形状: (total_chunks, seq_len, hidden_size)
|
||||
mert_ssl_hidden_states = self.mert_model(all_chunks).last_hidden_state
|
||||
|
||||
# 计算每个块的特征数量
|
||||
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
|
||||
|
||||
# 裁剪每个块的隐藏状态
|
||||
chunk_hidden_states = [mert_ssl_hidden_states[i, :chunk_num_features[i], :] for i in range(len(all_chunks))]
|
||||
|
||||
# 按音频组织隐藏状态
|
||||
mert_ssl_hidden_states_list = []
|
||||
chunk_idx = 0
|
||||
for i in range(bsz):
|
||||
audio_chunks = chunk_hidden_states[chunk_idx:chunk_idx + num_chunks_per_audio[i]]
|
||||
audio_hidden = torch.cat(audio_chunks, dim=0) # 拼接同一音频的块
|
||||
mert_ssl_hidden_states_list.append(audio_hidden)
|
||||
chunk_idx += num_chunks_per_audio[i]
|
||||
|
||||
return mert_ssl_hidden_states_list
|
||||
|
||||
|
||||
def infer_mhubert_ssl(self, target_wavs, wav_lengths):
|
||||
# Step 1: Preprocess audio
|
||||
# Input: N x 2 x T (48kHz, stereo) -> N x T (16kHz, mono)
|
||||
mhubert_input_wavs_mono_16k = self.resampler_mhubert(target_wavs.mean(dim=1))
|
||||
bsz = target_wavs.shape[0]
|
||||
actual_lengths_16k = wav_lengths // 3 # Convert lengths from 48kHz to 16kHz
|
||||
|
||||
# Step 2: Zero-mean unit-variance normalization (only on actual audio)
|
||||
means = torch.stack([mhubert_input_wavs_mono_16k[i, :actual_lengths_16k[i]].mean()
|
||||
for i in range(bsz)])
|
||||
vars = torch.stack([mhubert_input_wavs_mono_16k[i, :actual_lengths_16k[i]].var()
|
||||
for i in range(bsz)])
|
||||
mhubert_input_wavs_mono_16k = (mhubert_input_wavs_mono_16k - means.view(-1, 1)) / \
|
||||
torch.sqrt(vars.view(-1, 1) + 1e-7)
|
||||
|
||||
# Step 3: Define chunk size for MHubert (30 seconds at 16kHz)
|
||||
chunk_size = 16000 * 30 # 30 seconds = 480,000 samples
|
||||
|
||||
# Step 4: Split audio into chunks
|
||||
num_chunks_per_audio = (actual_lengths_16k + chunk_size - 1) // chunk_size # Ceiling division
|
||||
all_chunks = []
|
||||
chunk_actual_lengths = []
|
||||
|
||||
for i in range(bsz):
|
||||
audio = mhubert_input_wavs_mono_16k[i]
|
||||
actual_length = actual_lengths_16k[i]
|
||||
for start in range(0, actual_length, chunk_size):
|
||||
end = min(start + chunk_size, actual_length)
|
||||
chunk = audio[start:end]
|
||||
if len(chunk) < chunk_size:
|
||||
chunk = F.pad(chunk, (0, chunk_size - len(chunk))) # Pad with zeros
|
||||
all_chunks.append(chunk)
|
||||
chunk_actual_lengths.append(end - start)
|
||||
|
||||
# Step 5: Stack all chunks for batch inference
|
||||
all_chunks = torch.stack(all_chunks, dim=0) # Shape: (total_chunks, chunk_size)
|
||||
|
||||
# Step 6: Batch inference with MHubert model
|
||||
with torch.no_grad():
|
||||
mhubert_ssl_hidden_states = self.hubert_model(all_chunks).last_hidden_state
|
||||
# Shape: (total_chunks, seq_len, hidden_size)
|
||||
|
||||
# Step 7: Compute number of features per chunk (assuming model stride of 320)
|
||||
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
|
||||
|
||||
# Step 8: Trim hidden states to remove padding effects
|
||||
chunk_hidden_states = [mhubert_ssl_hidden_states[i, :chunk_num_features[i], :] for i in range(len(all_chunks))]
|
||||
|
||||
# Step 9: Reorganize hidden states by original audio
|
||||
mhubert_ssl_hidden_states_list = []
|
||||
chunk_idx = 0
|
||||
for i in range(bsz):
|
||||
audio_chunks = chunk_hidden_states[chunk_idx:chunk_idx + num_chunks_per_audio[i]]
|
||||
audio_hidden = torch.cat(audio_chunks, dim=0) # Concatenate chunks for this audio
|
||||
mhubert_ssl_hidden_states_list.append(audio_hidden)
|
||||
chunk_idx += num_chunks_per_audio[i]
|
||||
return mhubert_ssl_hidden_states_list
|
||||
|
||||
def preprocess(self, batch, train=True):
|
||||
target_wavs = batch["target_wavs"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
|
||||
dtype = target_wavs.dtype
|
||||
bs = target_wavs.shape[0]
|
||||
device = target_wavs.device
|
||||
|
||||
# SSL约束
|
||||
mert_ssl_hidden_states = None
|
||||
mhubert_ssl_hidden_states = None
|
||||
# is_long = target_wavs.shape[-1] >= 48000 * 150
|
||||
if train:
|
||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||
mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths)
|
||||
# mhubert_ssl_hidden_states = self.infer_mhubert_ssl(batch["vocal_wavs"], wav_lengths)
|
||||
mhubert_ssl_hidden_states = self.infer_mhubert_ssl(target_wavs, wav_lengths)
|
||||
|
||||
# 1: text embedding
|
||||
texts = batch["prompts"]
|
||||
encoder_text_hidden_states, text_attention_mask = self.lyric_processor.get_text_embeddings(texts, device)
|
||||
encoder_text_hidden_states = encoder_text_hidden_states.to(dtype)
|
||||
|
||||
target_latents, _ = self.vae.encode(target_wavs, wav_lengths)
|
||||
attention_mask = torch.ones(bs, target_latents.shape[-1], device=device, dtype=dtype)
|
||||
|
||||
speaker_embds = batch["speaker_embs"].to(dtype)
|
||||
keys = batch["keys"]
|
||||
lyric_token_ids = batch["lyric_token_ids"]
|
||||
lyric_mask = batch["lyric_masks"]
|
||||
|
||||
# pretrain stage 2 需要 cfg
|
||||
if train:
|
||||
full_cfg_condition_mask = torch.where(
|
||||
(torch.rand(size=(bs,), device=device) < 0.15),
|
||||
torch.zeros(size=(bs,), device=device),
|
||||
torch.ones(size=(bs,), device=device)
|
||||
).long()
|
||||
# N x T x 768
|
||||
encoder_text_hidden_states = torch.where(full_cfg_condition_mask.unsqueeze(1).unsqueeze(1).bool(), encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states))
|
||||
|
||||
# full_cfg_condition_mask = torch.where(
|
||||
# (torch.rand(size=(bs,), device=device) < 0.50),
|
||||
# torch.zeros(size=(bs,), device=device),
|
||||
# torch.ones(size=(bs,), device=device)
|
||||
# ).long()
|
||||
# # N x 512
|
||||
# speaker_embds = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), speaker_embds, torch.zeros_like(speaker_embds))
|
||||
|
||||
# 歌词
|
||||
full_cfg_condition_mask = torch.where(
|
||||
(torch.rand(size=(bs,), device=device) < 0.15),
|
||||
torch.zeros(size=(bs,), device=device),
|
||||
torch.ones(size=(bs,), device=device)
|
||||
).long()
|
||||
lyric_token_ids = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), lyric_token_ids, torch.zeros_like(lyric_token_ids))
|
||||
lyric_mask = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), lyric_mask, torch.zeros_like(lyric_mask))
|
||||
|
||||
return (
|
||||
keys,
|
||||
target_latents,
|
||||
attention_mask,
|
||||
encoder_text_hidden_states,
|
||||
text_attention_mask,
|
||||
speaker_embds,
|
||||
lyric_token_ids,
|
||||
lyric_mask,
|
||||
mert_ssl_hidden_states,
|
||||
mhubert_ssl_hidden_states,
|
||||
)
|
||||
|
||||
def get_scheduler(self):
|
||||
return FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=self.T,
|
||||
shift=self.hparams.shift,
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
# trainable_parameters = self.transformers.get_trainable_parameters()
|
||||
|
||||
# optimizer = get_muon_optimizer(
|
||||
# self.transformers.named_parameters(),
|
||||
# lr=self.hparams.learning_rate,
|
||||
# wd=self.hparams.weight_decay,
|
||||
# )
|
||||
# optimizer = CAME8BitWrapper(
|
||||
# params=[
|
||||
# {'params': self.transformers.parameters()},
|
||||
# ],
|
||||
# lr=self.hparams.learning_rate,
|
||||
# weight_decay=self.hparams.weight_decay,
|
||||
# betas=(0.8, 0.9),
|
||||
# )
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=[
|
||||
{'params': self.transformers.parameters()},
|
||||
],
|
||||
lr=self.hparams.learning_rate,
|
||||
weight_decay=self.hparams.weight_decay,
|
||||
betas=(0.8, 0.9),
|
||||
)
|
||||
max_steps = self.hparams.max_steps
|
||||
# 训练200k
|
||||
decay_interval = int(max_steps * (1 - 0.9) * 0.2)
|
||||
lr_scheduler = configure_lr_scheduler(optimizer, total_steps_per_epoch=max_steps, epochs=1, decay_ratio=0.9, decay_interval=decay_interval, warmup_iters=4000)
|
||||
return [optimizer], lr_scheduler
|
||||
|
||||
def get_sd3_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def get_timestep(self, bsz, device):
|
||||
if self.hparams.timestep_densities_type == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
# In practice, we sample the random variable u from a normal distribution u ∼ N (u; m, s)
|
||||
# and map it through the standard logistic function
|
||||
u = torch.normal(mean=self.hparams.logit_mean, std=self.hparams.logit_std, size=(bsz, ), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
indices = (u * self.scheduler.config.num_train_timesteps).long()
|
||||
indices = torch.clamp(indices, 0, self.scheduler.config.num_train_timesteps - 1)
|
||||
timesteps = self.scheduler.timesteps[indices].to(device)
|
||||
|
||||
if self.hparams.timestep_densities_type == "u_shape":
|
||||
# 参数 a 决定 U-shaped 程度,论文中 a=4 效果较好
|
||||
a = 4.0
|
||||
# 从均匀分布采样 v
|
||||
v = torch.rand(bsz)
|
||||
|
||||
# 计算 u:使用上述解析式
|
||||
# u = 0.5 + (1/a)*asinh( sinh(a/2)*(2*v -1) )
|
||||
s = torch.sinh(torch.tensor(a/2))
|
||||
argument = s * (2 * v - 1)
|
||||
u = 0.5 + (1.0 / a) * torch.asinh(argument)
|
||||
|
||||
# 数值上可能有极小偏差,保险起见 clamp 一下
|
||||
u = torch.clamp(u, 0.0, 1.0)
|
||||
|
||||
# 将连续 [0,1] 的 u 映射到具体的离散 timesteps
|
||||
indices = (u * self.scheduler.config.num_train_timesteps).long()
|
||||
indices = torch.clamp(indices, 0, self.scheduler.config.num_train_timesteps - 1)
|
||||
timesteps = self.scheduler.timesteps[indices].to(device)
|
||||
|
||||
return timesteps
|
||||
|
||||
def run_step(self, batch, batch_idx):
|
||||
self.plot_step(batch, batch_idx)
|
||||
(
|
||||
keys,
|
||||
target_latents,
|
||||
attention_mask,
|
||||
encoder_text_hidden_states,
|
||||
text_attention_mask,
|
||||
speaker_embds,
|
||||
lyric_token_ids,
|
||||
lyric_mask,
|
||||
mert_ssl_hidden_states,
|
||||
mhubert_ssl_hidden_states,
|
||||
) = self.preprocess(batch)
|
||||
|
||||
target_image = target_latents
|
||||
device = target_image.device
|
||||
dtype = target_image.dtype
|
||||
# check dtype
|
||||
# logger.info(f"target_image dtype: {target_image.dtype} model dtype: {self.transformers.dtype}")
|
||||
# step 1: 随机生成噪声,初始化设置
|
||||
noise = torch.randn_like(target_image, device=device)
|
||||
bsz = target_image.shape[0]
|
||||
timesteps = self.get_timestep(bsz, device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = self.get_sd3_sigmas(timesteps=timesteps, device=device, n_dim=target_image.ndim, dtype=dtype)
|
||||
noisy_image = sigmas * noise + (1.0 - sigmas) * target_image
|
||||
|
||||
# This is the flow-matching target for vanilla SD3.
|
||||
target = target_image
|
||||
|
||||
# clap ssl 约束 和vocal_latent_channel2的约束
|
||||
all_ssl_hiden_states = []
|
||||
if mert_ssl_hidden_states is not None:
|
||||
all_ssl_hiden_states.append(mert_ssl_hidden_states)
|
||||
if mhubert_ssl_hidden_states is not None:
|
||||
all_ssl_hiden_states.append(mhubert_ssl_hidden_states)
|
||||
|
||||
# N x H -> N x c x W x H
|
||||
x = noisy_image
|
||||
# step 5: predict noise
|
||||
transformer_output = self.transformers(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embds,
|
||||
lyric_token_idx=lyric_token_ids,
|
||||
lyric_mask=lyric_mask,
|
||||
timestep=timesteps.to(device).to(dtype),
|
||||
ssl_hidden_states=all_ssl_hiden_states,
|
||||
)
|
||||
model_pred = transformer_output.sample
|
||||
proj_losses = transformer_output.proj_losses
|
||||
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_image
|
||||
|
||||
# Compute loss. 只有chunk_mask为1,且无padding的地方才计算loss
|
||||
# N x T x 64
|
||||
# chunk_masks_to_cat
|
||||
# N x T -> N x c x W x T
|
||||
mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(-1, target_image.shape[1], target_image.shape[2], -1)
|
||||
|
||||
selected_model_pred = (model_pred * mask).reshape(bsz, -1).contiguous()
|
||||
selected_target = (target * mask).reshape(bsz, -1).contiguous()
|
||||
|
||||
loss = F.mse_loss(selected_model_pred, selected_target, reduction="none")
|
||||
loss = loss.mean(1)
|
||||
loss = loss * mask.reshape(bsz, -1).mean(1)
|
||||
loss = loss.mean()
|
||||
|
||||
prefix = "train"
|
||||
|
||||
self.log(f"{prefix}/denoising_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
total_proj_loss = 0.0
|
||||
for k, v in proj_losses:
|
||||
self.log(f"{prefix}/{k}_loss", v, on_step=True, on_epoch=False, prog_bar=True)
|
||||
total_proj_loss += v
|
||||
|
||||
if len(proj_losses) > 0:
|
||||
total_proj_loss = total_proj_loss / len(proj_losses)
|
||||
|
||||
loss = loss + total_proj_loss * self.ssl_coeff
|
||||
self.log(f"{prefix}/loss", loss, on_step=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
learning_rate = self.lr_schedulers().get_last_lr()[0]
|
||||
self.log(f"{prefix}/learning_rate", learning_rate, on_step=True, on_epoch=False, prog_bar=True)
|
||||
# with torch.autograd.detect_anomaly():
|
||||
# self.manual_backward(loss)
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
return self.run_step(batch, batch_idx)
|
||||
|
||||
@torch.no_grad()
|
||||
def diffusion_process(
|
||||
self,
|
||||
duration,
|
||||
encoder_text_hidden_states,
|
||||
text_attention_mask,
|
||||
speaker_embds,
|
||||
lyric_token_ids,
|
||||
lyric_mask,
|
||||
random_generators=None,
|
||||
infer_steps=60,
|
||||
guidance_scale=15.0,
|
||||
omega_scale=10.0,
|
||||
):
|
||||
|
||||
do_classifier_free_guidance = True
|
||||
if guidance_scale == 0.0 or guidance_scale == 1.0:
|
||||
do_classifier_free_guidance = False
|
||||
|
||||
device = encoder_text_hidden_states.device
|
||||
dtype = encoder_text_hidden_states.dtype
|
||||
bsz = encoder_text_hidden_states.shape[0]
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
)
|
||||
|
||||
frame_length = int(duration * 44100 / 512 / 8)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
|
||||
|
||||
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
||||
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
||||
if do_classifier_free_guidance:
|
||||
attention_mask = torch.cat([attention_mask] * 2, dim=0)
|
||||
encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
|
||||
text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
|
||||
|
||||
speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
|
||||
|
||||
lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
|
||||
lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
|
||||
|
||||
momentum_buffer = MomentumBuffer()
|
||||
|
||||
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latents = target_latents
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
noise_pred = self.transformers(
|
||||
hidden_states=latent_model_input,
|
||||
attention_mask=attention_mask,
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embds,
|
||||
lyric_token_idx=lyric_token_ids,
|
||||
lyric_mask=lyric_mask,
|
||||
timestep=timestep,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_with_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = apg_forward(
|
||||
pred_cond=noise_pred_with_cond,
|
||||
pred_uncond=noise_pred_uncond,
|
||||
guidance_scale=guidance_scale,
|
||||
momentum_buffer=momentum_buffer,
|
||||
)
|
||||
|
||||
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
||||
|
||||
return target_latents
|
||||
|
||||
def predict_step(self, batch, batch_idx):
|
||||
(
|
||||
keys,
|
||||
target_latents,
|
||||
attention_mask,
|
||||
encoder_text_hidden_states,
|
||||
text_attention_mask,
|
||||
speaker_embds,
|
||||
lyric_token_ids,
|
||||
lyric_mask,
|
||||
mert_ssl_hidden_states,
|
||||
mhubert_ssl_hidden_states,
|
||||
) = self.preprocess(batch, train=False)
|
||||
|
||||
infer_steps = 60
|
||||
guidance_scale = 15.0
|
||||
omega_scale = 10.0
|
||||
seed_num = 1234
|
||||
random.seed(seed_num)
|
||||
bsz = target_latents.shape[0]
|
||||
random_generators = [torch.Generator(device=self.device) for _ in range(bsz)]
|
||||
seeds = []
|
||||
for i in range(bsz):
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
random_generators[i].manual_seed(seed)
|
||||
seeds.append(seed)
|
||||
duration = self.hparams.fix_cut_level * 10
|
||||
pred_latents = self.diffusion_process(
|
||||
duration=duration,
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embds=speaker_embds,
|
||||
lyric_token_ids=lyric_token_ids,
|
||||
lyric_mask=lyric_mask,
|
||||
random_generators=random_generators,
|
||||
infer_steps=infer_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
omega_scale=omega_scale,
|
||||
)
|
||||
|
||||
audio_lengths = batch["wav_lengths"]
|
||||
sr, pred_wavs = self.vae.decode(pred_latents, audio_lengths=audio_lengths, sr=48000)
|
||||
return {
|
||||
"target_wavs": batch["target_wavs"],
|
||||
"pred_wavs": pred_wavs,
|
||||
"keys": keys,
|
||||
"prompts": batch["prompts"],
|
||||
"candidate_lyric_chunks": batch["candidate_lyric_chunks"],
|
||||
"sr": sr,
|
||||
"seeds": seeds,
|
||||
}
|
||||
|
||||
def construct_lyrics(self, candidate_lyric_chunk):
|
||||
lyrics = []
|
||||
for chunk in candidate_lyric_chunk:
|
||||
lyrics.append(chunk["lyric"])
|
||||
|
||||
lyrics = "\n".join(lyrics)
|
||||
return lyrics
|
||||
|
||||
def plot_step(self, batch, batch_idx):
|
||||
|
||||
if batch_idx % self.hparams.every_plot_step != 0 or self.local_rank != 0 or torch.distributed.get_rank() != 0 or torch.cuda.current_device() != 0:
|
||||
return
|
||||
results = self.predict_step(batch, batch_idx)
|
||||
|
||||
target_wavs = results["target_wavs"]
|
||||
pred_wavs = results["pred_wavs"]
|
||||
keys = results["keys"]
|
||||
prompts = results["prompts"]
|
||||
candidate_lyric_chunks = results["candidate_lyric_chunks"]
|
||||
sr = results["sr"]
|
||||
seeds = results["seeds"]
|
||||
i = 0
|
||||
for key, target_wav, pred_wav, prompt, candidate_lyric_chunk, seed in zip(keys, target_wavs, pred_wavs, prompts, candidate_lyric_chunks, seeds):
|
||||
key = key
|
||||
prompt = prompt
|
||||
lyric = self.construct_lyrics(candidate_lyric_chunk)
|
||||
key_prompt_lyric = f"# KEY\n\n{key}\n\n\n# PROMPT\n\n{prompt}\n\n\n# LYRIC\n\n{lyric}\n\n# SEED\n\n{seed}\n\n"
|
||||
log_dir = self.logger.log_dir
|
||||
save_dir = f"{log_dir}/eval_results/step_{self.global_step}"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
torchaudio.save(f"{save_dir}/target_wav_{key}_{i}.flac", target_wav.float().cpu(), sr)
|
||||
torchaudio.save(f"{save_dir}/pred_wav_{key}_{i}.flac", pred_wav.float().cpu(), sr)
|
||||
with open(f"{save_dir}/key_prompt_lyric_{key}_{i}.txt", "w") as f:
|
||||
f.write(key_prompt_lyric)
|
||||
i += 1
|
||||
1278
models/attention.py
Normal file
1278
models/attention.py
Normal file
File diff suppressed because it is too large
Load Diff
23
models/config_sana_text2music_dcae_0225_3.5B_simple.json
Normal file
23
models/config_sana_text2music_dcae_0225_3.5B_simple.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"_class_name": "Transformer2DModel",
|
||||
"_diffusers_version": "0.27.2",
|
||||
"in_channels": 8,
|
||||
"num_layers": 24,
|
||||
"inner_dim": 2560,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 20,
|
||||
"mlp_ratio": 2.5,
|
||||
"out_channels": 8,
|
||||
"max_position": 32768,
|
||||
"rope_theta": 1000000.0,
|
||||
"speaker_embedding_dim": 512,
|
||||
"text_embedding_dim": 768,
|
||||
"ssl_encoder_depths": [8, 8],
|
||||
"ssl_names": ["mert", "m-hubert"],
|
||||
"ssl_latent_dims": [1024, 768],
|
||||
"patch_size": [16, 1],
|
||||
"max_height": 16,
|
||||
"max_width": 32768,
|
||||
"lyric_encoder_vocab_size": 6693,
|
||||
"lyric_hidden_size": 1024
|
||||
}
|
||||
1529
models/customer_attention_processor.py
Normal file
1529
models/customer_attention_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
1070
models/lyrics_utils/lyric_encoder.py
Normal file
1070
models/lyrics_utils/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
66
models/lyrics_utils/lyric_normalizer.py
Normal file
66
models/lyrics_utils/lyric_normalizer.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import re
|
||||
from opencc import OpenCC
|
||||
|
||||
|
||||
t2s_converter = OpenCC('t2s')
|
||||
s2t_converter = OpenCC('s2t')
|
||||
|
||||
|
||||
EMOJI_PATTERN = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # Emoticons
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
|
||||
# 创建一个翻译表,用于替换和移除字符
|
||||
TRANSLATION_TABLE = str.maketrans({
|
||||
'-': ' ', # 将 '-' 替换为空格
|
||||
',': None,
|
||||
'.': None,
|
||||
',': None,
|
||||
'。': None,
|
||||
'!': None,
|
||||
'!': None,
|
||||
'?': None,
|
||||
'?': None,
|
||||
'…': None,
|
||||
';': None,
|
||||
';': None,
|
||||
':': None,
|
||||
':': None,
|
||||
'\u3000': ' ', # 将全角空格替换为空格
|
||||
})
|
||||
|
||||
# 替换括号中的内容,包括中括号和小括号
|
||||
BACKSLASH_PATTERN = re.compile(r'\(.*?\)|\[.*?\]')
|
||||
|
||||
SPACE_PATTERN = re.compile('(?<!^)\s+(?!$)')
|
||||
|
||||
|
||||
def normalize_text(text, language, strip=True):
|
||||
"""
|
||||
对文本进行标准化处理,去除标点符号,转为小写(如果适用)
|
||||
"""
|
||||
# Step 1: 替换 '-' 为 ' ' 并移除标点符号
|
||||
text = text.translate(TRANSLATION_TABLE)
|
||||
|
||||
# Step 2: 移除表情符号
|
||||
text = EMOJI_PATTERN.sub('', text)
|
||||
|
||||
# Step 3: 连续空白字符替换为单个空格,首位除外
|
||||
text = SPACE_PATTERN.sub(' ', text)
|
||||
|
||||
# Step 4: 去除首尾空白字符(如果需要)
|
||||
if strip:
|
||||
text = text.strip()
|
||||
|
||||
# Step 5: 转为小写
|
||||
text = text.lower()
|
||||
|
||||
# Step 6: 多语言转换
|
||||
if language == "zh":
|
||||
text = t2s_converter.convert(text)
|
||||
if language == "yue":
|
||||
text = s2t_converter.convert(text)
|
||||
# 其他语言根据需要添加
|
||||
return text
|
||||
883
models/lyrics_utils/lyric_tokenizer.py
Normal file
883
models/lyrics_utils/lyric_tokenizer.py
Normal file
@@ -0,0 +1,883 @@
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from functools import cached_property
|
||||
|
||||
import pypinyin
|
||||
import torch
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
from spacy.lang.ar import Arabic
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.es import Spanish
|
||||
from spacy.lang.ja import Japanese
|
||||
from spacy.lang.zh import Chinese
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from .zh_num2words import TextNorm as zh_num2words
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
|
||||
#copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/tokenizer.py
|
||||
def get_spacy_lang(lang):
|
||||
if lang == "zh":
|
||||
return Chinese()
|
||||
elif lang == "ja":
|
||||
return Japanese()
|
||||
elif lang == "ar":
|
||||
return Arabic()
|
||||
elif lang == "es":
|
||||
return Spanish()
|
||||
else:
|
||||
# For most languages, Enlish does the job
|
||||
return English()
|
||||
|
||||
|
||||
def split_sentence(text, lang, text_split_length=250):
|
||||
"""Preprocess the input text"""
|
||||
text_splits = []
|
||||
if text_split_length is not None and len(text) >= text_split_length:
|
||||
text_splits.append("")
|
||||
nlp = get_spacy_lang(lang)
|
||||
nlp.add_pipe("sentencizer")
|
||||
doc = nlp(text)
|
||||
for sentence in doc.sents:
|
||||
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||
# if the last sentence + the current sentence is less than the text_split_length
|
||||
# then add the current sentence to the last sentence
|
||||
text_splits[-1] += " " + str(sentence)
|
||||
text_splits[-1] = text_splits[-1].lstrip()
|
||||
elif len(str(sentence)) > text_split_length:
|
||||
# if the current sentence is greater than the text_split_length
|
||||
for line in textwrap.wrap(
|
||||
str(sentence),
|
||||
width=text_split_length,
|
||||
drop_whitespace=True,
|
||||
break_on_hyphens=False,
|
||||
tabsize=1,
|
||||
):
|
||||
text_splits.append(str(line))
|
||||
else:
|
||||
text_splits.append(str(sentence))
|
||||
|
||||
if len(text_splits) > 1:
|
||||
if text_splits[0] == "":
|
||||
del text_splits[0]
|
||||
else:
|
||||
text_splits = [text.lstrip()]
|
||||
|
||||
return text_splits
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
],
|
||||
"es": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("sra", "señora"),
|
||||
("sr", "señor"),
|
||||
("dr", "doctor"),
|
||||
("dra", "doctora"),
|
||||
("st", "santo"),
|
||||
("co", "compañía"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitada"),
|
||||
]
|
||||
],
|
||||
"fr": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mme", "madame"),
|
||||
("mr", "monsieur"),
|
||||
("dr", "docteur"),
|
||||
("st", "saint"),
|
||||
("co", "compagnie"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitée"),
|
||||
]
|
||||
],
|
||||
"de": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("fr", "frau"),
|
||||
("dr", "doktor"),
|
||||
("st", "sankt"),
|
||||
("co", "firma"),
|
||||
("jr", "junior"),
|
||||
]
|
||||
],
|
||||
"pt": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("sra", "senhora"),
|
||||
("sr", "senhor"),
|
||||
("dr", "doutor"),
|
||||
("dra", "doutora"),
|
||||
("st", "santo"),
|
||||
("co", "companhia"),
|
||||
("jr", "júnior"),
|
||||
("ltd", "limitada"),
|
||||
]
|
||||
],
|
||||
"it": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# ("sig.ra", "signora"),
|
||||
("sig", "signore"),
|
||||
("dr", "dottore"),
|
||||
("st", "santo"),
|
||||
("co", "compagnia"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitata"),
|
||||
]
|
||||
],
|
||||
"pl": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("p", "pani"),
|
||||
("m", "pan"),
|
||||
("dr", "doktor"),
|
||||
("sw", "święty"),
|
||||
("jr", "junior"),
|
||||
]
|
||||
],
|
||||
"ar": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# There are not many common abbreviations in Arabic as in English.
|
||||
]
|
||||
],
|
||||
"zh": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
]
|
||||
],
|
||||
"cs": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dr", "doktor"), # doctor
|
||||
("ing", "inženýr"), # engineer
|
||||
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
||||
# Other abbreviations would be specialized and not as common.
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("г-жа", "госпожа"), # Mrs.
|
||||
("г-н", "господин"), # Mr.
|
||||
("д-р", "доктор"), # doctor
|
||||
# Other abbreviations are less common or specialized.
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dhr", "de heer"), # Mr.
|
||||
("mevr", "mevrouw"), # Mrs.
|
||||
("dr", "dokter"), # doctor
|
||||
("jhr", "jonkheer"), # young lord or nobleman
|
||||
# Dutch uses more abbreviations, but these are the most common ones.
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("b", "bay"), # Mr.
|
||||
("byk", "büyük"), # büyük
|
||||
("dr", "doktor"), # doctor
|
||||
# Add other Turkish abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
"hu": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dr", "doktor"), # doctor
|
||||
("b", "bácsi"), # Mr.
|
||||
("nőv", "nővér"), # nurse
|
||||
# Add other Hungarian abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
"ko": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang="en"):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
_symbols_multilingual = {
|
||||
"en": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
("@", " at "),
|
||||
("%", " percent "),
|
||||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree "),
|
||||
]
|
||||
],
|
||||
"es": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " y "),
|
||||
("@", " arroba "),
|
||||
("%", " por ciento "),
|
||||
("#", " numeral "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " grados "),
|
||||
]
|
||||
],
|
||||
"fr": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " et "),
|
||||
("@", " arobase "),
|
||||
("%", " pour cent "),
|
||||
("#", " dièse "),
|
||||
("$", " dollar "),
|
||||
("£", " livre "),
|
||||
("°", " degrés "),
|
||||
]
|
||||
],
|
||||
"de": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " und "),
|
||||
("@", " at "),
|
||||
("%", " prozent "),
|
||||
("#", " raute "),
|
||||
("$", " dollar "),
|
||||
("£", " pfund "),
|
||||
("°", " grad "),
|
||||
]
|
||||
],
|
||||
"pt": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
("@", " arroba "),
|
||||
("%", " por cento "),
|
||||
("#", " cardinal "),
|
||||
("$", " dólar "),
|
||||
("£", " libra "),
|
||||
("°", " graus "),
|
||||
]
|
||||
],
|
||||
"it": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
("@", " chiocciola "),
|
||||
("%", " per cento "),
|
||||
("#", " cancelletto "),
|
||||
("$", " dollaro "),
|
||||
("£", " sterlina "),
|
||||
("°", " gradi "),
|
||||
]
|
||||
],
|
||||
"pl": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " i "),
|
||||
("@", " małpa "),
|
||||
("%", " procent "),
|
||||
("#", " krzyżyk "),
|
||||
("$", " dolar "),
|
||||
("£", " funt "),
|
||||
("°", " stopnie "),
|
||||
]
|
||||
],
|
||||
"ar": [
|
||||
# Arabic
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " و "),
|
||||
("@", " على "),
|
||||
("%", " في المئة "),
|
||||
("#", " رقم "),
|
||||
("$", " دولار "),
|
||||
("£", " جنيه "),
|
||||
("°", " درجة "),
|
||||
]
|
||||
],
|
||||
"zh": [
|
||||
# Chinese
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " 和 "),
|
||||
("@", " 在 "),
|
||||
("%", " 百分之 "),
|
||||
("#", " 号 "),
|
||||
("$", " 美元 "),
|
||||
("£", " 英镑 "),
|
||||
("°", " 度 "),
|
||||
]
|
||||
],
|
||||
"cs": [
|
||||
# Czech
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " a "),
|
||||
("@", " na "),
|
||||
("%", " procento "),
|
||||
("#", " křížek "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " stupně "),
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
# Russian
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " и "),
|
||||
("@", " собака "),
|
||||
("%", " процентов "),
|
||||
("#", " номер "),
|
||||
("$", " доллар "),
|
||||
("£", " фунт "),
|
||||
("°", " градус "),
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
# Dutch
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " en "),
|
||||
("@", " bij "),
|
||||
("%", " procent "),
|
||||
("#", " hekje "),
|
||||
("$", " dollar "),
|
||||
("£", " pond "),
|
||||
("°", " graden "),
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " ve "),
|
||||
("@", " at "),
|
||||
("%", " yüzde "),
|
||||
("#", " diyez "),
|
||||
("$", " dolar "),
|
||||
("£", " sterlin "),
|
||||
("°", " derece "),
|
||||
]
|
||||
],
|
||||
"hu": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " és "),
|
||||
("@", " kukac "),
|
||||
("%", " százalék "),
|
||||
("#", " kettőskereszt "),
|
||||
("$", " dollár "),
|
||||
("£", " font "),
|
||||
("°", " fok "),
|
||||
]
|
||||
],
|
||||
"ko": [
|
||||
# Korean
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " 그리고 "),
|
||||
("@", " 에 "),
|
||||
("%", " 퍼센트 "),
|
||||
("#", " 번호 "),
|
||||
("$", " 달러 "),
|
||||
("£", " 파운드 "),
|
||||
("°", " 도 "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_symbols_multilingual(text, lang="en"):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
_ordinal_re = {
|
||||
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
|
||||
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
|
||||
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
|
||||
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
|
||||
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
|
||||
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
|
||||
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
|
||||
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
||||
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
||||
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
||||
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
||||
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
|
||||
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
|
||||
def _expand_decimal_point(m, lang="en"):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
|
||||
def _expand_currency(m, lang="en", currency="USD"):
|
||||
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
"es": " con ",
|
||||
"fr": " et ",
|
||||
"de": " und ",
|
||||
"pt": " e ",
|
||||
"it": " e ",
|
||||
"pl": ", ",
|
||||
"cs": ", ",
|
||||
"ru": ", ",
|
||||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
"hu": ", ",
|
||||
"ko": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
last_and = full_amount.rfind(and_equivalents[lang])
|
||||
if last_and != -1:
|
||||
full_amount = full_amount[:last_and]
|
||||
|
||||
return full_amount
|
||||
|
||||
|
||||
def _expand_ordinal(m, lang="en"):
|
||||
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
|
||||
def _expand_number(m, lang="en"):
|
||||
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang == "zh":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
if lang in ["en", "ru"]:
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||
except:
|
||||
pass
|
||||
if lang != "tr":
|
||||
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
text = lowercase(text)
|
||||
try:
|
||||
text = expand_numbers_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_abbreviations_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_symbols_multilingual(text, lang=lang)
|
||||
except:
|
||||
pass
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_transliterate(text):
|
||||
return "".join(
|
||||
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
||||
)
|
||||
|
||||
|
||||
def japanese_cleaners(text, katsu):
|
||||
text = katsu.romaji(text)
|
||||
text = lowercase(text)
|
||||
return text
|
||||
|
||||
|
||||
def korean_transliterate(text):
|
||||
r = Transliter(academic)
|
||||
return r.translit(text)
|
||||
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vocab.json")
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
self.char_limits = {
|
||||
"en": 10000,
|
||||
"de": 253,
|
||||
"fr": 273,
|
||||
"es": 239,
|
||||
"it": 213,
|
||||
"pt": 203,
|
||||
"pl": 224,
|
||||
"zh": 82,
|
||||
"ar": 166,
|
||||
"cs": 186,
|
||||
"ru": 182,
|
||||
"nl": 251,
|
||||
"tr": 226,
|
||||
"ja": 71,
|
||||
"hu": 224,
|
||||
"ko": 95,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def katsu(self):
|
||||
import cutlet
|
||||
|
||||
return cutlet.Cutlet()
|
||||
|
||||
def check_input_length(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
limit = self.char_limits.get(lang, 250)
|
||||
# if len(txt) > limit:
|
||||
# print(
|
||||
# f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
||||
# )
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang == "zh":
|
||||
txt = chinese_transliterate(txt)
|
||||
if lang == "ko":
|
||||
txt = korean_transliterate(txt)
|
||||
elif lang == "ja":
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
elif lang == "hi":
|
||||
# @manmay will implement this
|
||||
txt = basic_cleaners(txt)
|
||||
else:
|
||||
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def decode(self, seq, skip_special_tokens=False):
|
||||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
|
||||
txt = txt.replace("[SPACE]", " ")
|
||||
txt = txt.replace("[STOP]", "")
|
||||
# txt = txt.replace("[UNK]", "")
|
||||
return txt
|
||||
|
||||
|
||||
#copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936
|
||||
def batch_decode(
|
||||
self,
|
||||
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of decoded sentences.
|
||||
"""
|
||||
return [
|
||||
self.decode(seq)
|
||||
for seq in sequences
|
||||
]
|
||||
|
||||
#https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/trainer/dataset.py#L202
|
||||
# def pad(self):
|
||||
|
||||
def __len__(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
def get_number_tokens(self):
|
||||
return max(self.tokenizer.get_vocab().values()) + 1
|
||||
|
||||
|
||||
def test_expand_numbers_multilingual():
|
||||
test_cases = [
|
||||
# English
|
||||
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
|
||||
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
|
||||
("This is a 1st test", "This is a first test", "en"),
|
||||
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
|
||||
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
|
||||
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
|
||||
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
|
||||
# French
|
||||
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
|
||||
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
|
||||
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
|
||||
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
|
||||
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
|
||||
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
|
||||
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
|
||||
# German
|
||||
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
|
||||
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
|
||||
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
|
||||
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
|
||||
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
|
||||
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
|
||||
# Spanish
|
||||
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
|
||||
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
|
||||
("Este es un 1er test", "Este es un primero test", "es"),
|
||||
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
|
||||
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
|
||||
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
|
||||
# Italian
|
||||
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
|
||||
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
|
||||
("Questo è un 1° test", "Questo è un primo test", "it"),
|
||||
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
|
||||
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
|
||||
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
|
||||
# Portuguese
|
||||
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
|
||||
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
|
||||
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
|
||||
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
|
||||
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
|
||||
(
|
||||
"Isso custará 20,15€ senhor.",
|
||||
"Isso custará vinte euros e quinze cêntimos senhor.",
|
||||
"pt",
|
||||
), # "cêntimos" should be "centavos" num2words issue
|
||||
# Polish
|
||||
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
|
||||
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
|
||||
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
|
||||
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
|
||||
# Arabic
|
||||
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
|
||||
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
|
||||
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
|
||||
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
|
||||
# Czech
|
||||
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
|
||||
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
|
||||
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
|
||||
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
|
||||
# Russian
|
||||
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
|
||||
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
|
||||
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
|
||||
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
|
||||
# Dutch
|
||||
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
|
||||
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
|
||||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||
# Chinese (Simplified)
|
||||
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||
("有50名士兵", "有五十名士兵", "zh"),
|
||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||
# Turkish
|
||||
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
|
||||
("50 asker vardı.", "elli asker vardı.", "tr"),
|
||||
("Bu 1. test", "Bu birinci test", "tr"),
|
||||
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
|
||||
# Hungarian
|
||||
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
|
||||
("50 katona volt.", "ötven katona volt.", "hu"),
|
||||
("Ez az 1. teszt", "Ez az első teszt", "hu"),
|
||||
# Korean
|
||||
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
|
||||
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
|
||||
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
|
||||
]
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_numbers_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
def test_abbreviations_multilingual():
|
||||
test_cases = [
|
||||
# English
|
||||
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
|
||||
("Dr. Jones is here.", "doctor Jones is here.", "en"),
|
||||
# Spanish
|
||||
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
|
||||
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
|
||||
# French
|
||||
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
|
||||
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
|
||||
# German
|
||||
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
|
||||
# Portuguese
|
||||
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
|
||||
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
|
||||
# Italian
|
||||
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
|
||||
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
|
||||
# Polish
|
||||
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
|
||||
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
|
||||
# Czech
|
||||
("P. Novák", "pan Novák", "cs"),
|
||||
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
|
||||
# Dutch
|
||||
("Dhr. Jansen", "de heer Jansen", "nl"),
|
||||
("Mevr. de Vries", "mevrouw de Vries", "nl"),
|
||||
# Russian
|
||||
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
|
||||
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
|
||||
# Turkish
|
||||
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
|
||||
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
|
||||
# Hungarian
|
||||
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
|
||||
]
|
||||
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_abbreviations_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
def test_symbols_multilingual():
|
||||
test_cases = [
|
||||
("I have 14% battery", "I have 14 percent battery", "en"),
|
||||
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
|
||||
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
|
||||
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
|
||||
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
|
||||
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
|
||||
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
|
||||
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
|
||||
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
|
||||
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
|
||||
("Я буду @ дома", "Я буду собака дома", "ru"),
|
||||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||
]
|
||||
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_symbols_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_expand_numbers_multilingual()
|
||||
test_abbreviations_multilingual()
|
||||
test_symbols_multilingual()
|
||||
15535
models/lyrics_utils/vocab.json
Normal file
15535
models/lyrics_utils/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
1209
models/lyrics_utils/zh_num2words.py
Normal file
1209
models/lyrics_utils/zh_num2words.py
Normal file
File diff suppressed because it is too large
Load Diff
8
models/singer_presets.json
Normal file
8
models/singer_presets.json
Normal file
@@ -0,0 +1,8 @@
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Dump Singer 1",
|
||||
"description": "This is the first singer preset",
|
||||
"spk_emb_path": "path/to/singer1/spk_emb"
|
||||
}
|
||||
]
|
||||
482
models/transformer_sana_text2music_large_dcae_0319.py
Normal file
482
models/transformer_sana_text2music_large_dcae_0319.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging, is_torch_version
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
|
||||
|
||||
try:
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
except ImportError:
|
||||
from attention import LinearTransformerBlock, t2i_modulate
|
||||
from lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
|
||||
super().__init__()
|
||||
self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias),
|
||||
torch.nn.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True),
|
||||
nn.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer1DModelOutput(BaseOutput):
|
||||
|
||||
sample: torch.FloatTensor
|
||||
proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
|
||||
|
||||
|
||||
class ACEFlowBaseModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
|
||||
self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(self.inner_dim, projector_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(projector_dim, projector_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(projector_dim, ssl_dim),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.ssl_latent_dims = ssl_latent_dims
|
||||
self.ssl_encoder_depths = ssl_encoder_depths
|
||||
|
||||
self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
|
||||
self.ssl_names = ssl_names
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
for ssl_encoder_depth in self.ssl_encoder_depths:
|
||||
if index_block == ssl_encoder_depth:
|
||||
inner_hidden_states.append(hidden_states)
|
||||
|
||||
proj_losses = []
|
||||
if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
|
||||
|
||||
for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
|
||||
if ssl_hidden_state is None:
|
||||
continue
|
||||
# 1. N x T x D1 -> N x D x D2
|
||||
est_ssl_hidden_state = projector(inner_hidden_state)
|
||||
# 3. projection loss
|
||||
bs = inner_hidden_state.shape[0]
|
||||
proj_loss = 0.0
|
||||
for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
|
||||
# 2. interpolate
|
||||
z_tilde = F.interpolate(z_tilde.unsqueeze(0).transpose(1, 2), size=len(z), mode='linear', align_corners=False).transpose(1, 2).squeeze(0)
|
||||
|
||||
z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
|
||||
z = torch.nn.functional.normalize(z, dim=-1)
|
||||
# T x d -> T x 1 -> 1
|
||||
target = torch.ones(z.shape[0], device=z.device)
|
||||
proj_loss += self.cosine_loss(z, z_tilde, target)
|
||||
proj_losses.append((ssl_name, proj_loss / bs))
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
if not return_dict:
|
||||
return (output, proj_losses)
|
||||
|
||||
return Transformer1DModelOutput(sample=output, proj_losses=proj_losses)
|
||||
|
||||
# @torch.compile
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
ssl_hidden_states=ssl_hidden_states,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return output
|
||||
2
music_dcae/__init__.py
Normal file
2
music_dcae/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .music_dcae import MusicDCAE
|
||||
from .music_log_mel import LogMelSpectrogram, get_mel_transform
|
||||
92
music_dcae/balancer.py
Normal file
92
music_dcae/balancer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
|
||||
|
||||
class Balancer:
|
||||
"""
|
||||
Balancer for dynamically re-weighting multiple losses based on gradient norms.
|
||||
|
||||
Args:
|
||||
weights (dict): Predefined weights for each loss.
|
||||
Example: {"mse_loss": 1.0, "adv_loss": 1.0}
|
||||
ema_decay (float): Decay factor for exponential moving average (default: 0.99).
|
||||
epsilon (float): Small value to avoid division by zero (default: 1e-8).
|
||||
"""
|
||||
def __init__(self, weights, ema_decay=0.99, epsilon=1e-8):
|
||||
self.weights = weights
|
||||
self.ema_decay = ema_decay
|
||||
self.epsilon = epsilon
|
||||
self.ema_values = {key: 0.0 for key in weights} # Initialize EMA for each loss
|
||||
|
||||
def forward(self, losses, grad_inputs):
|
||||
"""
|
||||
Re-weight the input losses based on gradient norms and return a combined loss.
|
||||
|
||||
Args:
|
||||
losses (dict): Dictionary of losses with names as keys and loss tensors as values.
|
||||
Example: {"mse_loss": mse_loss, "adv_loss": adv_loss}
|
||||
grad_inputs (dict): Dictionary of inputs for autograd.grad corresponding to each loss.
|
||||
Example: {"mse_loss": recon_mels, "adv_loss": recon_mels}
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Combined weighted loss.
|
||||
"""
|
||||
# Validate inputs
|
||||
if set(losses.keys()) != set(grad_inputs.keys()):
|
||||
raise ValueError("Keys of losses and grad_inputs must match.")
|
||||
|
||||
norm_values = {}
|
||||
|
||||
# Compute gradient norms for each loss
|
||||
for name, loss in losses.items():
|
||||
loss_grad, = grad(loss.mean(), [grad_inputs[name]], create_graph=True)
|
||||
dims = tuple(range(1, loss_grad.ndim)) # Exclude batch dimension
|
||||
grad_norm = torch.linalg.vector_norm(loss_grad, ord=2, dim=dims).mean()
|
||||
|
||||
# Update EMA for the gradient norm
|
||||
if self.ema_values[name] == 0.0:
|
||||
self.ema_values[name] = grad_norm.item()
|
||||
else:
|
||||
self.ema_values[name] = (
|
||||
self.ema_values[name] * self.ema_decay + grad_norm.item() * (1 - self.ema_decay)
|
||||
)
|
||||
|
||||
# Normalize gradient norm
|
||||
norm_values[name] = grad_norm / (self.ema_values[name] + self.epsilon)
|
||||
|
||||
# Compute dynamic weights
|
||||
total_norm = sum(norm_values.values())
|
||||
dynamic_weights = {name: norm / total_norm for name, norm in norm_values.items()}
|
||||
|
||||
# Combine losses with dynamic weights
|
||||
loss = 0.0
|
||||
log_weights = {}
|
||||
for name in losses:
|
||||
loss = loss + self.weights[name] * dynamic_weights[name] * losses[name]
|
||||
log_weights[f"{name}_weight"] = dynamic_weights[name]
|
||||
return loss, log_weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
mel_real = torch.randn(1, 80, 10)
|
||||
generator = torch.nn.Linear(10, 10)
|
||||
recon_mels = generator(mel_real)
|
||||
discriminator = torch.nn.Linear(10, 1)
|
||||
disc_out = discriminator(recon_mels)
|
||||
|
||||
mse_loss = torch.nn.functional.mse_loss(recon_mels, mel_real).mean()
|
||||
adv_loss = torch.nn.functional.softplus(-disc_out).mean()
|
||||
losses = {"mse_loss": mse_loss, "adv_loss": adv_loss}
|
||||
grad_inputs = {"mse_loss": recon_mels, "adv_loss": recon_mels}
|
||||
print("losses", losses)
|
||||
# Define predefined weights for each loss
|
||||
weights = {"mse_loss": 1.0, "adv_loss": 1.0}
|
||||
|
||||
# Initialize balancer
|
||||
balancer = Balancer(weights)
|
||||
|
||||
# Forward pass
|
||||
loss, log_weights = balancer.forward(losses, grad_inputs)
|
||||
print("Combined Loss:", loss)
|
||||
print("Dynamic Weights:", log_weights)
|
||||
69
music_dcae/config_f8c8.json
Normal file
69
music_dcae/config_f8c8.json
Normal file
@@ -0,0 +1,69 @@
|
||||
{
|
||||
"_class_name": "AutoencoderDC",
|
||||
"_diffusers_version": "0.32.1",
|
||||
"_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
|
||||
"attention_head_dim": 32,
|
||||
"decoder_act_fns": "silu",
|
||||
"decoder_block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock"
|
||||
],
|
||||
"decoder_layers_per_block": [
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"decoder_norm_types": "rms_norm",
|
||||
"decoder_qkv_multiscales": [
|
||||
[],
|
||||
[],
|
||||
[
|
||||
5
|
||||
],
|
||||
[
|
||||
5
|
||||
]
|
||||
],
|
||||
"downsample_block_type": "Conv",
|
||||
"encoder_block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024
|
||||
],
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock"
|
||||
],
|
||||
"encoder_layers_per_block": [
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"encoder_qkv_multiscales": [
|
||||
[],
|
||||
[],
|
||||
[
|
||||
5
|
||||
],
|
||||
[
|
||||
5
|
||||
]
|
||||
],
|
||||
"in_channels": 2,
|
||||
"latent_channels": 8,
|
||||
"scaling_factor": 0.41407,
|
||||
"upsample_block_type": "interpolate"
|
||||
}
|
||||
124
music_dcae/distrib.py
Normal file
124
music_dcae/distrib.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Torch distributed utilities."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def rank():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def world_size():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_distributed():
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||
if is_distributed():
|
||||
return torch.distributed.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def _is_complex_or_float(tensor):
|
||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||
|
||||
|
||||
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||
# utility function to check that the number of params in all workers is the same,
|
||||
# and thus avoid a deadlock with distributed all reduce.
|
||||
if not is_distributed() or not params:
|
||||
return
|
||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||
all_reduce(tensor)
|
||||
if tensor.item() != len(params) * world_size():
|
||||
# If not all the workers have the same number, for at least one of them,
|
||||
# this inequality will be verified.
|
||||
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
||||
"at least one worker has a different one.")
|
||||
|
||||
|
||||
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||
"""Broadcast the tensors from the given parameters to all workers.
|
||||
This can be used to ensure that all workers have the same model to start with.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||
_check_number_of_params(tensors)
|
||||
handles = []
|
||||
for tensor in tensors:
|
||||
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||
handles.append(handle)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
|
||||
def sync_buffer(buffers, average=True):
|
||||
"""
|
||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for buffer in buffers:
|
||||
if torch.is_floating_point(buffer.data):
|
||||
if average:
|
||||
handle = torch.distributed.all_reduce(
|
||||
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(
|
||||
buffer.data, src=0, async_op=True)
|
||||
handles.append((buffer, handle))
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
"""
|
||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||
on any black magic. For simple models it can also be as fast.
|
||||
Just call this on your model parameters after the call to backward!
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
handle = torch.distributed.all_reduce(
|
||||
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
handles.append((p, handle))
|
||||
for p, handle in handles:
|
||||
handle.wait()
|
||||
p.grad.data /= world_size()
|
||||
|
||||
|
||||
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
||||
"""Average a dictionary of metrics across all workers, using the optional
|
||||
`count` as unnormalized weight.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return metrics
|
||||
keys, values = zip(*metrics.items())
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||
tensor *= count
|
||||
all_reduce(tensor)
|
||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||
return dict(zip(keys, averaged))
|
||||
78
music_dcae/music_dcae.py
Normal file
78
music_dcae/music_dcae.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import AutoencoderDC
|
||||
import json
|
||||
|
||||
|
||||
DEFAULT_CONFIG_PATH = "/root/sag_train/music_dcae/config_f32c32_large.json"
|
||||
|
||||
class MusicDCAE(nn.Module):
|
||||
def __init__(self, config_path=DEFAULT_CONFIG_PATH):
|
||||
super(MusicDCAE, self).__init__()
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
self.dcae = AutoencoderDC(**config)
|
||||
|
||||
def encode(self, x):
|
||||
return self.dcae.encode(x).latent
|
||||
|
||||
def decode(self, latent):
|
||||
sample = self.dcae.decode(latent).sample
|
||||
return sample
|
||||
|
||||
def forward(self, x):
|
||||
sample = self.dcae(x).sample
|
||||
return sample
|
||||
|
||||
def return_middle_layers(self):
|
||||
last_down_block = self.dcae.encoder.down_blocks[-1]
|
||||
encoder_conv_out = self.dcae.encoder.conv_out
|
||||
decoder_conv_in = self.dcae.decoder.conv_in
|
||||
decoder_up_blocks = self.dcae.decoder.up_blocks[0]
|
||||
middle_layers = [last_down_block, encoder_conv_out, decoder_conv_in, decoder_up_blocks]
|
||||
return middle_layers
|
||||
|
||||
def return_head_layers(self):
|
||||
decoder_up_blocks = self.dcae.decoder.up_blocks[-1]
|
||||
conv_out = self.dcae.decoder.conv_out
|
||||
head_layers = [decoder_up_blocks, conv_out]
|
||||
return head_layers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = MusicDCAE("/root/sag_train/music_dcae/config_f8c8_large.json")
|
||||
|
||||
x = torch.randn(1, 2, 128, 1024)
|
||||
# mask = None
|
||||
# if mask is None:
|
||||
# mask = torch.ones(x.shape[0], 1, x.shape[2], x.shape[3]).to(x.device)
|
||||
# # N x 1024
|
||||
# elif len(mask.shape) == 2:
|
||||
# mask = mask.unsqueeze(1).unsqueeze(1).float()
|
||||
# mask = mask.repeat(1, 1, x.shape[2], 1)
|
||||
latent = model.encode(x)
|
||||
print("latent shape: ", latent.shape)
|
||||
y = model(x)
|
||||
print("y", y.shape)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"模型参数总数: {total_params / 1e6:.2f}M")
|
||||
|
||||
# middle_layers = model.return_middle_layers()
|
||||
# middle_params_count = 0
|
||||
# for layer in middle_layers:
|
||||
# for name, param in layer.named_parameters():
|
||||
# layer_param_count = param.numel()
|
||||
# middle_params_count += layer_param_count
|
||||
# print(f"{name}: {param.shape}, 参数量: {layer_param_count/1e6:.2f}M")
|
||||
|
||||
# print(f"中间层总参数量: {middle_params_count/1e6:.2f}M")
|
||||
|
||||
# head_layers = model.return_head_layers()
|
||||
# head_params_count = 0
|
||||
# for layer in head_layers:
|
||||
# for name, param in layer.named_parameters():
|
||||
# layer_param_count = param.numel()
|
||||
# head_params_count += layer_param_count
|
||||
# print(f"{name}: {param.shape}, 参数量: {layer_param_count/1e6:.2f}M")
|
||||
|
||||
# print(f"头部层总参数量: {head_params_count/1e6:.2f}M")
|
||||
155
music_dcae/music_dcae_pipeline.py
Normal file
155
music_dcae/music_dcae_pipeline.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import AutoencoderDC
|
||||
import torchaudio
|
||||
import torchvision.transforms as transforms
|
||||
import torchaudio
|
||||
|
||||
try:
|
||||
from .music_log_mel import get_mel_transform
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
except ImportError:
|
||||
from music_log_mel import get_mel_transform
|
||||
from music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
import os
|
||||
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
|
||||
VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder.pt")
|
||||
|
||||
|
||||
class MusicDCAE(nn.Module):
|
||||
def __init__(self, pretrained_path=DEFAULT_PRETRAINED_PATH, encoder_only=False, source_sample_rate=None):
|
||||
super(MusicDCAE, self).__init__()
|
||||
dcae = AutoencoderDC.from_pretrained(pretrained_path)
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
self.mel_transform = get_mel_transform()
|
||||
if encoder_only:
|
||||
self.encoder = dcae.encoder
|
||||
else:
|
||||
self.encoder = dcae.encoder
|
||||
self.decoder = dcae.decoder
|
||||
self.vocoder = ADaMoSHiFiGANV1(VOCODER_PRETRAINED_PATH).eval()
|
||||
|
||||
if source_sample_rate is None:
|
||||
source_sample_rate = 48000
|
||||
|
||||
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||
self.mel_chunk_size = 1024
|
||||
self.time_dimention_multiple = 8
|
||||
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audios, audio_lengths=None, sr=None):
|
||||
if audio_lengths is None:
|
||||
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||
audio_lengths = audio_lengths.to(audios.device)
|
||||
|
||||
# audios: N x 2 x T, 48kHz
|
||||
device = audios.device
|
||||
dtype = audios.dtype
|
||||
|
||||
if sr is None:
|
||||
sr = 48000
|
||||
resampler = self.resampler
|
||||
else:
|
||||
resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
|
||||
|
||||
audio = resampler(audios)
|
||||
|
||||
max_audio_len = audio.shape[-1]
|
||||
if max_audio_len % (8 * 512) != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||
|
||||
mels = self.forward_mel(audio)
|
||||
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
mels = self.transform(mels)
|
||||
latents = []
|
||||
for mel in mels:
|
||||
latent = self.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
latents = latents / self.scale_factor + self.shift_factor
|
||||
|
||||
mels = []
|
||||
|
||||
for latent in latents:
|
||||
mel = self.decoder(latent.unsqueeze(0))
|
||||
mels.append(mel)
|
||||
mels = torch.cat(mels, dim=0)
|
||||
|
||||
mels = mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
bsz, channels, num_mel, mel_width = mels.shape
|
||||
pred_wavs = []
|
||||
for i in range(bsz):
|
||||
mel = mels[i]
|
||||
wav = self.vocoder.decode(mel).squeeze(1)
|
||||
pred_wavs.append(wav)
|
||||
|
||||
pred_wavs = torch.stack(pred_wavs)
|
||||
|
||||
if sr is not None:
|
||||
resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
pred_wavs = [resampler(wav) for wav in pred_wavs]
|
||||
else:
|
||||
sr = 44100
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||
return sr, pred_wavs, latents, latent_lengths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
audio, sr = torchaudio.load("/root/data/repo/gongjunmin/sag_train/orig2.wav")
|
||||
audio_lengths = torch.tensor([audio.shape[1]])
|
||||
audios = audio.unsqueeze(0)
|
||||
|
||||
# test encode only
|
||||
model = MusicDCAE()
|
||||
# latents, latent_lengths = model.encode(audios, audio_lengths)
|
||||
# print("latents shape: ", latents.shape)
|
||||
# print("latent_lengths: ", latent_lengths)
|
||||
|
||||
# test encode and decode
|
||||
sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
|
||||
print("reconstructed wavs: ", pred_wavs[0].shape)
|
||||
print("latents shape: ", latents.shape)
|
||||
print("latent_lengths: ", latent_lengths)
|
||||
print("sr: ", sr)
|
||||
torchaudio.save("/root/data/repo/gongjunmin/sag_train/reconstructed.wav", pred_wavs[0], sr)
|
||||
print("reconstructed wav saved to /root/data/repo/gongjunmin/sag_train/reconstructed.wav")
|
||||
551
music_dcae/music_dcae_refiner.py
Normal file
551
music_dcae/music_dcae_refiner.py
Normal file
@@ -0,0 +1,551 @@
|
||||
from typing import Tuple, Union, Optional, Dict, Any
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_dc import DCUpBlock2d, get_block, RMSNorm, Decoder
|
||||
from diffusers.models.transformers.sana_transformer import SanaTransformerBlock
|
||||
from diffusers.models.embeddings import get_2d_sincos_pos_embed
|
||||
from diffusers.models.normalization import AdaLayerNormSingle, RMSNorm
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.models.unets import UNet2DModel
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: int = 8,
|
||||
attention_head_dim: int = 32,
|
||||
block_out_channels: Tuple[int] = (512, 1024, 2048),
|
||||
layers_per_block: Tuple[int] = (3, 3, 3),
|
||||
block_type: str = "EfficientViTBlock",
|
||||
norm_type: str = "rms_norm",
|
||||
act_fn: str = "silu",
|
||||
qkv_multiscales: tuple = (5,),
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
self.dump_encoder = False
|
||||
if num_blocks == 0:
|
||||
self.dump_encoder = True
|
||||
return
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=True,
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type,
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
act_fn=act_fn,
|
||||
qkv_mutliscales=qkv_multiscales,
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
self.norm_out = RMSNorm(block_out_channels[0], 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.dump_encoder:
|
||||
return hidden_states
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
i = 0
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
i += 1
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D Image to Patch Embedding with support for SD3 cropping.
|
||||
|
||||
Args:
|
||||
height (`int`, defaults to `224`): The height of the image.
|
||||
width (`int`, defaults to `224`): The width of the image.
|
||||
patch_size (`int`, defaults to `16`): The size of the patches.
|
||||
in_channels (`int`, defaults to `3`): The number of input channels.
|
||||
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
||||
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
||||
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
||||
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
||||
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
||||
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
||||
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=128,
|
||||
patch_size=(16,1),
|
||||
in_channels=16,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
interpolation_scale=1,
|
||||
pos_embed_type="sincos",
|
||||
pos_embed_max_size=None, # For SD3 cropping
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_patches = (height // patch_size[0]) * (width // patch_size[1])
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
||||
)
|
||||
if layer_norm:
|
||||
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size[0], width // patch_size[1]
|
||||
self.base_size = height // patch_size[1]
|
||||
self.interpolation_scale = interpolation_scale
|
||||
|
||||
# Calculate positional embeddings based on max size or default
|
||||
if pos_embed_max_size:
|
||||
grid_size = pos_embed_max_size
|
||||
else:
|
||||
grid_size = int(num_patches**0.5)
|
||||
|
||||
if pos_embed_type is None:
|
||||
self.pos_embed = None
|
||||
elif pos_embed_type == "sincos":
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
output_type="pt",
|
||||
)
|
||||
persistent = True if pos_embed_max_size else False
|
||||
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
|
||||
else:
|
||||
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
"""Crops positional embeddings for SD3 compatibility."""
|
||||
if self.pos_embed_max_size is None:
|
||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
if height > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
if width > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
if self.pos_embed_max_size is not None:
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // self.patch_size[0], latent.shape[-1] // self.patch_size[1]
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
if self.pos_embed is None:
|
||||
return latent.to(latent.dtype)
|
||||
# Interpolate or crop positional embeddings as needed
|
||||
if self.pos_embed_max_size:
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
else:
|
||||
if self.height != height or self.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=self.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
device=latent.device,
|
||||
output_type="pt",
|
||||
)
|
||||
pos_embed = pos_embed.float().unsqueeze(0)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
|
||||
class DiTDecoder(ModelMixin, ConfigMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@register_to_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Tuple[int, int] = (16, 128),
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 8,
|
||||
patch_size: Tuple[int, int] = (16, 1),
|
||||
inner_dim: int = 1152,
|
||||
num_attention_heads: int = 36,
|
||||
attention_head_dim: int = 32,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
num_cross_attention_heads: Optional[int] = None,
|
||||
cross_attention_head_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: int = 1,
|
||||
mlp_ratio: float = 2.5,
|
||||
num_layers: int = 12,
|
||||
):
|
||||
super(DiTDecoder, self).__init__()
|
||||
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
||||
self.interpolation_scale = interpolation_scale
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size[0],
|
||||
width=sample_size[1],
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5)
|
||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size[0] * patch_size[1] * out_channels)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
post_patch_height, post_patch_width = height // patch_size[0], width // patch_size[1]
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
|
||||
# 3. Normalization
|
||||
shift, scale = (
|
||||
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
||||
).chunk(2, dim=1)
|
||||
|
||||
# 4. Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_height, post_patch_width, self.config.patch_size[0], self.config.patch_size[1], -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
||||
output = hidden_states.reshape(batch_size, -1, post_patch_height * patch_size[0], post_patch_width * patch_size[1])
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
class MusicDcaeRefiner(ModelMixin, ConfigMixin):
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
attention_head_dim: int = 32,
|
||||
block_out_channels: Tuple[int] = (512, 1024, 2048),
|
||||
layers_per_block: Tuple[int] = (3, 3, 3),
|
||||
conv_block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
out_channels: int = 8,
|
||||
block_type: str = "EfficientViTBlock",
|
||||
norm_type: str = "rms_norm",
|
||||
act_fn: str = "silu",
|
||||
qkv_multiscales: tuple = (5,),
|
||||
sample_size: Tuple[int, int] = (16, 128),
|
||||
patch_size: Tuple[int, int] = (16, 1),
|
||||
inner_dim: int = 1152,
|
||||
num_attention_heads: int = 36,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
num_cross_attention_heads: Optional[int] = None,
|
||||
cross_attention_head_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: int = 1,
|
||||
mlp_ratio: float = 2.5,
|
||||
num_layers: int = 12,
|
||||
decoder_type: str = "ConvDecoder",
|
||||
|
||||
):
|
||||
super(MusicDcaeRefiner, self).__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
block_type=block_type,
|
||||
norm_type=norm_type,
|
||||
act_fn=act_fn,
|
||||
qkv_multiscales=qkv_multiscales,
|
||||
)
|
||||
if decoder_type == "DiTDecoder":
|
||||
self.decoder = DiTDecoder(
|
||||
sample_size=sample_size,
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
patch_size=patch_size,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
interpolation_scale=interpolation_scale,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_layers=num_layers,
|
||||
)
|
||||
else:
|
||||
self.decoder = UNet2DModel(
|
||||
sample_size=sample_size,
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=conv_block_out_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Optional[int] = None,
|
||||
return_dict: bool = True
|
||||
):
|
||||
encoder_hidden_states = self.encoder(encoder_hidden_states)
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
output = self.decoder(hidden_states, timestep=timestep, return_dict=return_dict)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# f32c32 -> f8c8
|
||||
# model = MusicDcaeRefiner()
|
||||
|
||||
# x = torch.randn(1, 8, 16, 128)
|
||||
# encoder_x = torch.randn(1, 32, 4, 32)
|
||||
# timestep = 0
|
||||
# y = model(x, encoder_x, timestep=timestep)
|
||||
# print("y", y.sample.shape)
|
||||
# total_params = sum(p.numel() for p in model.parameters())
|
||||
# print(f"模型参数总数: {total_params / 1e6:.2f}M")
|
||||
|
||||
# # 分别计算encoder和decoder的参数量
|
||||
# encoder_params_count = sum(p.numel() for p in model.encoder.parameters())
|
||||
# decoder_params_count = sum(p.numel() for p in model.decoder.parameters())
|
||||
# print(f"encoder参数量: {encoder_params_count/1e6:.2f}M")
|
||||
# print(f"decoder参数量: {decoder_params_count/1e6:.2f}M")
|
||||
|
||||
|
||||
# f8c8 -> mel
|
||||
import json
|
||||
with open("music_dcae/config_f8c8_to_mel_refiner.json", "r") as f:
|
||||
config = json.load(f)
|
||||
model = MusicDcaeRefiner(**config)
|
||||
|
||||
x = torch.randn(1, 2, 128, 1024)
|
||||
encoder_x = torch.randn(1, 2, 128, 1024)
|
||||
timestep = 0
|
||||
y = model(x, encoder_x, timestep=timestep)
|
||||
print("y", y.sample.shape)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"模型参数总数: {total_params / 1e6:.2f}M")
|
||||
|
||||
# 分别计算encoder和decoder的参数量
|
||||
encoder_params_count = sum(p.numel() for p in model.encoder.parameters())
|
||||
decoder_params_count = sum(p.numel() for p in model.decoder.parameters())
|
||||
print(f"encoder参数量: {encoder_params_count/1e6:.2f}M")
|
||||
print(f"decoder参数量: {decoder_params_count/1e6:.2f}M")
|
||||
157
music_dcae/music_dcae_vocoder.py
Normal file
157
music_dcae/music_dcae_vocoder.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import AutoencoderDC
|
||||
import json
|
||||
import torchvision.transforms as transforms
|
||||
import torchaudio
|
||||
|
||||
try:
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
except ImportError:
|
||||
from music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
|
||||
DEFAULT_CONFIG_PATH = "/root/sag_train/music_dcae/config_f32c32_large.json"
|
||||
DCAE_PRETRAINED_PATH = "/root/sag_train/checkpoints/music_dcae_f32c32"
|
||||
VOCODER_PRETRAINED_PATH = "/root/sag_train/checkpoints/music_vocoder.pt"
|
||||
|
||||
|
||||
class MusicDCAEVocoder(nn.Module):
|
||||
def __init__(self, config_path=DEFAULT_CONFIG_PATH, pretrained_path=DCAE_PRETRAINED_PATH):
|
||||
super(MusicDCAEVocoder, self).__init__()
|
||||
if pretrained_path is None:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
self.dcae = AutoencoderDC(**config)
|
||||
else:
|
||||
self.dcae = AutoencoderDC.from_pretrained(pretrained_path)
|
||||
self.vocoder = ADaMoSHiFiGANV1(VOCODER_PRETRAINED_PATH)
|
||||
self.freeze_vocoder()
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.target_sr = 44100
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
if audio.shape[0] == 1:
|
||||
audio = torch.cat([audio, audio], dim=0)
|
||||
return audio, sr
|
||||
|
||||
def resample_audio(self, audio, sr=48000):
|
||||
resampler = torchaudio.transforms.Resample(sr, self.target_sr)
|
||||
resampler = resampler.to(audio.device)
|
||||
audio = resampler(audio)
|
||||
return audio
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.vocoder.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
def norm_mel(self, mels):
|
||||
normed_mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
normed_mels = self.transform(normed_mels)
|
||||
return normed_mels
|
||||
|
||||
def denorm_mel(self, normed_mels):
|
||||
mels = normed_mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
return mels
|
||||
|
||||
def encode_latent(self, normed_mels):
|
||||
# N x 2 x 128 x W -> N x C x 128//F x W//F
|
||||
latent = self.dcae.encode(normed_mels).latent
|
||||
return latent
|
||||
|
||||
def decode_mel(self, latent):
|
||||
# N x C x 128//F x W//F -> N x 2 x 128 x W
|
||||
normed_mels = self.dcae.decode(latent).sample
|
||||
return normed_mels
|
||||
|
||||
def decode_audio(self, mels):
|
||||
# mels: N x 2 x 128 x W -> 2N x 128 x W
|
||||
bs = mels.shape[0]
|
||||
mono_mels = mels.reshape(-1, 128, mels.shape[-1])
|
||||
mono_audios = self.vocoder(mono_mels)
|
||||
audios = mono_audios.reshape(bs, 2, -1)
|
||||
return audios
|
||||
|
||||
def encode(self, audios):
|
||||
mels = self.forward_mel(audios)
|
||||
normed_mels = self.norm_mel(mels)
|
||||
latent = self.encode_latent(normed_mels)
|
||||
return latent, mels
|
||||
|
||||
def decode(self, latent):
|
||||
recon_normed_mels = self.decode_mel(latent)
|
||||
recon_mels = self.denorm_mel(recon_normed_mels)
|
||||
recon_audios = self.decode_audio(recon_mels)
|
||||
return recon_audios, recon_mels
|
||||
|
||||
def forward(self, audios):
|
||||
audios_len = audios.shape[-1]
|
||||
latent, mels = self.encode(audios)
|
||||
recon_audios, recon_mels = self.decode(latent)
|
||||
if recon_audios.shape[-1] > audios_len:
|
||||
recon_audios = recon_audios[:, :, :audios_len]
|
||||
elif recon_audios.shape[-1] < audios_len:
|
||||
recon_audios = F.pad(recon_audios, (0, audios_len - recon_audios.shape[-1]))
|
||||
return recon_audios, mels, recon_mels, latent
|
||||
|
||||
def freeze_vocoder(self):
|
||||
self.vocoder.eval()
|
||||
self.vocoder.requires_grad_(False)
|
||||
|
||||
def unfreeze_vocoder(self):
|
||||
self.vocoder.train()
|
||||
self.vocoder.requires_grad_(True)
|
||||
|
||||
def return_middle_layers(self):
|
||||
last_down_block = self.dcae.encoder.down_blocks[-1]
|
||||
encoder_conv_out = self.dcae.encoder.conv_out
|
||||
decoder_conv_in = self.dcae.decoder.conv_in
|
||||
decoder_up_blocks = self.dcae.decoder.up_blocks[0]
|
||||
middle_layers = [last_down_block, encoder_conv_out, decoder_conv_in, decoder_up_blocks]
|
||||
return middle_layers
|
||||
|
||||
def return_head_layers(self):
|
||||
decoder_up_blocks = self.dcae.decoder.up_blocks[-1]
|
||||
conv_out = self.dcae.decoder.conv_out
|
||||
head_layers = [decoder_up_blocks, conv_out]
|
||||
return head_layers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = MusicDCAEVocoder()
|
||||
|
||||
audio_path = "/root/sag_train/orig2.wav"
|
||||
audio, sr = model.load_audio(audio_path)
|
||||
audio = model.resample_audio(audio, sr)
|
||||
|
||||
model.eval()
|
||||
model = model.to("cuda:0")
|
||||
audio = audio.to("cuda:0")
|
||||
with torch.no_grad():
|
||||
audios_len = audio.shape[-1]
|
||||
min_frame = 512 * 32
|
||||
if audios_len % min_frame != 0:
|
||||
padding = torch.zeros(audio.shape[0], 2, min_frame - audios_len % min_frame).to(audios.device)
|
||||
audios = torch.cat([audio, padding], dim=-1)
|
||||
recon_audios, mels, recon_mels, latent = model(audio.unsqueeze(0))
|
||||
recon_audios = recon_audios[:, :, :audios_len]
|
||||
|
||||
print("latent shape: ", latent.shape)
|
||||
print("recon_audios", recon_audios.shape)
|
||||
print("mels", mels.shape, "min:", mels.min(), "max:", mels.max(), "mean:", mels.mean(), "std:", mels.std())
|
||||
print("recon_mels", recon_mels.shape, "min:", recon_mels.min(), "max:", recon_mels.max(), "mean:", recon_mels.mean(), "std:", recon_mels.std())
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"模型参数总数: {total_params / 1e6:.2f}M")
|
||||
|
||||
torchaudio.save("/root/sag_train/recon2.wav", recon_audios[0].cpu(), 44100)
|
||||
119
music_dcae/music_log_mel.py
Executable file
119
music_dcae/music_log_mel.py
Executable file
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torchaudio.transforms import MelScale
|
||||
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.mode = mode
|
||||
|
||||
self.register_buffer("window", torch.hann_window(win_length))
|
||||
|
||||
def forward(self, y: Tensor) -> Tensor:
|
||||
if y.ndim == 3:
|
||||
y = y.squeeze(1)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(
|
||||
(self.win_length - self.hop_length) // 2,
|
||||
(self.win_length - self.hop_length + 1) // 2,
|
||||
),
|
||||
mode="reflect",
|
||||
).squeeze(1)
|
||||
dtype = y.dtype
|
||||
spec = torch.stft(
|
||||
y.float(),
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
|
||||
if self.mode == "pow2_sqrt":
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = spec.to(dtype)
|
||||
return spec
|
||||
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max or sample_rate // 2
|
||||
|
||||
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||
self.mel_scale = MelScale(
|
||||
self.n_mels,
|
||||
self.sample_rate,
|
||||
self.f_min,
|
||||
self.f_max,
|
||||
self.n_fft // 2 + 1,
|
||||
"slaney",
|
||||
"slaney",
|
||||
)
|
||||
|
||||
def compress(self, x: Tensor) -> Tensor:
|
||||
return torch.log(torch.clamp(x, min=1e-5))
|
||||
|
||||
def decompress(self, x: Tensor) -> Tensor:
|
||||
return torch.exp(x)
|
||||
|
||||
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||
linear = self.spectrogram(x)
|
||||
x = self.mel_scale(linear)
|
||||
x = self.compress(x)
|
||||
# print(x.shape)
|
||||
if return_linear:
|
||||
return x, self.compress(linear)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_mel_transform():
|
||||
return LogMelSpectrogram(
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
f_min=40,
|
||||
f_max=16000,
|
||||
n_mels=128,
|
||||
)
|
||||
565
music_dcae/music_vocoder.py
Executable file
565
music_dcae/music_vocoder.py
Executable file
@@ -0,0 +1,565 @@
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
|
||||
try:
|
||||
from music_log_mel import LogMelSpectrogram
|
||||
except ImportError:
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
|
||||
def drop_path(
|
||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||
):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None] * x + self.bias[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
kernel_size: int = 7,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(dilation * (kernel_size - 1) / 2),
|
||||
groups=dim,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(
|
||||
dim, int(mlp_ratio * dim)
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value *
|
||||
torch.ones((dim)), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x, apply_residual: bool = True):
|
||||
input = x
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
|
||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||
x = self.drop_path(x)
|
||||
|
||||
if apply_residual:
|
||||
x = input + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ParallelConvNeXtBlock(nn.Module):
|
||||
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||
for kernel_size in kernel_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||
dim=1,
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class ConvNeXtEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
drop_path_rate=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
):
|
||||
super().__init__()
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
self.channel_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
input_channels,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
padding_mode="replicate",
|
||||
),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
)
|
||||
self.channel_layers.append(stem)
|
||||
|
||||
for i in range(len(depths) - 1):
|
||||
mid_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||
)
|
||||
self.channel_layers.append(mid_layer)
|
||||
|
||||
block_fn = (
|
||||
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||
if len(kernel_sizes) == 1
|
||||
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
|
||||
cur = 0
|
||||
for i in range(len(depths)):
|
||||
stage = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=dims[i],
|
||||
drop_path=drop_path_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||
x = channel_layer(x)
|
||||
x = stage(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.silu(x)
|
||||
xt = c1(xt)
|
||||
xt = F.silu(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for conv in self.convs1:
|
||||
remove_weight_norm(conv)
|
||||
for conv in self.convs2:
|
||||
remove_weight_norm(conv)
|
||||
|
||||
|
||||
class HiFiGANGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hop_length: int = 512,
|
||||
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 128,
|
||||
upsample_initial_channel: int = 512,
|
||||
use_template: bool = True,
|
||||
pre_conv_kernel_size: int = 7,
|
||||
post_conv_kernel_size: int = 7,
|
||||
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
nn.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
pre_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(pre_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.use_template = use_template
|
||||
self.ups = nn.ModuleList()
|
||||
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not use_template:
|
||||
continue
|
||||
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(
|
||||
Conv1d(
|
||||
1,
|
||||
c_cur,
|
||||
kernel_size=stride_f0 * 2,
|
||||
stride=stride_f0,
|
||||
padding=stride_f0 // 2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
post_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(post_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def forward(self, x, template=None):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.use_template:
|
||||
x = x + self.noise_convs[i](template)
|
||||
|
||||
xs = None
|
||||
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for up in self.ups:
|
||||
remove_weight_norm(up)
|
||||
for block in self.resblocks:
|
||||
block.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class ADaMoSHiFiGANV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str = "checkpoints/adamos-generator-1640000.pth",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = ConvNeXtEncoder(
|
||||
input_channels=128,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[128, 256, 384, 512],
|
||||
drop_path_rate=0,
|
||||
kernel_sizes=(7,),
|
||||
)
|
||||
|
||||
self.head = HiFiGANGenerator(
|
||||
hop_length=512,
|
||||
upsample_rates=(4, 4, 2, 2, 2, 2, 2),
|
||||
upsample_kernel_sizes=(8, 8, 4, 4, 4, 4, 4),
|
||||
resblock_kernel_sizes=(3, 7, 11, 13),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels=512,
|
||||
upsample_initial_channel=1024,
|
||||
use_template=False,
|
||||
pre_conv_kernel_size=13,
|
||||
post_conv_kernel_size=13,
|
||||
)
|
||||
self.sampling_rate = 44100
|
||||
|
||||
ckpt_state = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in ckpt_state:
|
||||
ckpt_state = ckpt_state["state_dict"]
|
||||
|
||||
if any(k.startswith("generator.") for k in ckpt_state):
|
||||
ckpt_state = {
|
||||
k.replace("generator.", ""): v
|
||||
for k, v in ckpt_state.items()
|
||||
if k.startswith("generator.")
|
||||
}
|
||||
|
||||
self.load_state_dict(ckpt_state)
|
||||
self.eval()
|
||||
|
||||
self.mel_transform = LogMelSpectrogram(
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
f_min=40,
|
||||
f_max=16000,
|
||||
n_mels=128,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
return self.mel_transform(x)
|
||||
|
||||
def forward(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import soundfile as sf
|
||||
|
||||
x = "./test.wav"
|
||||
model = ADaMoSHiFiGANV1(checkpoint_path='./step_001640000.pth')
|
||||
|
||||
wav, sr = librosa.load(x, sr=44100, mono=True)
|
||||
wav = torch.from_numpy(wav).float()[None]
|
||||
mel = model.encode(wav)
|
||||
|
||||
wav = model.decode(mel)[0].mT
|
||||
sf.write("test_out.wav", wav.cpu().numpy(), 44100)
|
||||
39
optimizers/cosine_wsd.py
Normal file
39
optimizers/cosine_wsd.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import torch
|
||||
|
||||
|
||||
class CosineWSD(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_iters, step_size, decay_length, decay_interval, eta_min=0, last_epoch=-1):
|
||||
self.warmup_iters = warmup_iters
|
||||
self.step_size = step_size
|
||||
self.decay_length = decay_length
|
||||
self.decay_interval = decay_interval
|
||||
self.eta_min = eta_min
|
||||
super(CosineWSD, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.warmup_iters:
|
||||
lr = [(base_lr * self.last_epoch / self.warmup_iters) for base_lr in self.base_lrs]
|
||||
elif self.last_epoch < self.step_size:
|
||||
lr = [base_lr for base_lr in self.base_lrs]
|
||||
elif self.last_epoch <= self.step_size + self.decay_length:
|
||||
lr = [(base_lr * (0.5 ** ((self.last_epoch - self.step_size) / self.decay_interval)))
|
||||
for base_lr in self.base_lrs]
|
||||
else:
|
||||
lr = [self.eta_min for base_lr in self.base_lrs]
|
||||
return lr
|
||||
|
||||
|
||||
def configure_lr_scheduler(optimizer, total_steps_per_epoch, epochs=10, decay_ratio=0.9, decay_interval=1000, warmup_iters=4000):
|
||||
total_steps = total_steps_per_epoch * epochs
|
||||
step_size = total_steps * decay_ratio
|
||||
decay_length = total_steps - step_size
|
||||
decay_interval = decay_interval
|
||||
lr_scheduler = CosineWSD(
|
||||
optimizer,
|
||||
warmup_iters=warmup_iters,
|
||||
step_size=step_size,
|
||||
decay_length=decay_length,
|
||||
decay_interval=decay_interval
|
||||
)
|
||||
return [{"scheduler": lr_scheduler, "name": "CosineWSD", "interval": "step"}]
|
||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
datasets==3.4.1
|
||||
diffusers==0.32.2
|
||||
gradio==5.23.3
|
||||
librosa==0.11.0
|
||||
loguru==0.7.3
|
||||
matplotlib==3.10.1
|
||||
numpy
|
||||
pypinyin==0.53.0
|
||||
pytorch_lightning==2.5.1
|
||||
soundfile==0.13.1
|
||||
torch
|
||||
torchaudio
|
||||
torchvision
|
||||
tqdm==4.67.1
|
||||
transformers==4.50.0
|
||||
py3langid==0.3.0
|
||||
hangul-romanize==0.1.0
|
||||
num2words==0.5.14
|
||||
spacy==3.8.4
|
||||
394
schedulers/scheduling_flow_match_euler_discrete.py
Normal file
394
schedulers/scheduling_flow_match_euler_discrete.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting=False,
|
||||
base_shift: Optional[float] = 0.5,
|
||||
max_shift: Optional[float] = 1.15,
|
||||
base_image_seq_len: Optional[int] = 256,
|
||||
max_image_seq_len: Optional[int] = 4096,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
if not use_dynamic_shifting:
|
||||
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(sample.device)
|
||||
timestep = timestep.to(sample.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||
elif self.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [self.step_index] * timestep.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [self.begin_index] * timestep.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(sample.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is None:
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
omega: Union[float, np.array] = 0.0
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
|
||||
# L = Lower bound
|
||||
# U = Upper bound
|
||||
# x_0 = Midpoint (x corresponding to y = 1.0)
|
||||
# k = Steepness, can adjust based on preference
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
device_ = x.device
|
||||
x = x.to(torch.float).cpu().numpy()
|
||||
|
||||
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
|
||||
|
||||
if isinstance(new_x, np.ndarray):
|
||||
new_x = torch.from_numpy(new_x).to(device_)
|
||||
return new_x
|
||||
|
||||
self.omega_bef_rescale = omega
|
||||
omega = logistic_function(omega, k=0.1)
|
||||
self.omega_aft_rescale = omega
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
## --
|
||||
## mean shift 1
|
||||
dx = (sigma_next - sigma) * model_output
|
||||
m = dx.mean()
|
||||
# print(dx.shape) # torch.Size([1, 16, 128, 128])
|
||||
# print(f'm: {m}') # m: -0.0014209747314453125
|
||||
# raise NotImplementedError
|
||||
dx_ = (dx - m) * omega + m
|
||||
prev_sample = sample + dx_
|
||||
|
||||
# ## --
|
||||
# ## mean shift 2
|
||||
# m = model_output.mean()
|
||||
# model_output_ = (model_output - m) * omega + m
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
||||
|
||||
# ## --
|
||||
# ## original
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
|
||||
|
||||
# ## --
|
||||
# ## spatial mean 1
|
||||
# dx = (sigma_next - sigma) * model_output
|
||||
# m = dx.mean(dim=(0, 1), keepdim=True)
|
||||
# # print(dx.shape) # torch.Size([1, 16, 128, 128])
|
||||
# # print(m.shape) # torch.Size([1, 1, 128, 128])
|
||||
# # raise NotImplementedError
|
||||
# dx_ = (dx - m) * omega + m
|
||||
# prev_sample = sample + dx_
|
||||
|
||||
# ## --
|
||||
# ## spatial mean 2
|
||||
# m = model_output.mean(dim=(0, 1), keepdim=True)
|
||||
# model_output_ = (model_output - m) * omega + m
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
||||
|
||||
# ## --
|
||||
# ## channel mean 1
|
||||
# m = model_output.mean(dim=(2, 3), keepdim=True)
|
||||
# # print(m.shape) # torch.Size([1, 16, 1, 1])
|
||||
# model_output_ = (model_output - m) * omega + m
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
||||
|
||||
# ## --
|
||||
# ## channel mean 2
|
||||
# dx = (sigma_next - sigma) * model_output
|
||||
# m = dx.mean(dim=(2, 3), keepdim=True)
|
||||
# # print(m.shape) # torch.Size([1, 16, 1, 1])
|
||||
# dx_ = (dx - m) * omega + m
|
||||
# prev_sample = sample + dx_
|
||||
|
||||
# ## --
|
||||
# ## keep sample mean
|
||||
# m_tgt = sample.mean()
|
||||
# prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
|
||||
# m_src = prev_sample_.mean()
|
||||
# prev_sample = prev_sample_ - m_src + m_tgt
|
||||
|
||||
# ## --
|
||||
# ## test
|
||||
# # print(sample.mean())
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
|
||||
# # raise NotImplementedError
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
348
schedulers/scheduling_flow_match_heun_discrete.py
Normal file
348
schedulers/scheduling_flow_match_heun_discrete.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Heun scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
omega: Union[float, np.array] = 0.0
|
||||
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
|
||||
# L = Lower bound
|
||||
# U = Upper bound
|
||||
# x_0 = Midpoint (x corresponding to y = 1.0)
|
||||
# k = Steepness, can adjust based on preference
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
device_ = x.device
|
||||
x = x.to(torch.float).cpu().numpy()
|
||||
|
||||
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
|
||||
|
||||
if isinstance(new_x, np.ndarray):
|
||||
new_x = torch.from_numpy(new_x).to(device_)
|
||||
return new_x
|
||||
|
||||
self.omega_bef_rescale = omega
|
||||
omega = logistic_function(omega, k=0.1)
|
||||
self.omega_aft_rescale = omega
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
else:
|
||||
# 2nd order / Heun's method
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_next = self.sigmas[self.step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
)
|
||||
eps = noise * s_noise
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. convert to an ODE derivative for 1st order
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
# 3. Delta timestep
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.prev_derivative = derivative
|
||||
self.dt = dt
|
||||
self.sample = sample
|
||||
else:
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
denoised = sample - model_output * sigma_next
|
||||
# 2. 2nd order / Heun's method
|
||||
derivative = (sample - denoised) / sigma_next
|
||||
derivative = 0.5 * (self.prev_derivative + derivative)
|
||||
|
||||
# 3. take prev timestep & sample
|
||||
dt = self.dt
|
||||
sample = self.sample
|
||||
|
||||
# free dt and derivative
|
||||
# Note, this puts the scheduler in "first order mode"
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
self.sample = None
|
||||
|
||||
# original sample way
|
||||
# prev_sample = sample + derivative * dt
|
||||
|
||||
dx = derivative * dt
|
||||
m = dx.mean()
|
||||
dx_ = (dx - m) * omega + m
|
||||
prev_sample = sample + dx_
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
3
ui/auth.py
Normal file
3
ui/auth.py
Normal file
@@ -0,0 +1,3 @@
|
||||
|
||||
def same_auth(username, password):
|
||||
return username == "timedomain_text2music_team" and password == "TimeDomain_ACEFlow_DEMO"
|
||||
44
ui/llm_prompt_gen.py
Normal file
44
ui/llm_prompt_gen.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
random_genre_prompt = """randomly give me a short prompt that describes a music (with genre tag). less than 30 words
|
||||
Here are some examples:
|
||||
fusion jazz with synth, bass, drums, saxophone
|
||||
Electronic, eerie, swing, dreamy, melodic, electro, sad, emotional
|
||||
90s hip-hop, old school rap, turntablism, vinyl samples, instrumental loop
|
||||
"""
|
||||
|
||||
|
||||
def random_genre():
|
||||
client = OpenAI()
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "system", "content": random_genre_prompt}],
|
||||
max_tokens=30,
|
||||
temperature=0.7,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
optimize_genre_prompt = """optimize the following music descirption and make it more genre specific. less than 30 words
|
||||
output examples:
|
||||
fusion jazz with synth, bass, drums, saxophone
|
||||
Electronic, eerie, swing, dreamy, melodic, electro, sad, emotional
|
||||
90s hip-hop, old school rap, turntablism, vinyl samples, instrumental loop
|
||||
|
||||
## input music descirption
|
||||
"""
|
||||
|
||||
|
||||
def optimize_genre(prompt):
|
||||
client = OpenAI()
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "system", "content": optimize_genre_prompt+prompt}],
|
||||
max_tokens=30,
|
||||
temperature=0.7,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
323
ui/text2music_large_lyric_components_v3.py
Normal file
323
ui/text2music_large_lyric_components_v3.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
import json
|
||||
from collections import OrderedDict, Counter
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from language_segmentation import LangSegment
|
||||
|
||||
MAX_GENERATE_LEN = 60
|
||||
|
||||
|
||||
SUPPORT_LANGUAGES = [
|
||||
"af", "sq", "am", "ar", "an", "hy", "az", "ba", "eu", "be", "bn", "bs", "bg", "my", "ca", "zh", "cs", "da", "nl", "en", "eo", "et", "fi", "fr", "gd", "ka", "de", "el", "gn", "gu", "hi", "hu", "io", "id", "ia", "it", "ja", "kk", "km", "ko", "ku", "la", "lt", "lb", "mk", "mt", "nb", "no", "or", "fa", "pl", "pt", "ro", "ru", "sa", "sr", "sd", "sk", "sl", "es", "sw", "sv", "tl", "ta", "tt", "th", "tr", "tk", "uk", "vi", "cy", "is", "ga", "gl", "se", "yue"
|
||||
]
|
||||
|
||||
|
||||
langseg = LangSegment()
|
||||
|
||||
langseg.setfilters([
|
||||
'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
|
||||
'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
|
||||
'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
|
||||
'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
|
||||
'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
|
||||
'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
|
||||
])
|
||||
|
||||
keyscale_idx_mapping = OrderedDict({
|
||||
"C major": 1,
|
||||
"C# major": 2,
|
||||
"D major": 3,
|
||||
"Eb major": 4,
|
||||
"E major": 5,
|
||||
"F major": 6,
|
||||
"F# major": 7,
|
||||
"G major": 8,
|
||||
"Ab major": 9,
|
||||
"A major": 10,
|
||||
"Bb major": 11,
|
||||
"B major": 12,
|
||||
"A minor": 13,
|
||||
"Bb minor": 14,
|
||||
"B minor": 15,
|
||||
"C minor": 16,
|
||||
"C# minor": 17,
|
||||
"D minor": 18,
|
||||
"Eb minor": 19,
|
||||
"E minor": 20,
|
||||
"F minor": 21,
|
||||
"F# minor": 22,
|
||||
"G minor": 23,
|
||||
"Ab minor": 24
|
||||
})
|
||||
|
||||
|
||||
def get_checkpoint_paths(checkpoint_path):
|
||||
# 获取指定目录中的所有checkpoint文件路径
|
||||
directory = Path(checkpoint_path).parent
|
||||
checkpoints = [str(p) for p in directory.glob("*.ckpt")]
|
||||
print(checkpoints)
|
||||
return checkpoints
|
||||
|
||||
|
||||
def create_list_checkpoint_path_ui(checkpoint_path):
|
||||
with gr.Column():
|
||||
gr.Markdown("Checkpoint Selection")
|
||||
with gr.Group():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=9):
|
||||
selected_checkpoint = gr.Dropdown(
|
||||
choices=get_checkpoint_paths(checkpoint_path),
|
||||
label="Select Model",
|
||||
interactive=True,
|
||||
value=checkpoint_path,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
refresh_button = gr.Button("Refresh Checkpoints", elem_id="refresh_button", variant="primary")
|
||||
refresh_button.click(
|
||||
fn=lambda: gr.update(choices=get_checkpoint_paths(checkpoint_path)),
|
||||
inputs=None,
|
||||
outputs=[selected_checkpoint]
|
||||
)
|
||||
return selected_checkpoint
|
||||
|
||||
|
||||
def create_keyscale_bpm_time_signature_input_ui(options=["auto", "manual"]):
|
||||
gr.Markdown("### Time and Keyscale Control")
|
||||
with gr.Group():
|
||||
results = [
|
||||
["keyscale", 0],
|
||||
["bpm", 0],
|
||||
["timesignature", 0],
|
||||
["is_music_start", 0],
|
||||
["is_music_end", 0],
|
||||
]
|
||||
keyscale_bpm_time_signature_input = gr.List(visible=False, elem_id="keyscale_bpm_time_signature_input", value=results)
|
||||
audio_duration = gr.Slider(10, 600, step=1, value=MAX_GENERATE_LEN, label="Audio Duration", interactive=True)
|
||||
with gr.Row():
|
||||
is_music_start_input = gr.Radio(["auto", "start", "not_start"], value="auto", label="Is Music Start", elem_id="is_music_start_input")
|
||||
is_music_end_input = gr.Radio(["auto", "end", "not_end"], value="auto", label="Is Music End", elem_id="is_music_end_input")
|
||||
|
||||
def when_is_music_start_input_change(
|
||||
is_music_start_input,
|
||||
):
|
||||
nonlocal results
|
||||
if is_music_start_input == "auto":
|
||||
is_music_start = 0
|
||||
elif is_music_start_input == "start":
|
||||
is_music_start = 1
|
||||
else:
|
||||
is_music_start = 2
|
||||
results[3][1] = is_music_start
|
||||
return gr.update(elem_id="keyscale_bpm_time_signature_input", value=results)
|
||||
|
||||
is_music_start_input.change(
|
||||
when_is_music_start_input_change,
|
||||
inputs=[is_music_start_input],
|
||||
outputs=[keyscale_bpm_time_signature_input]
|
||||
)
|
||||
|
||||
def when_is_music_end_input_change(
|
||||
is_music_end_input,
|
||||
):
|
||||
nonlocal results
|
||||
if is_music_end_input == "auto":
|
||||
is_music_end = 0
|
||||
elif is_music_end_input == "end":
|
||||
is_music_end = 1
|
||||
else:
|
||||
is_music_end = 2
|
||||
results[4][1] = is_music_end
|
||||
return gr.update(elem_id="keyscale_bpm_time_signature_input", value=results)
|
||||
|
||||
is_music_end_input.change(
|
||||
when_is_music_end_input_change,
|
||||
inputs=[is_music_end_input],
|
||||
outputs=[keyscale_bpm_time_signature_input]
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
keyscale_control = gr.Radio(options, value="auto", label="Keyscale", elem_id="keyscale_control")
|
||||
bpm_control = gr.Radio(options, value="auto", label="BPM", elem_id="bpm_control")
|
||||
time_signature_control = gr.Radio(options, value="auto", label="Time Signature", elem_id="time_signature_control")
|
||||
|
||||
keyscale_input = gr.Dropdown(list(keyscale_idx_mapping.keys()), label="Keyscale", info="the keyscale of the music", visible=False, elem_id="keyscale_input")
|
||||
|
||||
def when_keyscale_change(
|
||||
keyscale_input,
|
||||
keyscale_control,
|
||||
):
|
||||
nonlocal results
|
||||
keyscale = keyscale_input
|
||||
if keyscale_control == "auto":
|
||||
keyscale = 0
|
||||
results[0][1] = keyscale
|
||||
return [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="keyscale_input", visible=(keyscale_control == "manual"))]
|
||||
|
||||
keyscale_input.change(
|
||||
when_keyscale_change,
|
||||
inputs=[keyscale_input, keyscale_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, keyscale_input]
|
||||
)
|
||||
keyscale_control.change(
|
||||
fn=when_keyscale_change,
|
||||
inputs=[keyscale_input, keyscale_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, keyscale_input]
|
||||
)
|
||||
|
||||
bpm_input = gr.Slider(30, 200, step=1, value=120, label="BPM", info="the beats per minute of the music", visible=False, interactive=True, elem_id="bpm_input")
|
||||
|
||||
def when_bmp_change(
|
||||
bpm_input,
|
||||
bpm_control,
|
||||
):
|
||||
nonlocal results
|
||||
bpm = bpm_input
|
||||
if bpm_control == "auto":
|
||||
bpm = 0
|
||||
results[1][1] = bpm
|
||||
updates = [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="bpm_input", visible=(bpm_control == "manual"))]
|
||||
return updates
|
||||
|
||||
bpm_control.change(
|
||||
fn=when_bmp_change,
|
||||
inputs=[bpm_input, bpm_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, bpm_input]
|
||||
)
|
||||
|
||||
bpm_input.change(
|
||||
when_bmp_change,
|
||||
inputs=[bpm_input, bpm_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, bpm_input]
|
||||
)
|
||||
|
||||
time_signature_input = gr.Slider(1, 12, step=1, value=4, label="Time Signature", info="the time signature of the music", visible=False, interactive=True, elem_id="time_signature_input")
|
||||
|
||||
def when_time_signature_change(
|
||||
time_signature_input,
|
||||
time_signature_control,
|
||||
):
|
||||
nonlocal results
|
||||
time_signature = time_signature_input
|
||||
if time_signature_control == "auto":
|
||||
time_signature = 0
|
||||
results[2][1] = time_signature
|
||||
return [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="time_signature_input", visible=(time_signature_control == "manual"))]
|
||||
|
||||
time_signature_input.change(
|
||||
when_time_signature_change,
|
||||
inputs=[time_signature_input, time_signature_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, time_signature_input]
|
||||
)
|
||||
time_signature_control.change(
|
||||
fn=when_time_signature_change,
|
||||
inputs=[time_signature_input, time_signature_control],
|
||||
outputs=[keyscale_bpm_time_signature_input, time_signature_input]
|
||||
)
|
||||
|
||||
return [audio_duration, keyscale_bpm_time_signature_input]
|
||||
|
||||
|
||||
def detect_language(lyrics: str) -> list:
|
||||
lyrics = lyrics.strip()
|
||||
if not lyrics:
|
||||
return gr.update(value="en")
|
||||
langs = langseg.getTexts(lyrics)
|
||||
lang_counter = Counter()
|
||||
for lang in langs:
|
||||
lang_counter[lang["lang"]] += len(lang["text"])
|
||||
lang = lang_counter.most_common(1)[0][0]
|
||||
return lang
|
||||
|
||||
|
||||
def create_output_ui():
|
||||
target_audio = gr.Audio(type="filepath", label="Target Audio")
|
||||
output_audio1 = gr.Audio(type="filepath", label="Generated Audio 1")
|
||||
output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
|
||||
input_params_json = gr.JSON(label="Input Parameters")
|
||||
outputs = [output_audio1, output_audio2]
|
||||
return outputs, target_audio, input_params_json
|
||||
|
||||
|
||||
def dump_func(*args):
|
||||
print(args)
|
||||
return []
|
||||
|
||||
|
||||
def create_main_demo_ui(
|
||||
checkpoint_path="checkpoints/aceflow3_0311/1d_epoch=16-step=140k.ckpt",
|
||||
text2music_process_func=dump_func,
|
||||
sample_data_func=dump_func,
|
||||
):
|
||||
with gr.Blocks(
|
||||
title="AceFlow 3.0 DEMO (3.5B)",
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<h1 style="text-align: center;">AceFlow 3.0 DEMO</h1>
|
||||
"""
|
||||
)
|
||||
selected_checkpoint = create_list_checkpoint_path_ui(checkpoint_path)
|
||||
|
||||
gr.Markdown("Dataset Filter")
|
||||
with gr.Group():
|
||||
with gr.Row(equal_height=True):
|
||||
language = gr.Dropdown(["en", "zh"], label="Language", value="en", elem_id="language")
|
||||
dataset_example_idx = gr.Number(
|
||||
value=-1,
|
||||
label="Dataset Example Index",
|
||||
interactive=True
|
||||
)
|
||||
sample_bnt = gr.Button(value="Sample Data", elem_id="sample_bnt", variant="primary")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio_duration = gr.Slider(10, 600, step=1, value=MAX_GENERATE_LEN, label="Audio Duration", interactive=True)
|
||||
|
||||
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4)
|
||||
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=9)
|
||||
|
||||
scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type")
|
||||
cfg_type = gr.Radio(["cfg", "apg"], value="apg", label="CFG Type", elem_id="cfg_type")
|
||||
infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
|
||||
guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True)
|
||||
omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True)
|
||||
manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None)
|
||||
|
||||
text2music_bnt = gr.Button(variant="primary")
|
||||
with gr.Column():
|
||||
outputs, target_audio, input_params_json = create_output_ui()
|
||||
|
||||
sample_bnt.click(
|
||||
sample_data_func,
|
||||
inputs=[dataset_example_idx, audio_duration],
|
||||
outputs=[target_audio, prompt, lyrics, input_params_json],
|
||||
)
|
||||
text2music_bnt.click(
|
||||
fn=text2music_process_func,
|
||||
inputs=[
|
||||
audio_duration,
|
||||
prompt,
|
||||
lyrics,
|
||||
input_params_json,
|
||||
selected_checkpoint,
|
||||
scheduler_type,
|
||||
cfg_type,
|
||||
infer_step,
|
||||
guidance_scale,
|
||||
omega_scale,
|
||||
manual_seeds,
|
||||
], outputs=outputs + [input_params_json]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = create_main_demo_ui()
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
)
|
||||
197
utils.py
Normal file
197
utils.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from loguru import logger
|
||||
import functools
|
||||
import numpy as np
|
||||
import time
|
||||
import librosa
|
||||
import sys
|
||||
import yaml
|
||||
from threading import Thread
|
||||
|
||||
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, format="{time} {level} {message}", level="INFO")
|
||||
|
||||
|
||||
def async_thread(f):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
t = Thread(target=f, args=args, kwargs=kwargs)
|
||||
t.start()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def timecost(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
logger.info(f"{func.__name__} took {end - start} seconds to run")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
def autocut(wav, min_cut_second=9.9, sample_rate=16_000, frame_length=2048, hop_length=512, cut_threshold=[2e-5, 1, 2**0.5], min_mute_duration=120, min_tail_second=2):
|
||||
segs = []
|
||||
seg_lengths = []
|
||||
longest_wav_frames = int(min_cut_second * sample_rate)
|
||||
if len(wav) < longest_wav_frames:
|
||||
segs.append(wav)
|
||||
seg_lengths.append(len(wav))
|
||||
return segs, seg_lengths
|
||||
|
||||
# 自适应阈值算法找静音切分点
|
||||
candidate_cut_positions = []
|
||||
candidate_cut_durations = []
|
||||
cut_threshold, cut_threshold_max, cut_step_multiple = cut_threshold
|
||||
|
||||
for i in range(8):
|
||||
rms = librosa.feature.rms(y=wav, frame_length=frame_length, hop_length=hop_length)[0]
|
||||
is_mute_mask = rms <= cut_threshold
|
||||
is_mute = np.zeros_like(rms, dtype='bool')
|
||||
is_mute[is_mute_mask], is_mute[~is_mute_mask] = True, False
|
||||
# logger.info(f"{rms.mean()=}, {rms.min()=}, {rms.max()=}, {cut_threshold=}, {is_mute_mask.sum()=}")
|
||||
last_start = 0
|
||||
last_position = 0
|
||||
curr_cut_positions = []
|
||||
curr_cut_durations = []
|
||||
interrupt = False
|
||||
for i in range(len(is_mute) - 1):
|
||||
# 从有到无
|
||||
if not is_mute[i] and is_mute[i + 1]:
|
||||
last_start = i
|
||||
# 从无到有
|
||||
if is_mute[i] and not is_mute[i + 1]:
|
||||
# 静音部分至少大于等于min_mute_duration
|
||||
mute_duration = (i - last_start) * \
|
||||
hop_length / (sample_rate / 1000)
|
||||
if mute_duration >= min_mute_duration:
|
||||
# 切分规则:在静音中间部分作为分割点
|
||||
# 还原到wav的帧
|
||||
mid = (i + last_start) // 2
|
||||
cut_position = mid * hop_length
|
||||
curr_duration = cut_position - last_position
|
||||
# 若超了,切分成四份
|
||||
if (longest_wav_frames // 2) < curr_duration:
|
||||
left_cut_position = (last_start+mid) // 2 * hop_length
|
||||
left_curr_duration = left_cut_position - last_position
|
||||
curr_cut_positions.append(left_cut_position)
|
||||
curr_cut_durations.append(left_curr_duration)
|
||||
last_position = left_cut_position
|
||||
|
||||
right_cut_position = (mid+i) // 2 * hop_length
|
||||
right_curr_duration = right_cut_position - last_position
|
||||
curr_cut_positions.append(right_cut_position)
|
||||
curr_cut_durations.append(right_curr_duration)
|
||||
last_position = right_cut_position
|
||||
else:
|
||||
curr_cut_positions.append(cut_position)
|
||||
curr_cut_durations.append(curr_duration)
|
||||
last_position = cut_position
|
||||
|
||||
candidate_cut_positions = curr_cut_positions
|
||||
candidate_cut_durations = curr_cut_durations
|
||||
if cut_threshold >= cut_threshold_max:
|
||||
break
|
||||
if cut_threshold < cut_threshold_max:
|
||||
if len(curr_cut_durations) == 0:
|
||||
curr_cut_positions.append(len(wav))
|
||||
curr_cut_durations.append(len(wav))
|
||||
else:
|
||||
curr_cut_positions.append(len(wav))
|
||||
curr_cut_durations.append(
|
||||
curr_cut_positions[-1] - curr_cut_positions[-2])
|
||||
max_duration = max(curr_cut_durations)
|
||||
if max_duration >= longest_wav_frames:
|
||||
interrupt = True
|
||||
cut_threshold = cut_threshold * cut_step_multiple
|
||||
min_mute_duration = int(max(min_mute_duration/cut_step_multiple, 10))
|
||||
frame_length = int(max(frame_length / cut_step_multiple, 256))
|
||||
hop_length = int(max(hop_length / cut_step_multiple, 64))
|
||||
# logger.info(f"Adaptively adjust the threshold: {cut_threshold=} {min_mute_duration=} {frame_length=} {hop_length=} {len(curr_cut_durations)=}")
|
||||
if not interrupt and len(curr_cut_durations) > 0:
|
||||
candidate_cut_positions = curr_cut_positions
|
||||
candidate_cut_durations = curr_cut_durations
|
||||
break
|
||||
|
||||
# logger.info(f"candidate_cut_positions {candidate_cut_positions}")
|
||||
# logger.info(f"candidate_cut_durations {candidate_cut_durations}")
|
||||
# 从已有切分点中找最接近最大长度的切分点
|
||||
curr_duration = 0
|
||||
last_start = 0
|
||||
for i, duration in enumerate(candidate_cut_durations):
|
||||
curr_duration += duration
|
||||
# 若超出最大限制,以上一个点作为实际切分
|
||||
if curr_duration > longest_wav_frames:
|
||||
segs.append(wav[last_start:candidate_cut_positions[i - 1]])
|
||||
seg_lengths.append(curr_duration - duration)
|
||||
curr_duration = duration
|
||||
last_start = candidate_cut_positions[i - 1]
|
||||
if len(candidate_cut_durations) == 0 or (len(candidate_cut_durations)==1 and candidate_cut_durations[0] >= len(wav)):
|
||||
logger.info("自动切分算法失败,按最长强制切分")
|
||||
# 按最长强制切分
|
||||
last_start = 0
|
||||
segs = []
|
||||
seg_lengths = []
|
||||
for end in range(longest_wav_frames, max(longest_wav_frames, len(wav)), longest_wav_frames):
|
||||
segs.append(wav[last_start:end])
|
||||
seg_lengths.append(end-last_start)
|
||||
last_start = end
|
||||
# 解决尾部问题
|
||||
if sum(seg_lengths) < len(wav):
|
||||
for end in range(last_start+longest_wav_frames, max(longest_wav_frames, len(wav)), longest_wav_frames):
|
||||
segs.append(wav[last_start:end])
|
||||
seg_lengths.append(end - last_start)
|
||||
last_start = end
|
||||
if sum(seg_lengths) < len(wav):
|
||||
last_start = sum(seg_lengths)
|
||||
tail_frame = len(wav) - last_start
|
||||
if len(segs) > 0 and tail_frame < min_tail_second*sample_rate:
|
||||
segs.pop()
|
||||
seg_lengths.pop()
|
||||
last_start = sum(seg_lengths)
|
||||
segs.append(wav[last_start:])
|
||||
seg_lengths.append(len(wav) - last_start)
|
||||
|
||||
if any([len(seg) > longest_wav_frames for seg in segs]):
|
||||
new_segs = []
|
||||
new_seg_lengths = []
|
||||
for seg, seg_length in zip(segs, seg_lengths):
|
||||
num_cut = len(seg) // longest_wav_frames
|
||||
num_cut += 1 if len(seg) % longest_wav_frames > 0 else 0
|
||||
for i in range(num_cut):
|
||||
new_segs.append(seg[i*longest_wav_frames:(i+1)*longest_wav_frames])
|
||||
new_seg_lengths.append(len(new_segs[-1]))
|
||||
segs, seg_lengths = new_segs, new_seg_lengths
|
||||
return segs, seg_lengths
|
||||
|
||||
|
||||
class ConfigObj:
|
||||
|
||||
def __init__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.__dict__)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.__dict__[k]
|
||||
|
||||
def get(self, k, default=None):
|
||||
if k in self.__dict__:
|
||||
return self[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
self.__dict__[k] = v
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, encoding='utf-8') as yaml_file:
|
||||
config = yaml.safe_load(yaml_file)
|
||||
return ConfigObj(config)
|
||||
Reference in New Issue
Block a user