all inference code

This commit is contained in:
sean
2025-04-03 12:25:30 +08:00
parent ea09512de9
commit d6f5a2d911
36 changed files with 28075 additions and 0 deletions

65
apg_guidance.py Normal file
View File

@@ -0,0 +1,65 @@
import torch
class MomentumBuffer:
def __init__(self, momentum: float = -0.75):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(
v0: torch.Tensor, # [B, C, H, W]
v1: torch.Tensor, # [B, C, H, W]
dims=[-1, -2],
):
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=dims)
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
def apg_forward(
pred_cond: torch.Tensor, # [B, C, H, W]
pred_uncond: torch.Tensor, # [B, C, H, W]
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 0.0,
norm_threshold: float = 2.5,
dims=[-1, -2],
):
diff = pred_cond - pred_uncond
# orig_cfg_guided = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# print("======== 新的一轮 =========")
# print("原来的diff", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std(), f"cfg会乘上{guidance_scale=}")
# print("如果跑cfg orig_cfg_guided", "min:", orig_cfg_guided.min(), "max:", orig_cfg_guided.max(), "mean:", orig_cfg_guided.mean(), "std:", orig_cfg_guided.std())
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
# print("跑完momentum_buffer后", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std(), f"cfg会乘上{guidance_scale=}")
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
# print("diff_norm", diff_norm)
# 只有比1大的时候爆音才会进行缩放
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
# print("跑完norm_threshold scale factor后", "min:", diff.min(), "max:", diff.max(), "mean:", diff.mean(), "std:", diff.std())
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
# print("跑完project后, diff_parallel", "min:", diff_parallel.min(), "max:", diff_parallel.max(), "mean:", diff_parallel.mean(), "std:", diff_parallel.std())
normalized_update = diff_orthogonal + eta * diff_parallel
# print("跑完normalized_update后", "min:", normalized_update.min(), "max:", normalized_update.max(), "mean:", normalized_update.mean(), "std:", normalized_update.std())
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
# print("最终pred_guided", "min:", pred_guided.min(), "max:", pred_guided.max(), "mean:", pred_guided.mean(), "std:", pred_guided.std())
return pred_guided
def cfg_forward(cond_output, uncond_output, cfg_strength):
return uncond_output + cfg_strength * (cond_output - uncond_output)

View File

@@ -0,0 +1,390 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default="checkpoints/epoch=22-step=460k_pretrained_ft_80k.ckpt")
parser.add_argument("--port", type=int, default=7862)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--hide_dataset_sampler", action='store_true', default=False)
args = parser.parse_args()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
import torch
import torchaudio
import torch.nn.functional as F
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from diffusers.utils.torch_utils import randn_tensor
from pathlib import Path
import time
from tqdm import tqdm
from loguru import logger
import json
from ui.auth import same_auth
from ui.text2music_large_lyric_components_v3 import create_main_demo_ui
from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
from apg_guidance import apg_forward, MomentumBuffer, cfg_forward
from language_segmentation import LangSegment
import random
import re
logger.add("demo_v3.log", level="INFO")
def ensure_directory_exists(directory):
directory = str(directory)
if not os.path.exists(directory):
os.makedirs(directory)
VALID_STRUCTURE_PATTERN = ["hook", "break", "pre-chorus", "solo", "inst", "end", "outro", "bridge", "chorus", "verse", "intro", "start"]
def is_structure_tag(lin):
lin = lin.lower()
pattern = re.compile(r"\[.*\]")
for tag in VALID_STRUCTURE_PATTERN:
if tag in lin and pattern.match(lin):
return True
return False
# 重新tokenize的逻辑
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
"ko": 6152, "hi": 6680
}
structure_pattern = re.compile(r"\[.*?\]")
class InferDemo:
def __init__(self, args):
logger.info(f"init model with checkpoint: {args.checkpoint_path}")
model_checkpoint_name = "AceFlow3_250401" + Path(args.checkpoint_path).stem
if args.bf16:
self.dtype = torch.bfloat16
else:
self.dtype = torch.float32
self.device = "cuda:0"
self.model_checkpoint_name = model_checkpoint_name
self.checkpoint_path = ""
lang_segment = LangSegment()
lang_segment.setfilters([
'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
])
self.lang_segment = lang_segment
self.lyric_tokenizer = VoiceBpeTokenizer()
def reload_model(self, checkpoint_path):
if checkpoint_path in self.checkpoint_path or self.checkpoint_path == checkpoint_path:
return
logger.info(f"re-init model with checkpoint: {checkpoint_path}")
model_checkpoint_name = "AceFlow3_250401" + Path(checkpoint_path).stem
checkpoint = torch.load(checkpoint_path, map_location='cpu')
from main_text2music_large_sana_dcae_0331_finetune import Pipeline
model = Pipeline(infer=True, train=False)
model.load_state_dict(checkpoint, strict=False)
self.model = model.eval().to(self.device).to(self.dtype)
self.model_checkpoint_name = model_checkpoint_name
self.checkpoint_path = checkpoint_path
self.tokenizer = VoiceBpeTokenizer()
def save_wav_file(self, target_wav, idx, sample_rate=48000):
base_path = f"./test_results/{self.model_checkpoint_name}/demo_outputs"
ensure_directory_exists(base_path)
# 压缩成mp3
output_path_flac = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.flac"
target_wav = target_wav.float()
torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format='flac', backend="ffmpeg", compression=torchaudio.io.CodecConfig(bit_rate=320000))
return output_path_flac
def set_seeds(self, batch_size, manual_seeds=None):
seeds = None
if manual_seeds is not None:
if isinstance(manual_seeds, str):
if "," in manual_seeds:
seeds = list(map(int, manual_seeds.split(",")))
elif manual_seeds.isdigit():
seeds = int(manual_seeds)
random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
actual_seeds = []
for i in range(batch_size):
seed = None
if seeds is None:
seed = torch.randint(0, 2**32, (1,)).item()
if isinstance(seeds, int):
seed = seeds
if isinstance(seeds, list):
seed = seeds[i]
logger.info(f"batch idx: {i}, seed: {seed}")
random_generators[i].manual_seed(seed)
actual_seeds.append(seed)
return random_generators, actual_seeds
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000):
output_audio_paths = []
bs = latents.shape[0]
audio_lengths = [target_wav_duration_second * sample_rate] * bs
pred_latents = latents
with torch.no_grad():
_, pred_wavs = self.model.vae.decode(pred_latents, sr=sample_rate)
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
for i in tqdm(range(bs)):
output_audio_path = self.save_wav_file(pred_wavs[i], i, sample_rate=sample_rate)
output_audio_paths.append(output_audio_path)
return output_audio_paths
def get_lang(self, text):
language = "en"
try:
langs = self.lang_segment.getTexts(text)
langCounts = self.lang_segment.getCounts()
language = langCounts[0][0]
if len(langCounts) > 1 and language == "en":
language = langCounts[1][0]
except Exception as err:
language = "en"
return language
def tokenize_lyrics(self, lyrics, debug=False):
lines = lyrics.split("\n")
lyric_token_idx = [261]
for line in lines:
line = line.strip()
if not line:
lyric_token_idx += [2]
continue
lang = self.get_lang(line)
if lang not in SUPPORT_LANGUAGES:
lang = "en"
if "zh" in lang:
lang = "zh"
if "spa" in lang:
lang = "es"
try:
if structure_pattern.match(line):
token_idx = self.lyric_tokenizer.encode(line, "en")
else:
token_idx = self.lyric_tokenizer.encode(line, lang)
if debug:
toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
logger.info(f"debbug {line} --> {lang} --> {toks}")
lyric_token_idx = lyric_token_idx + token_idx + [2]
except Exception as e:
print("tokenize error", e, "for line", line, "major_language", lang)
return lyric_token_idx
@torch.no_grad()
def text2music_diffusion_process(
self,
duration,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
random_generators=None,
infer_steps=60,
guidance_scale=15.0,
omega_scale=10.0,
scheduler_type="euler",
cfg_type="apg",
):
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
do_classifier_free_guidance = True
if guidance_scale == 0.0 or guidance_scale == 1.0:
do_classifier_free_guidance = False
device = encoder_text_hidden_states.device
dtype = encoder_text_hidden_states.dtype
bsz = encoder_text_hidden_states.shape[0]
if scheduler_type == "euler":
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=3.0,
)
elif scheduler_type == "heun":
scheduler = FlowMatchHeunDiscreteScheduler(
num_train_timesteps=1000,
shift=3.0,
)
frame_length = int(duration * 44100 / 512 / 8)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
if do_classifier_free_guidance:
attention_mask = torch.cat([attention_mask] * 2, dim=0)
encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
momentum_buffer = MomentumBuffer()
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
# expand the latents if we are doing classifier free guidance
latents = target_latents
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.model.transformers(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embds,
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
timestep=timestep,
).sample
if do_classifier_free_guidance:
noise_pred_with_cond, noise_pred_uncond = noise_pred.chunk(2)
if cfg_type == "apg":
noise_pred = apg_forward(
pred_cond=noise_pred_with_cond,
pred_uncond=noise_pred_uncond,
guidance_scale=guidance_scale,
momentum_buffer=momentum_buffer,
)
else:
noise_pred = cfg_forward(
cond_output=noise_pred_with_cond,
uncond_output=noise_pred_uncond,
cfg_strength=guidance_scale,
)
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
return target_latents
@torch.no_grad()
def process_text2music(
self,
audio_duration,
prompt,
lyrics,
input_params_json,
selected_checkpoint,
scheduler_type,
cfg_type,
infer_step,
guidance_scale,
omega_scale,
manual_seeds,
):
# 1 check if need to reload model
if selected_checkpoint is not None and self.checkpoint_path != selected_checkpoint:
self.reload_model(selected_checkpoint)
batch_size = 2
# 2 set seed
random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
# 8 x 16 x T//8
# 4 prompt
texts = [prompt]
encoder_text_hidden_states, text_attention_mask = self.model.lyric_processor.get_text_embeddings(texts, self.device)
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
text_attention_mask = text_attention_mask.repeat(batch_size, 1)
speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
# 6 lyric
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
if len(lyrics) > 0:
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=True)
lyric_mask = [1] * len(lyric_token_idx)
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
if audio_duration <= 0:
audio_duration = random.uniform(30.0, 300.0)
logger.info(f"random audio duration: {audio_duration}")
# 7. encode
target_latents = self.text2music_diffusion_process(
audio_duration,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embds=speaker_embeds,
lyric_token_ids=lyric_token_idx,
lyric_mask=lyric_mask,
guidance_scale=guidance_scale,
omega_scale=omega_scale,
infer_steps=infer_step,
random_generators=random_generators,
scheduler_type=scheduler_type,
cfg_type=cfg_type,
)
# 8 latents2audio
output_paths = self.latents2audio(latents=target_latents, target_wav_duration_second=audio_duration)
if input_params_json is None:
input_params_json = {}
input_params_json["prompt"] = prompt
input_params_json["lyrics"] = lyrics
input_params_json["infer_steps"] = infer_step
input_params_json["guidance_scale"] = guidance_scale
input_params_json["manual_seeds"] = manual_seeds
input_params_json["actual_seeds"] = actual_seeds
input_params_json["checkpoint_path"] = self.checkpoint_path
input_params_json["omega_scale"] = omega_scale
input_params_json["scheduler_type"] = scheduler_type
input_params_json["cfg_type"] = cfg_type
input_params_json["audio_duration"] = audio_duration
logger.info(json.dumps(input_params_json, indent=4, ensure_ascii=False))
return output_paths + [input_params_json]
def main(args):
model_demo = InferDemo(args)
demo = create_main_demo_ui(
checkpoint_path=args.checkpoint_path,
text2music_process_func=model_demo.process_text2music,
)
demo.launch(
server_name="0.0.0.0",
server_port=args.port,
auth=same_auth,
share=args.share
)
if __name__ == "__main__":
main(args)

View File

@@ -0,0 +1,866 @@
"""
This file bundles language identification functions.
Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
Based on research by Marco Lui and Tim Baldwin.
See LICENSE file for more info.
https://github.com/adbar/py3langid
Projects:
https://github.com/juntaosun/LangSegment
"""
import os
import re
import sys
import numpy as np
from collections import Counter
from collections import defaultdict
# import langid
# import py3langid as langid
# pip install py3langid==0.2.2
# 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
# langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
# For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
from py3langid.langid import LanguageIdentifier, MODEL_FILE
# Digital processing
try:from .utils.num import num2str
except ImportError:
try:from utils.num import num2str
except ImportError as e:
raise e
# -----------------------------------
# 更新日志:新版本分词更加精准。
# Changelog: The new version of the word segmentation is more accurate.
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
# -----------------------------------
# Word segmentation function:
# automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
# making it more suitable for TTS processing.
# This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
# This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
#===========================================================================================================
#分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
#このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
#===========================================================================================================
#(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
#2手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
#この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=koを対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
#===========================================================================================================
#===========================================================================================================
# 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
# 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
#===========================================================================================================
# (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
# (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
# 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
#===========================================================================================================
# ===========================================================================================================
# 分词功能:将文章或句子里的例如(中/英/日/韩按不同语言自动识别并拆分让它更适合TTS处理。
# 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
# ===========================================================================================================
# 1自动分词“韩语中的오빠读什么呢あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
# 2手动分词“你的名字叫<ja>佐々木?<ja>吗?”
# 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko, 实际上可支持多达 97 种不同的语言混合处理。
# ===========================================================================================================
# 手动分词标签规范:<语言标签>文本内容</语言标签>
# 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
# Manual word segmentation tag specification: <language tags> text content </language tags>
# 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
# ===========================================================================================================
# For manual word segmentation, labels need to appear in pairs, such as:
# 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
# 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
# Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
# ===========================================================================================================
# ===========================================================================================================
# 语音合成标记语言 SSML , 这里只支持它的标签(非 XMLSpeech Synthesis Markup Language SSML, only its tags are supported here (not XML)
# 想支持更多的 SSML 标签?欢迎 PR Want to support more SSML tags? PRs are welcome!
# 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
# Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
# ===========================================================================================================
# 中文实现Chinese implementation:
# 【SSML】<number>=中文大写数字读法(单字)
# 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
# 【SSML】<currency>=按金额发音。
# 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
# ===========================================================================================================
class LangSSML:
def __init__(self):
# 纯数字
self._zh_numerals_number = {
'0': '',
'1': '',
'2': '',
'3': '',
'4': '',
'5': '',
'6': '',
'7': '',
'8': '',
'9': ''
}
# 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
# Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
def _format_chinese_data(self, date_str:str):
# 处理日期格式
input_date = date_str
if date_str is None or date_str.strip() == "":return ""
date_str = re.sub(r"[\/\._|年|月]","-",date_str)
date_str = re.sub(r"",r"",date_str)
date_arrs = date_str.split(' ')
if len(date_arrs) == 1 and ":" in date_arrs[0]:
time_str = date_arrs[0]
date_arrs = []
else:
time_str = date_arrs[1] if len(date_arrs) >=2 else ""
def nonZero(num,cn,func=None):
if func is not None:num=func(num)
return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
f_number = self.to_chinese_number
f_currency = self.to_chinese_currency
# year, month, day
year_month_day = ""
if len(date_arrs) > 0:
year, month, day = "","",""
parts = date_arrs[0].split('-')
if len(parts) == 3: # 格式为 YYYY-MM-DD
year, month, day = parts
elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
if len(parts[0]) == 4: # 年-月
year, month = parts
else:month, day = parts # 月-日
elif len(parts[0]) > 0: # 仅有月-日或年
if len(parts[0]) == 4:
year = parts[0]
else:day = parts[0]
year,month,day = nonZero(year,"",f_number),nonZero(month,"",f_currency),nonZero(day,"",f_currency)
year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
# hours, minutes, seconds
time_str = re.sub(r"[\/\.\-_]",":",time_str)
time_arrs = time_str.split(":")
hours, minutes, seconds = "","",""
if len(time_arrs) == 3: # H/M/S
hours, minutes, seconds = time_arrs
elif len(time_arrs) == 2:# H/M
hours, minutes = time_arrs
elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}' # H
if len(time_arrs) > 1:
hours, minutes, seconds = nonZero(hours,"",f_currency),nonZero(minutes,"",f_currency),nonZero(seconds,"",f_currency)
hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
output_date = f"{year_month_day}{hours_minutes_seconds}"
return output_date
# 【SSML】number=中文大写数字读法(单字)
# Chinese Numbers(single word)
def to_chinese_number(self, num:str):
pattern = r'(\d+)'
zh_numerals = self._zh_numerals_number
arrs = re.split(pattern, num)
output = ""
for item in arrs:
if re.match(pattern,item):
output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
else:output += item
output = output.replace(".","")
return output
# 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
# Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
def to_chinese_telephone(self, num:str):
output = self.to_chinese_number(num.replace("+86","")) # zh +86
output = output.replace("","")
return output
# 【SSML】currency=按金额发音。
# Digital processing from GPT_SoVITS num.py thanks
def to_chinese_currency(self, num:str):
pattern = r'(\d+)'
arrs = re.split(pattern, num)
output = ""
for item in arrs:
if re.match(pattern,item):
output += num2str(item)
else:output += item
output = output.replace(".","")
return output
# 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
def to_chinese_date(self, num:str):
chinese_date = self._format_chinese_data(num)
return chinese_date
class LangSegment:
def __init__(self):
self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
self._text_cache = None
self._text_lasts = None
self._text_langs = None
self._lang_count = None
self._lang_eos = None
# 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
# Customizable language matching tags: These are supported이 표현들은 모두 지지합니다
# <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
# 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
# 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
# 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
# The language filter group function allows you to specify reserved languages.
# Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
# 排名越前优先级越高The higher the ranking, the higher the priorityランキングが上位になるほど、優先度が高くなります。
# 系统默认过滤器。System default filter。(ISO 639-1 codes given)
# ----------------------------------------------------------------------------------------------------------------------------------
# "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
# "th"泰语=Thai
# ----------------------------------------------------------------------------------------------------------------------------------
self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
# 用户可自定义过滤器。User-defined filters
self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本
# 合并文本
self.isLangMerge = True
# 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
# 请使用API启用self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
# 预览版功能,自动启用或禁用,无需设置
# Preview feature, automatically enabled or disabled, no settings required
self.EnablePreview = False
# 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
# In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
# 示例:您可以任意指定多种组合,进行过滤
# Example: You can specify any combination to filter
# 中/日语言优先级阀值(评分范围为 0 ~ 1:评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
# 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
# 中国語/日本語の優先度しきい値スコア範囲0〜1:スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
# Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
# Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
self.LangPriorityThreshold = 0.89
# Langfilters = ["zh"] # 按中文识别
# Langfilters = ["en"] # 按英文识别
# Langfilters = ["ja"] # 按日文识别
# Langfilters = ["ko"] # 按韩文识别
# Langfilters = ["zh_ja"] # 中日混合识别
# Langfilters = ["zh_en"] # 中英混合识别
# Langfilters = ["ja_en"] # 日英混合识别
# Langfilters = ["zh_ko"] # 中韩混合识别
# Langfilters = ["ja_ko"] # 日韩混合识别
# Langfilters = ["en_ko"] # 英韩混合识别
# Langfilters = ["zh_ja_en"] # 中日英混合识别
# Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
# 更多过滤组合请您随意。。。For more filter combinations, please feel free to......
# より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
# 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
# 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
self.keepPinyin = False
# DEFINITION
self.PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
self.LangSSML = LangSSML()
def _clears(self):
self._text_cache = None
self._text_lasts = None
self._text_langs = None
self._text_waits = None
self._lang_count = None
self._lang_eos = None
def _is_english_word(self, word):
return bool(re.match(r'^[a-zA-Z]+$', word))
def _is_chinese(self, word):
for char in word:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def _is_japanese_kana(self, word):
pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
matches = pattern.findall(word)
return len(matches) > 0
def _insert_english_uppercase(self, word):
modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
modified_text = modified_text.strip('-')
return modified_text + " "
def _split_camel_case(self, word):
return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
def _statistics(self, language, text):
# Language word statistics:
# Chinese characters usually occupy double bytes
if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
self._lang_count = defaultdict(int)
lang_count = self._lang_count
if not "|" in language:
lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
self._lang_count = lang_count
def _clear_text_number(self, text):
if text == "\n":return text,False # Keep Line Breaks
clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
return clear_text,is_number
def _saveData(self, words,language:str,text:str,score:float,symbol=None):
# Pre-detection
clear_text , is_number = self._clear_text_number(text)
# Merge the same language and save the results
preData = words[-1] if len(words) > 0 else None
if symbol is not None:pass
elif preData is not None and preData["symbol"] is None:
if len(clear_text) == 0:language = preData["lang"]
elif is_number == True:language = preData["lang"]
_ , pre_is_number = self._clear_text_number(preData["text"])
if (preData["lang"] == language):
self._statistics(preData["lang"],text)
text = preData["text"] + text
preData["text"] = text
return preData
elif pre_is_number == True:
text = f'{preData["text"]}{text}'
words.pop()
elif is_number == True:
priority_language = self._get_filters_string()[:2]
if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
data = {"lang":language,"text": text,"score":score,"symbol":symbol}
filters = self.Langfilters
if filters is None or len(filters) == 0 or "?" in language or \
language in filters or language in filters[0] or \
filters[0] == "*" or filters[0] in "alls-mixs-autos":
words.append(data)
self._statistics(data["lang"],data["text"])
return data
def _addwords(self, words,language,text,score,symbol=None):
if text == "\n":pass # Keep Line Breaks
elif text is None or len(text.strip()) == 0:return True
if language is None:language = ""
language = language.lower()
if language == 'en':text = self._insert_english_uppercase(text)
# text = re.sub(r'[()]', ',' , text) # Keep it.
text_waits = self._text_waits
ispre_waits = len(text_waits)>0
preResult = text_waits.pop() if ispre_waits else None
if preResult is None:preResult = words[-1] if len(words) > 0 else None
if preResult and ("|" in preResult["lang"]):
pre_lang = preResult["lang"]
if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
else:preResult["lang"]=pre_lang.split("|")[0]
if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
pre_lang = preResult["lang"] if preResult else None
if ("|" in language) and (pre_lang and not pre_lang in language and not "" in language):language = language.split("|")[0]
if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
else:self._saveData(words,language,text,score,symbol)
return False
def _get_prev_data(self, words):
data = words[-1] if words and len(words) > 0 else None
if data:return (data["lang"] , data["text"])
return (None,"")
def _match_ending(self, input , index):
if input is None or len(input) == 0:return False,None
input = re.sub(r'\s+', '', input)
if len(input) == 0 or abs(index) > len(input):return False,None
ending_pattern = re.compile(r'([「」“”‘’"\'::。.!?])')
return ending_pattern.match(input[index]),input[index]
def _cleans_text(self, cleans_text):
cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
return cleans_text.strip()
def _mean_processing(self, text:str):
if text is None or (text.strip()) == "":return None , 0.0
arrs = self._split_camel_case(text).split(" ")
langs = []
for t in arrs:
if len(t.strip()) <= 3:continue
language, score = self.langid.classify(t)
langs.append({"lang":language})
if len(langs) == 0:return None , 0.0
return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
def _lang_classify(self, cleans_text):
language, score = self.langid.classify(cleans_text)
# fix: Huggingface is np.float32
if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
score = score.item()
score = round(score , 3)
return language, score
def _get_filters_string(self):
filters = self.Langfilters
return "-".join(filters).lower().strip() if filters is not None else ""
def _parse_language(self, words , segment):
LANG_JA = "ja"
LANG_ZH = "zh"
LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
language = LANG_ZH
regex_pattern = re.compile(r'([^\w\s]+)')
lines = regex_pattern.split(segment)
lines_max = len(lines)
LANG_EOS =self._lang_eos
for index, text in enumerate(lines):
if len(text) == 0:continue
EOS = index >= (lines_max - 1)
nextId = index + 1
nextText = lines[nextId] if not EOS else ""
nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
lines[nextId] = f'{text}{nextText}'
continue
number_tags = re.compile(r'(⑥\d{6,}⑥)')
cleans_text = re.sub(number_tags, '' ,text)
cleans_text = re.sub(r'\d+', '' ,cleans_text)
cleans_text = self._cleans_text(cleans_text)
# fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
if not EOS and len(cleans_text) <= 2:
lines[nextId] = f'{text}{nextText}'
continue
language,score = self._lang_classify(cleans_text)
prev_language , prev_text = self._get_prev_data(words)
if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
filters_string = self._get_filters_string()
if score < self.LangPriorityThreshold and len(filters_string) > 0:
index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
if index_ja != -1 and index_ja < index_zh:language = LANG_JA
elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
if self._is_japanese_kana(cleans_text):language = LANG_JA
elif len(cleans_text) > 2 and score > 0.90:pass
elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
else:
LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
match_end,match_char = self._match_ending(text, -1)
referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
else:language = f"{LANG_UNKNOWN}|…"
text,*_ = re.subn(number_tags , self._restore_number , text )
self._addwords(words,language,text,score)
# ----------------------------------------------------------
# 【SSML】中文数字处理Chinese Number Processing (SSML support)
# 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
# The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
# 中文电话号码:<telephone>1234567</telephone>
# 中文数字号码:<number>1234567</number>
def _process_symbol_SSML(self, words,data):
tag , match = data
language = SSML = match[1]
text = match[2]
score = 1.0
if SSML == "telephone":
# 中文-电话号码
language = "zh"
text = self.LangSSML.to_chinese_telephone(text)
elif SSML == "number":
# 中文-数字读法
language = "zh"
text = self.LangSSML.to_chinese_number(text)
elif SSML == "currency":
# 中文-按金额发音
language = "zh"
text = self.LangSSML.to_chinese_currency(text)
elif SSML == "date":
# 中文-按金额发音
language = "zh"
text = self.LangSSML.to_chinese_date(text)
self._addwords(words,language,text,score,SSML)
# ----------------------------------------------------------
def _restore_number(self, matche):
value = matche.group(0)
text_cache = self._text_cache
if value in text_cache:
process , data = text_cache[value]
tag , match = data
value = match
return value
def _pattern_symbols(self, item , text):
if text is None:return text
tag , pattern , process = item
matches = pattern.findall(text)
if len(matches) == 1 and "".join(matches[0]) == text:
return text
for i , match in enumerate(matches):
key = f"{tag}{i:06d}"
text = re.sub(pattern , key , text , count=1)
self._text_cache[key] = (process , (tag , match))
return text
def _process_symbol(self, words,data):
tag , match = data
language = match[1]
text = match[2]
score = 1.0
filters = self._get_filters_string()
if language not in filters:
self._process_symbol_SSML(words,data)
else:
self._addwords(words,language,text,score,True)
def _process_english(self, words,data):
tag , match = data
text = match[0]
filters = self._get_filters_string()
priority_language = filters[:2]
# Preview feature, other language segmentation processing
enablePreview = self.EnablePreview
if enablePreview == True:
# Experimental: Other language support
regex_pattern = re.compile(r'(.*?[。.?!]+[\n]{,1})')
lines = regex_pattern.split(text)
for index , text in enumerate(lines):
if len(text.strip()) == 0:continue
cleans_text = self._cleans_text(text)
language,score = self._lang_classify(cleans_text)
if language not in filters:
language,score = self._mean_processing(cleans_text)
if language is None or score <= 0.0:continue
elif language in filters:pass # pass
elif score >= 0.95:continue # High score, but not in the filter, excluded.
elif score <= 0.15 and filters[:2] == "fr":language = priority_language
else:language = "en"
self._addwords(words,language,text,score)
else:
# Default is English
language, score = "en", 1.0
self._addwords(words,language,text,score)
def _process_Russian(self, words,data):
tag , match = data
text = match[0]
language = "ru"
score = 1.0
self._addwords(words,language,text,score)
def _process_Thai(self, words,data):
tag , match = data
text = match[0]
language = "th"
score = 1.0
self._addwords(words,language,text,score)
def _process_korean(self, words,data):
tag , match = data
text = match[0]
language = "ko"
score = 1.0
self._addwords(words,language,text,score)
def _process_quotes(self, words,data):
tag , match = data
text = "".join(match)
childs = self.PARSE_TAG.findall(text)
if len(childs) > 0:
self._process_tags(words , text , False)
else:
cleans_text = self._cleans_text(match[1])
if len(cleans_text) <= 5:
self._parse_language(words,text)
else:
language,score = self._lang_classify(cleans_text)
self._addwords(words,language,text,score)
def _process_pinyin(self, words,data):
tag , match = data
text = match
language = "zh"
score = 1.0
self._addwords(words,language,text,score)
def _process_number(self, words,data): # "$0" process only
"""
Numbers alone cannot accurately identify language.
Because numbers are universal in all languages.
So it won't be executed here, just for testing.
"""
tag , match = data
language = words[0]["lang"] if len(words) > 0 else "zh"
text = match
score = 0.0
self._addwords(words,language,text,score)
def _process_tags(self, words , text , root_tag):
text_cache = self._text_cache
segments = re.split(self.PARSE_TAG, text)
segments_len = len(segments) - 1
for index , text in enumerate(segments):
if root_tag:self._lang_eos = index >= segments_len
if self.PARSE_TAG.match(text):
process , data = text_cache[text]
if process:process(words , data)
else:
self._parse_language(words , text)
return words
def _merge_results(self, words):
new_word = []
for index , cur_data in enumerate(words):
if "symbol" in cur_data:del cur_data["symbol"]
if index == 0:new_word.append(cur_data)
else:
pre_data = new_word[-1]
if cur_data["lang"] == pre_data["lang"]:
pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
else:new_word.append(cur_data)
return new_word
def _parse_symbols(self, text):
TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
# Get custom language filter
filters = self.Langfilters
filters = filters if filters is not None else ""
# =======================================================================================================
# Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge dautres langues.
# 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
# If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
# S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
# Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
# -------------------------------------------------------------------------------------------------------
# Preview feature, other language support
enablePreview = self.EnablePreview
if "fr" in filters or \
"vi" in filters:enablePreview = True
self.EnablePreview = enablePreview
# 实验性法语字符支持。Prise en charge des caractères français
RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
# 实验性越南语字符支持。Hỗ trợ ký tự tiếng Việt
RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
# -------------------------------------------------------------------------------------------------------
# Basic options:
process_list = [
( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
]
# Extended options: Default False
if self.keepPinyin == True:process_list.insert(1 ,
( TAG_S2 , re.compile(r'([\({](?:\s*\w*\d\w*\s*)+[}\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
)
# -------------------------------------------------------------------------------------------------------
words = []
lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
for index , text in enumerate(lines):
if len(text.strip()) == 0:continue
self._lang_eos = False
self._text_cache = {}
for item in process_list:
text = self._pattern_symbols(item , text)
cur_word = self._process_tags([] , text , True)
if len(cur_word) == 0:continue
cur_data = cur_word[0] if len(cur_word) > 0 else None
pre_data = words[-1] if len(words) > 0 else None
if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] \
and cur_data["symbol"] == False and pre_data["symbol"] :
cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
words.pop()
words += cur_word
if self.isLangMerge == True:words = self._merge_results(words)
lang_count = self._lang_count
if lang_count and len(lang_count) > 0:
lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
lang_count = list(lang_count.items())
self._lang_count = lang_count
return words
def setfilters(self, filters):
# 当过滤器更改时,清除缓存
# 필터가 변경되면 캐시를 지웁니다.
# フィルタが変更されると、キャッシュがクリアされます
# When the filter changes, clear the cache
if self.Langfilters != filters:
self._clears()
self.Langfilters = filters
def getfilters(self):
return self.Langfilters
def setPriorityThreshold(self, threshold:float):
self.LangPriorityThreshold = threshold
def getPriorityThreshold(self):
return self.LangPriorityThreshold
def getCounts(self):
lang_count = self._lang_count
if lang_count is not None:return lang_count
text_langs = self._text_langs
if text_langs is None or len(text_langs) == 0:return [("zh",0)]
lang_counts = defaultdict(int)
for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
lang_counts = list(lang_counts.items())
self._lang_count = lang_counts
return lang_counts
def getTexts(self, text:str):
if text is None or len(text.strip()) == 0:
self._clears()
return []
# lasts
text_langs = self._text_langs
if self._text_lasts == text and text_langs is not None:return text_langs
# parse
self._text_waits = []
self._lang_count = None
self._text_lasts = text
text = self._parse_symbols(text)
self._text_langs = text
return text
def classify(self, text:str):
return self.getTexts(text)
def printList(langlist):
"""
功能:打印数组结果
기능: 어레이 결과 인쇄
機能:配列結果を印刷
Function: Print array results
"""
print("\n===================【打印结果】===================")
if langlist is None or len(langlist) == 0:
print("无内容结果,No content result")
return
for line in langlist:
print(line)
pass
def main():
# -----------------------------------
# 更新日志:新版本分词更加精准。
# Changelog: The new version of the word segmentation is more accurate.
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
# -----------------------------------
# 输入示例1包含日文中文Input Example 1: (including Japanese, Chinese)
# text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
# 输入示例2包含日文中文Input Example 1: (including Japanese, Chinese)
# text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
# 输入示例3包含日文中文Input Example 1: (including Japanese, Chinese)
# text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
# 输入示例4包含日文中文韩语英文Input Example 4: (including Japanese, Chinese, Korean, English)
# text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品这次的iPad Air采用了LCD屏幕"
# 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
langsegment = LangSegment()
langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
text = """
我喜欢在雨天里听音乐。
I enjoy listening to music on rainy days.
雨の日に音楽を聴くのが好きです。
비 오는 날에 음악을 듣는 것을 즐깁니다。
J'aime écouter de la musique les jours de pluie.
Tôi thích nghe nhạc vào những ngày mưa.
Мне нравится слушать музыку в дождливую погоду.
ฉันชอบฟังเพลงในวันที่ฝนตก
"""
# 进行分词接入TTS项目仅需一行代码调用Segmentation: (Only one line of code is required to access the TTS project)
langlist = langsegment.getTexts(text)
printList(langlist)
# 语种统计:Language statistics:
print("\n===================【语种统计】===================")
# 获取所有语种数组结果,根据内容字数降序排列
# Get the array results in all languages, sorted in descending order according to the number of content words
langCounts = langsegment.getCounts()
print(langCounts , "\n")
# 根据结果获取内容的主要语种 (语言,字数含标点)
# Get the main language of content based on the results (language, word count including punctuation)
lang , count = langCounts[0]
print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
print("==================================================\n")
# 分词输出lang=语言text=内容。Word output: lang = language, text = content
# ===================【打印结果】===================
# {'lang': 'zh', 'text': '你的名字叫'}
# {'lang': 'ja', 'text': '佐々木?'}
# {'lang': 'zh', 'text': '吗?韩语中的'}
# {'lang': 'ko', 'text': '안녕 오빠'}
# {'lang': 'zh', 'text': '读什么呢?'}
# {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
# {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
# {'lang': 'en', 'text': 'i Phone '}
# {'lang': 'zh', 'text': '15系列机型和三款'}
# {'lang': 'en', 'text': 'Apple Watch '}
# {'lang': 'zh', 'text': '等一系列新品,这次的'}
# {'lang': 'en', 'text': 'i Pad Air '}
# {'lang': 'zh', 'text': '采用了'}
# {'lang': 'en', 'text': 'L C D '}
# {'lang': 'zh', 'text': '屏幕'}
# ===================【语种统计】===================
# ===================【语种统计】===================
# [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
# 输入内容的主要语言为 = zh ,字数 = 51
# ==================================================
# The main language of the input content is = zh, word count = 51
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,9 @@
from .LangSegment import LangSegment
# release
__version__ = '0.3.5'
# develop
__develop__ = 'dev-0.0.1'

View File

View File

@@ -0,0 +1,327 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Digital processing from GPT_SoVITS num.py thanks
"""
Rules to verbalize numbers into Chinese characters.
https://zh.wikipedia.org/wiki/中文数字#現代中文
"""
import re
from collections import OrderedDict
from typing import List
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
UNITS = OrderedDict({
1: '',
2: '',
3: '',
4: '',
8: '亿',
})
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
# 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
def replace_frac(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
sign = match.group(1)
nominator = match.group(2)
denominator = match.group(3)
sign: str = "" if sign else ""
nominator: str = num2str(nominator)
denominator: str = num2str(denominator)
result = f"{sign}{denominator}分之{nominator}"
return result
# 百分数表达式
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
def replace_percentage(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
sign = match.group(1)
percent = match.group(2)
sign: str = "" if sign else ""
percent: str = num2str(percent)
result = f"{sign}百分之{percent}"
return result
# 整数表达式
# 带负号的整数 -10
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
def replace_negative_num(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
sign = match.group(1)
number = match.group(2)
sign: str = "" if sign else ""
number: str = num2str(number)
result = f"{sign}{number}"
return result
# 编号-无符号整形
# 00078
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
def replace_default_num(match):
"""
Args:
match (re.Match)
Returns:
str
"""
number = match.group(0)
return verbalize_digit(number, alt_one=True)
# 加减乘除
# RE_ASMD = re.compile(
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
RE_ASMD = re.compile(
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
asmd_map = {
'+': '',
'-': '',
'×': '',
'÷': '',
'=': '等于'
}
def replace_asmd(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
return result
# 次方专项
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
power_map = {
'': '0',
'¹': '1',
'²': '2',
'³': '3',
'': '4',
'': '5',
'': '6',
'': '7',
'': '8',
'': '9',
'ˣ': 'x',
'ʸ': 'y',
'': 'n'
}
def replace_power(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
power_num = ""
for m in match.group(0):
power_num += power_map[m]
result = "" + power_num + "次方"
return result
# 数字表达式
# 纯小数
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
# 正整数 + 量词
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
def replace_positive_quantifier(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
number = match.group(1)
match_2 = match.group(2)
if match_2 == "+":
match_2 = ""
match_2: str = match_2 if match_2 else ""
quantifiers: str = match.group(3)
number: str = num2str(number)
result = f"{number}{match_2}{quantifiers}"
return result
def replace_number(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
sign = match.group(1)
number = match.group(2)
pure_decimal = match.group(5)
if pure_decimal:
result = num2str(pure_decimal)
else:
sign: str = "" if sign else ""
number: str = num2str(number)
result = f"{sign}{number}"
return result
# 范围表达式
# match.group(1) and match.group(8) are copy from RE_NUMBER
RE_RANGE = re.compile(
r"""
(?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
[-~] # 匹配范围分隔符
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
""", re.VERBOSE)
def replace_range(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
first, second = match.group(1), match.group(6)
first = RE_NUMBER.sub(replace_number, first)
second = RE_NUMBER.sub(replace_number, second)
result = f"{first}{second}"
return result
# ~至表达式
RE_TO_RANGE = re.compile(
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
def replace_to_range(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = match.group(0).replace('~', '')
return result
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0')
if len(stripped) == 0:
return []
elif len(stripped) == 1:
if use_zero and len(stripped) < len(value_string):
return [DIGITS['0'], DIGITS[stripped]]
else:
return [DIGITS[stripped]]
else:
largest_unit = next(
power for power in reversed(UNITS.keys()) if power < len(stripped))
first_part = value_string[:-largest_unit]
second_part = value_string[-largest_unit:]
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
second_part)
def verbalize_cardinal(value_string: str) -> str:
if not value_string:
return ''
# 000 -> '零' , 0 -> '零'
value_string = value_string.lstrip('0')
if len(value_string) == 0:
return DIGITS['0']
result_symbols = _get_value(value_string)
# verbalized number starting with '一十*' is abbreviated as `十*`
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
'1'] and result_symbols[1] == UNITS[1]:
result_symbols = result_symbols[1:]
return ''.join(result_symbols)
def verbalize_digit(value_string: str, alt_one=False) -> str:
result_symbols = [DIGITS[digit] for digit in value_string]
result = ''.join(result_symbols)
if alt_one:
result = result.replace("", "")
return result
def num2str(value_string: str) -> str:
integer_decimal = value_string.split('.')
if len(integer_decimal) == 1:
integer = integer_decimal[0]
decimal = ''
elif len(integer_decimal) == 2:
integer, decimal = integer_decimal
else:
raise ValueError(
f"The value string: '${value_string}' has more than one point in it."
)
result = verbalize_cardinal(integer)
decimal = decimal.rstrip('0')
if decimal:
# '.22' is verbalized as '零点二二'
# '3.20' is verbalized as '三点二
result = result if result else ""
result += '' + verbalize_digit(decimal)
return result
if __name__ == "__main__":
text = ""
text = num2str(text)
print(text)
pass

412
lyric_processor_v2.py Normal file
View File

@@ -0,0 +1,412 @@
import torch.nn as nn
import torch
import random
from loguru import logger
from transformers import UMT5EncoderModel, AutoTokenizer, AutoModel
import re
from typing import List, Tuple, Dict, Set
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class TrieNode:
def __init__(self):
self.children: Dict[str, 'TrieNode'] = {}
self.is_end_of_word: bool = False
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, word: str):
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
def search_from(self, s: str, start: int) -> List[str]:
"""
从字符串 s 的位置 start 开始,使用 Trie 树查找所有可能的匹配 phoneme。
返回所有匹配的 phoneme。
"""
node = self.root
matches = []
current_phoneme = []
for i in range(start, len(s)):
char = s[i]
if char in node.children:
node = node.children[char]
current_phoneme.append(char)
if node.is_end_of_word:
matches.append(''.join(current_phoneme))
else:
break
return matches
class PhonemeMatcher:
def __init__(self, word_dict: Set[str]):
"""
初始化 PhonemeMatcher构建 Trie 树。
:param word_dict: Set[str] - 包含所有 phoneme 的集合
"""
self.trie = Trie()
for word in word_dict:
self.trie.insert(word)
def tokenize(self, s: str) -> List[str]:
"""
将输入的 xsampa 字符串拆分成 phoneme 序列,尽可能使用词表中的 phoneme
并在无法完全匹配时,选择编辑距离最小且 phoneme 数量最少的序列。
:param s: str - 输入的 xsampa 字符串
:return: List[str] - 输出的 phoneme 序列
"""
n = len(s)
# 初始化 DP 数组dp[i] = (cost, phoneme_count, phone_list)
dp: List[Tuple[int, int, List[str]]] = [(sys.maxsize, sys.maxsize, []) for _ in range(n + 1)]
dp[0] = (0, 0, [])
for i in range(n):
current_cost, current_count, current_list = dp[i]
if current_cost == sys.maxsize:
continue # 无法到达当前位置
# 查找所有从位置 i 开始的匹配 phoneme
matches = self.trie.search_from(s, i)
if matches:
for phoneme in matches:
end = i + len(phoneme)
new_cost = current_cost # 匹配成功,无需增加编辑距离
new_count = current_count + 1
new_list = current_list + [phoneme]
if new_cost < dp[end][0]:
dp[end] = (new_cost, new_count, new_list)
elif new_cost == dp[end][0]:
if new_count < dp[end][1]:
dp[end] = (new_cost, new_count, new_list)
else:
# 没有匹配的 phoneme考虑跳过当前字符增加编辑距离
new_cost = current_cost + 1
end = i + 1
new_count = current_count + 1 # 跳过一个字符也算作一个 phoneme
new_list = current_list + [s[i]]
if new_cost < dp[end][0]:
dp[end] = (new_cost, new_count, new_list)
elif new_cost == dp[end][0]:
if new_count < dp[end][1]:
dp[end] = (new_cost, new_count, new_list)
# 如果无法完全匹配,选择最优的近似匹配
if dp[n][0] == sys.maxsize:
# 找到所有可能的最小编辑距离
min_cost = min(dp[i][0] for i in range(n + 1))
# 选择最小编辑距离且 phoneme 数量最少的序列
candidates = [dp[i] for i in range(n + 1) if dp[i][0] == min_cost]
if candidates:
# 选择 phoneme 数量最少的
best = min(candidates, key=lambda x: x[1])
return best[2]
else:
return []
return dp[n][2]
HARMONIX_LABELS = [
'start',
'end',
'intro',
'outro',
'break',
'bridge',
'inst',
'solo',
'verse',
'chorus',
]
def timestamp2second(timestamps):
res = []
for item in timestamps:
start, end = item["start"], item["end"]
# convert 8kHz to latents level
start = round(start / 8000, 2)
end = round(end / 8000, 2)
res.append({"start": start, "end": end})
return res
def sample_lyric_mask(voiced_timestamp, max_length):
voiced_timestamps = timestamp2second(voiced_timestamp)
min_gaps = [1,2,3,4,5]
while len(min_gaps) > 0:
min_gap = min_gaps.pop()
can_split_breaks = []
last_end = 0.00
for item in voiced_timestamps:
if item["start"] - last_end >= min_gap:
if last_end == 0.00:
can_split_breaks.append((last_end, item["start"] - 0.5))
else:
can_split_breaks.append((last_end + 0.5, item["start"] - 0.5))
last_end = item["end"]
if len(can_split_breaks) > 1:
if can_split_breaks[1][0] <= 360:
break
else:
if min_gap == 1:
return 0.0, 360.0, 36
if len(can_split_breaks) == 0:
mask_start, mask_end = 0.0, max_length
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
return 0.0, mask_end, min_cut_level
if len(can_split_breaks) == 1:
# 前后随机选一个
mask_start = random.choice(["start", "middle"])
if mask_start == "start":
mask_start = 0.0
mask_end = random.uniform(can_split_breaks[0][0], can_split_breaks[0][1])
else:
mask_start = random.uniform(can_split_breaks[0][0], can_split_breaks[0][1])
mask_end = max_length
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
return mask_start, mask_end, min_cut_level
mask_start, mask_end = 0.0, 370
min_cut_level = 37
breaths_gap = [end-start for start, end in can_split_breaks]
max_tried = 5
while mask_end - mask_start > 370 and min_cut_level > 0 and min_cut_level > 36:
total_breaths = len(can_split_breaks)
start = random.choices(range(total_breaths-1), weights=breaths_gap[:-1])[0]
end = random.choices(range(start + 1, total_breaths), weights=breaths_gap[start+1:], k=1)[0]
start_break, end_break = can_split_breaks[start], can_split_breaks[end]
mask_start, mask_end = random.uniform(start_break[0], start_break[1]), random.uniform(end_break[0], end_break[1])
min_cut_level = int(mask_end//10 - mask_start//10 + 1)
if min_cut_level < 36:
min_cut_level = random.randint(min_cut_level, 36)
if max_tried == 0:
print("max tried", mask_start, mask_end, min_cut_level, "breaths_gap", breaths_gap, "can_split_breaks", can_split_breaks)
break
max_tried -= 1
mask_start, mask_end = round(mask_start, 2), min(round(mask_end, 2), max_length)
return mask_start, mask_end, min_cut_level
def check_valid_lyric_lines(lyric_lines):
# must has lyric lines
if len(lyric_lines) == 0:
return False
for valid_lyric_line in lyric_lines:
if len(valid_lyric_line[1]) > 0:
return True
return False
def select_valid_lyric_lines(lyric_lines, mask_start, mask_end):
# 选歌词原则
# 宁可多,不可少
# 选取mask_start和mask_end之间的歌词行如果mask_end在一个歌词行中间那么这个歌词行也要被选取但最后的structure不要
valid_lyric_lines = []
add_tail_structure = True
for lyric_line in lyric_lines:
if lyric_line["start"] > lyric_line["end"]:
continue
if lyric_line["start"]+1.0 >= mask_start and lyric_line["end"]-1.0 <= mask_end:
if len(valid_lyric_lines) > 0:
if valid_lyric_lines[-1][0] is not None and valid_lyric_lines[-1][0] != lyric_line["structure"] and lyric_line["structure"] != "":
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
elif lyric_line["structure"] != "":
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
lyric_line["lyric_line"] = lyric_line["lyric_line"].strip()
if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0:
valid_lyric_lines.append((None, lyric_line["lyric_line"], lyric_line["phoneme_line_ipa"], (lyric_line["start"], lyric_line["end"])))
elif mask_start < lyric_line["start"] and lyric_line["start"] < mask_end and lyric_line["end"] > mask_end:
lyric_line["lyric_line"] = lyric_line["lyric_line"].strip()
if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0:
valid_lyric_lines.append((None, lyric_line["lyric_line"], lyric_line["phoneme_line_ipa"], (lyric_line["start"], lyric_line["end"])))
add_tail_structure = False
break
elif lyric_line["start"] > mask_start and lyric_line["start"] < mask_end and not lyric_line["lyric_line"] and add_tail_structure:
valid_lyric_lines.append((lyric_line["structure"], [], [], (lyric_line["start"], lyric_line["end"])))
add_tail_structure = False
break
if len(valid_lyric_lines) > 0 and len(lyric_lines) > 0 and add_tail_structure:
if lyric_lines[-1]["structure"] != "" and lyric_lines[-1]["structure"] != valid_lyric_lines[-1][0]:
if lyric_lines[-1]["start"] > mask_start and lyric_lines[-1]["start"] < mask_end:
valid_lyric_lines.append((lyric_lines[-1]["structure"], [], [], (lyric_lines[-1]["start"], lyric_lines[-1]["end"])))
return valid_lyric_lines
def sample_lyric_mask_with_cut_levels(voiced_timestamp, cut_level, n_chunks, lyric_lines):
voiced_timestamps = timestamp2second(voiced_timestamp)
candidate_spans = []
for candidate_start_idx in range(n_chunks):
candidate_start_second = candidate_start_idx * 10
candidate_end_second = (candidate_start_idx + cut_level) * 10
valid = True
for item in voiced_timestamps:
if item["start"] < candidate_start_second and candidate_start_second < item["end"]:
valid = False
break
if item["start"] < candidate_end_second and candidate_end_second < item["end"]:
valid = False
break
valid_lyric_lines = select_valid_lyric_lines(lyric_lines, candidate_start_second, candidate_end_second)
if not check_valid_lyric_lines(valid_lyric_lines):
valid = False
if valid:
candidate_spans.append((candidate_start_second, candidate_end_second, valid_lyric_lines))
if len(candidate_spans) > 0:
return candidate_spans
else:
candidate_spans = []
for candidate_start_idx in range(n_chunks):
candidate_start_second = candidate_start_idx * 10
candidate_end_second = (candidate_start_idx + cut_level) * 10
valid_lyric_lines = select_valid_lyric_lines(lyric_lines, candidate_start_second, candidate_end_second)
if check_valid_lyric_lines(valid_lyric_lines):
candidate_spans.append((candidate_start_second, candidate_end_second, valid_lyric_lines))
if len(candidate_spans) > 0:
return candidate_spans
return []
def sample_lyric_mask_with_lyric_timestamp(cut_level, lyric_lines, expected_num_example, n_chunks, start_pad_offset=1.0):
# 1 去掉structure
# non_structure_lyric_lines = [lyric_line for lyric_line in lyric_lines if lyric_line["lyric_line"] and "phoneme_line_ipa" in lyric_line and len(lyric_line["phoneme_line_ipa"]) > 0 and lyric_line["start"] < lyric_line["end"]]
# 保留structure
valid_lyric_lines = []
last_structure = ""
for lyric_line in lyric_lines:
if "structure" not in lyric_line:
lyric_line["structure"] = ""
if lyric_line["start"] < lyric_line["end"]:
new_line = lyric_line.copy()
if not lyric_line["lyric_line"] or "phoneme_line_ipa" not in lyric_line or len(lyric_line["phoneme_line_ipa"]) == 0:
if lyric_line["structure"] != "":
new_line["lyric_line"] = "["+lyric_line["structure"]+"]"
new_line["phoneme_line_ipa"] = ["_"]
else:
last_structure = lyric_line["structure"]
continue
else:
if new_line["structure"] != "" and new_line["structure"] != last_structure:
if new_line["lyric_line"] != "[" + new_line["structure"] + "]":
new_line["lyric_line"] = f"[{new_line['structure']}]\n{new_line['lyric_line']}"
new_line["phoneme_line_ipa"] = ["_", "_"] + new_line["phoneme_line_ipa"]
valid_lyric_lines.append(new_line)
last_structure = lyric_line["structure"]
# 2 优先选刚好包含在里面的
full_spans = []
partial_spans = []
# print("non_structure_lyric_lines", non_structure_lyric_lines, n_chunks)
for start_idx in range(len(valid_lyric_lines)):
for end_idx in range(start_idx, len(valid_lyric_lines)):
start = valid_lyric_lines[start_idx]["start"]
end = start + cut_level * 10
# print("start_idx:", start_idx, "end_idx:", end_idx, "start:", start, "end:", end, "non_structure_lyric_lines[end_idx]:", non_structure_lyric_lines[end_idx])
if start_idx == end_idx and valid_lyric_lines[start_idx]["end"] > end:
res = [(None, valid_lyric_lines[start_idx]["lyric_line"], valid_lyric_lines[start_idx]["phoneme_line_ipa"], (valid_lyric_lines[start_idx]["start"], valid_lyric_lines[start_idx]["end"])) for line_idx in range(start_idx, end_idx+1)]
if len(res) > 0:
partial_spans.append((start, end, res))
break
if end_idx > 0 and end < valid_lyric_lines[end_idx]["start"] and valid_lyric_lines[end_idx-1]["end"] + start_pad_offset < end:
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx)]
if len(res) > 0:
full_spans.append((start, end, res))
break
if end < valid_lyric_lines[end_idx]["end"] + start_pad_offset and end > valid_lyric_lines[end_idx]["start"]:
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx)]
if len(res) > 0:
partial_spans.append((start, end, res))
break
if valid_lyric_lines[end_idx]["start"] > end:
break
if start_idx == 0 and end_idx == len(valid_lyric_lines) - 1 and len(full_spans) == 0 and len(partial_spans) == 0:
res = [(None, valid_lyric_lines[line_idx]["lyric_line"], valid_lyric_lines[line_idx]["phoneme_line_ipa"], (valid_lyric_lines[line_idx]["start"], valid_lyric_lines[line_idx]["end"])) for line_idx in range(start_idx, end_idx+1)]
if len(res) > 0:
full_spans.append((start, end, res))
if expected_num_example is not None:
if len(full_spans) >= expected_num_example or len(partial_spans) == 0:
return full_spans
if len(full_spans) + len(partial_spans) >= expected_num_example:
left = expected_num_example - len(full_spans)
return full_spans + random.sample(partial_spans, left)
# print("full_spans:", full_spans)
# print("partial_spans:", partial_spans)
return full_spans + partial_spans
class LyricProcessor(nn.Module):
def __init__(self, infer=False):
super().__init__()
self.lyric_text_model = UMT5EncoderModel.from_pretrained("./checkpoints/umt5-base", local_files_only=True).eval().half()
# not required gradient
self.lyric_text_model.requires_grad_(False)
self.lyric_text_tokenizer = AutoTokenizer.from_pretrained("./checkpoints/umt5-base", local_files_only=True)
def get_text_embeddings(self, texts, device, text_max_length=256):
inputs = self.lyric_text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
inputs = {key: value.to(device) for key, value in inputs.items()}
if self.lyric_text_model.device != device:
self.lyric_text_model.to(device)
with torch.no_grad():
outputs = self.lyric_text_model(**inputs)
last_hidden_states = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
return last_hidden_states, attention_mask
def preprocess(self, valid_lyric_lines):
lyric_texts = []
ipa_texts = []
for valid_line in valid_lyric_lines:
structure, lyric_line, ipa_line = valid_line["structure"], valid_line["lyric"], valid_line["ipa"]
if len(structure) > 0:
lyric_texts.append(structure)
if len(lyric_line) > 0:
lyric_texts.append(lyric_line)
if len(structure) == 0 and len(lyric_line) == 0:
lyric_texts.append("")
if ipa_line != "_":
ipa_line = self.split_unk(ipa_line.split(" "))
ipa_line_str = " ".join(ipa_line)
# 处理掉G2P的bug
ipa_line_str = re.sub(r'\bz(?:\s+ə\s+z)+\b', "", ipa_line_str)
ipa_line_str = re.sub(r'\s+', ' ', ipa_line_str).strip()
ipa_texts.append(ipa_line_str)
else:
ipa_texts.append(ipa_line)
lyric_text = "\n".join(lyric_texts)
ipa_text = " _ ".join(ipa_texts)
return lyric_text, ipa_text

View File

@@ -0,0 +1,644 @@
import json
import matplotlib
import torch
import torch.nn.functional as F
import torch.utils.data
from pytorch_lightning.core import LightningModule
from torch.utils.data import DataLoader
# from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from models.transformer_sana_text2music_large_dcae_0319 import ACEFlowBaseModel
from loguru import logger
from transformers import AutoModel
from lyric_processor_v2 import LyricProcessor
from optimizers.cosine_wsd import configure_lr_scheduler
import traceback
import torchaudio
from transformers import Wav2Vec2FeatureExtractor
from music_dcae.music_dcae_pipeline import MusicDCAE
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from diffusers.utils.torch_utils import randn_tensor
from apg_guidance import apg_forward, MomentumBuffer
from tqdm import tqdm
import random
import os
matplotlib.use("Agg")
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision('high')
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
# torch.backends.cuda.matmul.allow_tf32 = True
class Pipeline(LightningModule):
def __init__(
self,
learning_rate: float = 1e-4,
num_workers: int = 4,
infer: bool = False,
train: bool = True,
T: int = 1000,
minibatch_size: int = 32,
batch_size: int = 1,
snr_gamma: float = 0.5,
prediction_type: str = "v_prediction", # epsilon, sample, v_prediction
beta_start: float = 0.0015,
beta_end: float = 0.0195,
noise_offset: float = 0.1,
input_perturbation: float = 0.1,
use_ema: bool = False,
enable_xformers_memory_efficient_attention: bool = False,
weight_decay: float = 1e-2,
num_chunk: int = 2,
beta_schedule: str = "scaled_linear",
scheduler_type: str = "ddpm",
every_plot_step: int = 2000,
vocal_noise: float = 0,
max_length: int = 6400,
sample_size: int = None,
target_orig: bool = True,
csv_path: str = None,
config_path: str = "./models/config_sana_text2music_dcae_0225_3.5B_simple.json",
shift: float = 3.0,
logit_mean: float = 0.0,
logit_std: float = 1.0,
timestep_densities_type: str = "logit_normal",
ssl_coeff: float = 1.0,
wav_max_seconds: float = 30.0,
max_steps: int = -1,
fix_cut_level: int = 3,
ipa_max_length: int = 8192,
text_max_length: int = 1024,
):
super().__init__()
self.save_hyperparameters()
self.is_train = train
self.T = T
self.beta_start = beta_start
self.beta_end = beta_end
self.scheduler = self.get_scheduler()
with open(config_path, "r") as f:
self.config = json.load(f)
self.transformers = ACEFlowBaseModel(**self.config)
self.lyric_processor = LyricProcessor()
self.lyric_processor.requires_grad_(False)
if not infer and self.is_train:
self.mert_model = AutoModel.from_pretrained("./checkpoints/MERT-v1-330M", trust_remote_code=True).eval()
self.mert_model.requires_grad_(False)
self.resampler_mert = torchaudio.transforms.Resample(orig_freq=48000, new_freq=24000)
self.processor_mert = Wav2Vec2FeatureExtractor.from_pretrained("./checkpoints/MERT-v1-330M", trust_remote_code=True)
self.hubert_model = AutoModel.from_pretrained("checkpoints/mHuBERT-147", local_files_only=True).eval()
self.hubert_model.requires_grad_(False)
self.resampler_mhubert = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000)
self.processor_mhubert = Wav2Vec2FeatureExtractor.from_pretrained("checkpoints/mHuBERT-147", local_files_only=True)
self.ssl_coeff = ssl_coeff
self.vae = MusicDCAE(encoder_only=False).eval()
self.vae.requires_grad_(False)
# self.mert_model = torch.compile(self.mert_model)
# self.hubert_model = torch.compile(self.hubert_model)
# self.vae = torch.compile(self.vae)
# self.transformers = torch.compile(self.transformers)
else:
self.vae = MusicDCAE(encoder_only=False).eval()
self.vae.requires_grad_(False)
def infer_mert_ssl(self, target_wavs, wav_lengths):
# 输入为 N x 2 x T (48kHz),转换为 N x T (24kHz),单声道
mert_input_wavs_mono_24k = self.resampler_mert(target_wavs.mean(dim=1))
bsz = target_wavs.shape[0]
actual_lengths_24k = wav_lengths // 2 # 48kHz -> 24kHz
# 对实际音频部分进行归一化
means = torch.stack([mert_input_wavs_mono_24k[i, :actual_lengths_24k[i]].mean() for i in range(bsz)])
vars = torch.stack([mert_input_wavs_mono_24k[i, :actual_lengths_24k[i]].var() for i in range(bsz)])
mert_input_wavs_mono_24k = (mert_input_wavs_mono_24k - means.view(-1, 1)) / torch.sqrt(vars.view(-1, 1) + 1e-7)
# MERT SSL 约束
# 定义每个 chunk 的长度5 秒的采样点数)
chunk_size = 24000 * 5 # 5 秒,每秒 24000 个采样点
total_length = mert_input_wavs_mono_24k.shape[1]
num_chunks_per_audio = (actual_lengths_24k + chunk_size - 1) // chunk_size
# 分块处理
all_chunks = []
chunk_actual_lengths = []
for i in range(bsz):
audio = mert_input_wavs_mono_24k[i]
actual_length = actual_lengths_24k[i]
for start in range(0, actual_length, chunk_size):
end = min(start + chunk_size, actual_length)
chunk = audio[start:end]
if len(chunk) < chunk_size:
chunk = F.pad(chunk, (0, chunk_size - len(chunk))) # 不足部分用零填充
all_chunks.append(chunk)
chunk_actual_lengths.append(end - start)
# 堆叠所有块为 (total_chunks, chunk_size)
all_chunks = torch.stack(all_chunks, dim=0)
# 批量推理
with torch.no_grad():
# 输出形状: (total_chunks, seq_len, hidden_size)
mert_ssl_hidden_states = self.mert_model(all_chunks).last_hidden_state
# 计算每个块的特征数量
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
# 裁剪每个块的隐藏状态
chunk_hidden_states = [mert_ssl_hidden_states[i, :chunk_num_features[i], :] for i in range(len(all_chunks))]
# 按音频组织隐藏状态
mert_ssl_hidden_states_list = []
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[chunk_idx:chunk_idx + num_chunks_per_audio[i]]
audio_hidden = torch.cat(audio_chunks, dim=0) # 拼接同一音频的块
mert_ssl_hidden_states_list.append(audio_hidden)
chunk_idx += num_chunks_per_audio[i]
return mert_ssl_hidden_states_list
def infer_mhubert_ssl(self, target_wavs, wav_lengths):
# Step 1: Preprocess audio
# Input: N x 2 x T (48kHz, stereo) -> N x T (16kHz, mono)
mhubert_input_wavs_mono_16k = self.resampler_mhubert(target_wavs.mean(dim=1))
bsz = target_wavs.shape[0]
actual_lengths_16k = wav_lengths // 3 # Convert lengths from 48kHz to 16kHz
# Step 2: Zero-mean unit-variance normalization (only on actual audio)
means = torch.stack([mhubert_input_wavs_mono_16k[i, :actual_lengths_16k[i]].mean()
for i in range(bsz)])
vars = torch.stack([mhubert_input_wavs_mono_16k[i, :actual_lengths_16k[i]].var()
for i in range(bsz)])
mhubert_input_wavs_mono_16k = (mhubert_input_wavs_mono_16k - means.view(-1, 1)) / \
torch.sqrt(vars.view(-1, 1) + 1e-7)
# Step 3: Define chunk size for MHubert (30 seconds at 16kHz)
chunk_size = 16000 * 30 # 30 seconds = 480,000 samples
# Step 4: Split audio into chunks
num_chunks_per_audio = (actual_lengths_16k + chunk_size - 1) // chunk_size # Ceiling division
all_chunks = []
chunk_actual_lengths = []
for i in range(bsz):
audio = mhubert_input_wavs_mono_16k[i]
actual_length = actual_lengths_16k[i]
for start in range(0, actual_length, chunk_size):
end = min(start + chunk_size, actual_length)
chunk = audio[start:end]
if len(chunk) < chunk_size:
chunk = F.pad(chunk, (0, chunk_size - len(chunk))) # Pad with zeros
all_chunks.append(chunk)
chunk_actual_lengths.append(end - start)
# Step 5: Stack all chunks for batch inference
all_chunks = torch.stack(all_chunks, dim=0) # Shape: (total_chunks, chunk_size)
# Step 6: Batch inference with MHubert model
with torch.no_grad():
mhubert_ssl_hidden_states = self.hubert_model(all_chunks).last_hidden_state
# Shape: (total_chunks, seq_len, hidden_size)
# Step 7: Compute number of features per chunk (assuming model stride of 320)
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
# Step 8: Trim hidden states to remove padding effects
chunk_hidden_states = [mhubert_ssl_hidden_states[i, :chunk_num_features[i], :] for i in range(len(all_chunks))]
# Step 9: Reorganize hidden states by original audio
mhubert_ssl_hidden_states_list = []
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[chunk_idx:chunk_idx + num_chunks_per_audio[i]]
audio_hidden = torch.cat(audio_chunks, dim=0) # Concatenate chunks for this audio
mhubert_ssl_hidden_states_list.append(audio_hidden)
chunk_idx += num_chunks_per_audio[i]
return mhubert_ssl_hidden_states_list
def preprocess(self, batch, train=True):
target_wavs = batch["target_wavs"]
wav_lengths = batch["wav_lengths"]
dtype = target_wavs.dtype
bs = target_wavs.shape[0]
device = target_wavs.device
# SSL约束
mert_ssl_hidden_states = None
mhubert_ssl_hidden_states = None
# is_long = target_wavs.shape[-1] >= 48000 * 150
if train:
with torch.amp.autocast(device_type="cuda", dtype=dtype):
mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths)
# mhubert_ssl_hidden_states = self.infer_mhubert_ssl(batch["vocal_wavs"], wav_lengths)
mhubert_ssl_hidden_states = self.infer_mhubert_ssl(target_wavs, wav_lengths)
# 1: text embedding
texts = batch["prompts"]
encoder_text_hidden_states, text_attention_mask = self.lyric_processor.get_text_embeddings(texts, device)
encoder_text_hidden_states = encoder_text_hidden_states.to(dtype)
target_latents, _ = self.vae.encode(target_wavs, wav_lengths)
attention_mask = torch.ones(bs, target_latents.shape[-1], device=device, dtype=dtype)
speaker_embds = batch["speaker_embs"].to(dtype)
keys = batch["keys"]
lyric_token_ids = batch["lyric_token_ids"]
lyric_mask = batch["lyric_masks"]
# pretrain stage 2 需要 cfg
if train:
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bs,), device=device) < 0.15),
torch.zeros(size=(bs,), device=device),
torch.ones(size=(bs,), device=device)
).long()
# N x T x 768
encoder_text_hidden_states = torch.where(full_cfg_condition_mask.unsqueeze(1).unsqueeze(1).bool(), encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states))
# full_cfg_condition_mask = torch.where(
# (torch.rand(size=(bs,), device=device) < 0.50),
# torch.zeros(size=(bs,), device=device),
# torch.ones(size=(bs,), device=device)
# ).long()
# # N x 512
# speaker_embds = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), speaker_embds, torch.zeros_like(speaker_embds))
# 歌词
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bs,), device=device) < 0.15),
torch.zeros(size=(bs,), device=device),
torch.ones(size=(bs,), device=device)
).long()
lyric_token_ids = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), lyric_token_ids, torch.zeros_like(lyric_token_ids))
lyric_mask = torch.where(full_cfg_condition_mask.unsqueeze(1).bool(), lyric_mask, torch.zeros_like(lyric_mask))
return (
keys,
target_latents,
attention_mask,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
mert_ssl_hidden_states,
mhubert_ssl_hidden_states,
)
def get_scheduler(self):
return FlowMatchEulerDiscreteScheduler(
num_train_timesteps=self.T,
shift=self.hparams.shift,
)
def configure_optimizers(self):
# trainable_parameters = self.transformers.get_trainable_parameters()
# optimizer = get_muon_optimizer(
# self.transformers.named_parameters(),
# lr=self.hparams.learning_rate,
# wd=self.hparams.weight_decay,
# )
# optimizer = CAME8BitWrapper(
# params=[
# {'params': self.transformers.parameters()},
# ],
# lr=self.hparams.learning_rate,
# weight_decay=self.hparams.weight_decay,
# betas=(0.8, 0.9),
# )
optimizer = torch.optim.AdamW(
params=[
{'params': self.transformers.parameters()},
],
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
betas=(0.8, 0.9),
)
max_steps = self.hparams.max_steps
# 训练200k
decay_interval = int(max_steps * (1 - 0.9) * 0.2)
lr_scheduler = configure_lr_scheduler(optimizer, total_steps_per_epoch=max_steps, epochs=1, decay_ratio=0.9, decay_interval=decay_interval, warmup_iters=4000)
return [optimizer], lr_scheduler
def get_sd3_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = self.scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def get_timestep(self, bsz, device):
if self.hparams.timestep_densities_type == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
# In practice, we sample the random variable u from a normal distribution u N (u; m, s)
# and map it through the standard logistic function
u = torch.normal(mean=self.hparams.logit_mean, std=self.hparams.logit_std, size=(bsz, ), device="cpu")
u = torch.nn.functional.sigmoid(u)
indices = (u * self.scheduler.config.num_train_timesteps).long()
indices = torch.clamp(indices, 0, self.scheduler.config.num_train_timesteps - 1)
timesteps = self.scheduler.timesteps[indices].to(device)
if self.hparams.timestep_densities_type == "u_shape":
# 参数 a 决定 U-shaped 程度,论文中 a=4 效果较好
a = 4.0
# 从均匀分布采样 v
v = torch.rand(bsz)
# 计算 u使用上述解析式
# u = 0.5 + (1/a)*asinh( sinh(a/2)*(2*v -1) )
s = torch.sinh(torch.tensor(a/2))
argument = s * (2 * v - 1)
u = 0.5 + (1.0 / a) * torch.asinh(argument)
# 数值上可能有极小偏差,保险起见 clamp 一下
u = torch.clamp(u, 0.0, 1.0)
# 将连续 [0,1] 的 u 映射到具体的离散 timesteps
indices = (u * self.scheduler.config.num_train_timesteps).long()
indices = torch.clamp(indices, 0, self.scheduler.config.num_train_timesteps - 1)
timesteps = self.scheduler.timesteps[indices].to(device)
return timesteps
def run_step(self, batch, batch_idx):
self.plot_step(batch, batch_idx)
(
keys,
target_latents,
attention_mask,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
mert_ssl_hidden_states,
mhubert_ssl_hidden_states,
) = self.preprocess(batch)
target_image = target_latents
device = target_image.device
dtype = target_image.dtype
# check dtype
# logger.info(f"target_image dtype: {target_image.dtype} model dtype: {self.transformers.dtype}")
# step 1: 随机生成噪声,初始化设置
noise = torch.randn_like(target_image, device=device)
bsz = target_image.shape[0]
timesteps = self.get_timestep(bsz, device)
# Add noise according to flow matching.
sigmas = self.get_sd3_sigmas(timesteps=timesteps, device=device, n_dim=target_image.ndim, dtype=dtype)
noisy_image = sigmas * noise + (1.0 - sigmas) * target_image
# This is the flow-matching target for vanilla SD3.
target = target_image
# clap ssl 约束 和vocal_latent_channel2的约束
all_ssl_hiden_states = []
if mert_ssl_hidden_states is not None:
all_ssl_hiden_states.append(mert_ssl_hidden_states)
if mhubert_ssl_hidden_states is not None:
all_ssl_hiden_states.append(mhubert_ssl_hidden_states)
# N x H -> N x c x W x H
x = noisy_image
# step 5: predict noise
transformer_output = self.transformers(
hidden_states=x,
attention_mask=attention_mask,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embds,
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
timestep=timesteps.to(device).to(dtype),
ssl_hidden_states=all_ssl_hiden_states,
)
model_pred = transformer_output.sample
proj_losses = transformer_output.proj_losses
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_image
# Compute loss. 只有chunk_mask为1且无padding的地方才计算loss
# N x T x 64
# chunk_masks_to_cat
# N x T -> N x c x W x T
mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(-1, target_image.shape[1], target_image.shape[2], -1)
selected_model_pred = (model_pred * mask).reshape(bsz, -1).contiguous()
selected_target = (target * mask).reshape(bsz, -1).contiguous()
loss = F.mse_loss(selected_model_pred, selected_target, reduction="none")
loss = loss.mean(1)
loss = loss * mask.reshape(bsz, -1).mean(1)
loss = loss.mean()
prefix = "train"
self.log(f"{prefix}/denoising_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
total_proj_loss = 0.0
for k, v in proj_losses:
self.log(f"{prefix}/{k}_loss", v, on_step=True, on_epoch=False, prog_bar=True)
total_proj_loss += v
if len(proj_losses) > 0:
total_proj_loss = total_proj_loss / len(proj_losses)
loss = loss + total_proj_loss * self.ssl_coeff
self.log(f"{prefix}/loss", loss, on_step=True, on_epoch=False, prog_bar=True)
learning_rate = self.lr_schedulers().get_last_lr()[0]
self.log(f"{prefix}/learning_rate", learning_rate, on_step=True, on_epoch=False, prog_bar=True)
# with torch.autograd.detect_anomaly():
# self.manual_backward(loss)
return loss
def training_step(self, batch, batch_idx):
return self.run_step(batch, batch_idx)
@torch.no_grad()
def diffusion_process(
self,
duration,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
random_generators=None,
infer_steps=60,
guidance_scale=15.0,
omega_scale=10.0,
):
do_classifier_free_guidance = True
if guidance_scale == 0.0 or guidance_scale == 1.0:
do_classifier_free_guidance = False
device = encoder_text_hidden_states.device
dtype = encoder_text_hidden_states.dtype
bsz = encoder_text_hidden_states.shape[0]
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=3.0,
)
frame_length = int(duration * 44100 / 512 / 8)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
if do_classifier_free_guidance:
attention_mask = torch.cat([attention_mask] * 2, dim=0)
encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
momentum_buffer = MomentumBuffer()
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
# expand the latents if we are doing classifier free guidance
latents = target_latents
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformers(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embds,
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
timestep=timestep,
).sample
if do_classifier_free_guidance:
noise_pred_with_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = apg_forward(
pred_cond=noise_pred_with_cond,
pred_uncond=noise_pred_uncond,
guidance_scale=guidance_scale,
momentum_buffer=momentum_buffer,
)
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
return target_latents
def predict_step(self, batch, batch_idx):
(
keys,
target_latents,
attention_mask,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
mert_ssl_hidden_states,
mhubert_ssl_hidden_states,
) = self.preprocess(batch, train=False)
infer_steps = 60
guidance_scale = 15.0
omega_scale = 10.0
seed_num = 1234
random.seed(seed_num)
bsz = target_latents.shape[0]
random_generators = [torch.Generator(device=self.device) for _ in range(bsz)]
seeds = []
for i in range(bsz):
seed = random.randint(0, 2**32 - 1)
random_generators[i].manual_seed(seed)
seeds.append(seed)
duration = self.hparams.fix_cut_level * 10
pred_latents = self.diffusion_process(
duration=duration,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embds=speaker_embds,
lyric_token_ids=lyric_token_ids,
lyric_mask=lyric_mask,
random_generators=random_generators,
infer_steps=infer_steps,
guidance_scale=guidance_scale,
omega_scale=omega_scale,
)
audio_lengths = batch["wav_lengths"]
sr, pred_wavs = self.vae.decode(pred_latents, audio_lengths=audio_lengths, sr=48000)
return {
"target_wavs": batch["target_wavs"],
"pred_wavs": pred_wavs,
"keys": keys,
"prompts": batch["prompts"],
"candidate_lyric_chunks": batch["candidate_lyric_chunks"],
"sr": sr,
"seeds": seeds,
}
def construct_lyrics(self, candidate_lyric_chunk):
lyrics = []
for chunk in candidate_lyric_chunk:
lyrics.append(chunk["lyric"])
lyrics = "\n".join(lyrics)
return lyrics
def plot_step(self, batch, batch_idx):
if batch_idx % self.hparams.every_plot_step != 0 or self.local_rank != 0 or torch.distributed.get_rank() != 0 or torch.cuda.current_device() != 0:
return
results = self.predict_step(batch, batch_idx)
target_wavs = results["target_wavs"]
pred_wavs = results["pred_wavs"]
keys = results["keys"]
prompts = results["prompts"]
candidate_lyric_chunks = results["candidate_lyric_chunks"]
sr = results["sr"]
seeds = results["seeds"]
i = 0
for key, target_wav, pred_wav, prompt, candidate_lyric_chunk, seed in zip(keys, target_wavs, pred_wavs, prompts, candidate_lyric_chunks, seeds):
key = key
prompt = prompt
lyric = self.construct_lyrics(candidate_lyric_chunk)
key_prompt_lyric = f"# KEY\n\n{key}\n\n\n# PROMPT\n\n{prompt}\n\n\n# LYRIC\n\n{lyric}\n\n# SEED\n\n{seed}\n\n"
log_dir = self.logger.log_dir
save_dir = f"{log_dir}/eval_results/step_{self.global_step}"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
torchaudio.save(f"{save_dir}/target_wav_{key}_{i}.flac", target_wav.float().cpu(), sr)
torchaudio.save(f"{save_dir}/pred_wav_{key}_{i}.flac", pred_wav.float().cpu(), sr)
with open(f"{save_dir}/key_prompt_lyric_{key}_{i}.txt", "w") as f:
f.write(key_prompt_lyric)
i += 1

1278
models/attention.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
{
"_class_name": "Transformer2DModel",
"_diffusers_version": "0.27.2",
"in_channels": 8,
"num_layers": 24,
"inner_dim": 2560,
"attention_head_dim": 128,
"num_attention_heads": 20,
"mlp_ratio": 2.5,
"out_channels": 8,
"max_position": 32768,
"rope_theta": 1000000.0,
"speaker_embedding_dim": 512,
"text_embedding_dim": 768,
"ssl_encoder_depths": [8, 8],
"ssl_names": ["mert", "m-hubert"],
"ssl_latent_dims": [1024, 768],
"patch_size": [16, 1],
"max_height": 16,
"max_width": 32768,
"lyric_encoder_vocab_size": 6693,
"lyric_hidden_size": 1024
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,66 @@
import re
from opencc import OpenCC
t2s_converter = OpenCC('t2s')
s2t_converter = OpenCC('s2t')
EMOJI_PATTERN = re.compile(
"["
"\U0001F600-\U0001F64F" # Emoticons
"]+", flags=re.UNICODE
)
# 创建一个翻译表,用于替换和移除字符
TRANSLATION_TABLE = str.maketrans({
'-': ' ', # 将 '-' 替换为空格
',': None,
'.': None,
'': None,
'': None,
'!': None,
'': None,
'?': None,
'': None,
'': None,
';': None,
'': None,
':': None,
'': None,
'\u3000': ' ', # 将全角空格替换为空格
})
# 替换括号中的内容,包括中括号和小括号
BACKSLASH_PATTERN = re.compile(r'\(.*?\)|\[.*?\]')
SPACE_PATTERN = re.compile('(?<!^)\s+(?!$)')
def normalize_text(text, language, strip=True):
"""
对文本进行标准化处理,去除标点符号,转为小写(如果适用)
"""
# Step 1: 替换 '-' 为 ' ' 并移除标点符号
text = text.translate(TRANSLATION_TABLE)
# Step 2: 移除表情符号
text = EMOJI_PATTERN.sub('', text)
# Step 3: 连续空白字符替换为单个空格,首位除外
text = SPACE_PATTERN.sub(' ', text)
# Step 4: 去除首尾空白字符(如果需要)
if strip:
text = text.strip()
# Step 5: 转为小写
text = text.lower()
# Step 6: 多语言转换
if language == "zh":
text = t2s_converter.convert(text)
if language == "yue":
text = s2t_converter.convert(text)
# 其他语言根据需要添加
return text

View File

@@ -0,0 +1,883 @@
import os
import re
import textwrap
from functools import cached_property
import pypinyin
import torch
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
from spacy.lang.ar import Arabic
from spacy.lang.en import English
from spacy.lang.es import Spanish
from spacy.lang.ja import Japanese
from spacy.lang.zh import Chinese
from tokenizers import Tokenizer
from .zh_num2words import TextNorm as zh_num2words
from typing import Dict, List, Optional, Set, Union
#copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/tokenizer.py
def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()
def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))
if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]
return text_splits
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
"es": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "señora"),
("sr", "señor"),
("dr", "doctor"),
("dra", "doctora"),
("st", "santo"),
("co", "compañía"),
("jr", "junior"),
("ltd", "limitada"),
]
],
"fr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mme", "madame"),
("mr", "monsieur"),
("dr", "docteur"),
("st", "saint"),
("co", "compagnie"),
("jr", "junior"),
("ltd", "limitée"),
]
],
"de": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("fr", "frau"),
("dr", "doktor"),
("st", "sankt"),
("co", "firma"),
("jr", "junior"),
]
],
"pt": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "senhora"),
("sr", "senhor"),
("dr", "doutor"),
("dra", "doutora"),
("st", "santo"),
("co", "companhia"),
("jr", "júnior"),
("ltd", "limitada"),
]
],
"it": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# ("sig.ra", "signora"),
("sig", "signore"),
("dr", "dottore"),
("st", "santo"),
("co", "compagnia"),
("jr", "junior"),
("ltd", "limitata"),
]
],
"pl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("p", "pani"),
("m", "pan"),
("dr", "doktor"),
("sw", "święty"),
("jr", "junior"),
]
],
"ar": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# There are not many common abbreviations in Arabic as in English.
]
],
"zh": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
"cs": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("ing", "inženýr"), # engineer
("p", "pan"), # Could also map to pani for woman but no easy way to do it
# Other abbreviations would be specialized and not as common.
]
],
"ru": [
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
for x in [
("г-жа", "госпожа"), # Mrs.
("г", "господин"), # Mr.
("д-р", "доктор"), # doctor
# Other abbreviations are less common or specialized.
]
],
"nl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dhr", "de heer"), # Mr.
("mevr", "mevrouw"), # Mrs.
("dr", "dokter"), # doctor
("jhr", "jonkheer"), # young lord or nobleman
# Dutch uses more abbreviations, but these are the most common ones.
]
],
"tr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("b", "bay"), # Mr.
("byk", "büyük"), # büyük
("dr", "doktor"), # doctor
# Add other Turkish abbreviations here if needed.
]
],
"hu": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("b", "bácsi"), # Mr.
("nőv", "nővér"), # nurse
# Add other Hungarian abbreviations here if needed.
]
],
"ko": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
}
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
"en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree "),
]
],
"es": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " y "),
("@", " arroba "),
("%", " por ciento "),
("#", " numeral "),
("$", " dolar "),
("£", " libra "),
("°", " grados "),
]
],
"fr": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " et "),
("@", " arobase "),
("%", " pour cent "),
("#", " dièse "),
("$", " dollar "),
("£", " livre "),
("°", " degrés "),
]
],
"de": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " und "),
("@", " at "),
("%", " prozent "),
("#", " raute "),
("$", " dollar "),
("£", " pfund "),
("°", " grad "),
]
],
"pt": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " arroba "),
("%", " por cento "),
("#", " cardinal "),
("$", " dólar "),
("£", " libra "),
("°", " graus "),
]
],
"it": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " chiocciola "),
("%", " per cento "),
("#", " cancelletto "),
("$", " dollaro "),
("£", " sterlina "),
("°", " gradi "),
]
],
"pl": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " i "),
("@", " małpa "),
("%", " procent "),
("#", " krzyżyk "),
("$", " dolar "),
("£", " funt "),
("°", " stopnie "),
]
],
"ar": [
# Arabic
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " و "),
("@", " على "),
("%", " في المئة "),
("#", " رقم "),
("$", " دولار "),
("£", " جنيه "),
("°", " درجة "),
]
],
"zh": [
# Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", ""),
("@", ""),
("%", " 百分之 "),
("#", ""),
("$", " 美元 "),
("£", " 英镑 "),
("°", ""),
]
],
"cs": [
# Czech
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " a "),
("@", " na "),
("%", " procento "),
("#", " křížek "),
("$", " dolar "),
("£", " libra "),
("°", " stupně "),
]
],
"ru": [
# Russian
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " и "),
("@", " собака "),
("%", " процентов "),
("#", " номер "),
("$", " доллар "),
("£", " фунт "),
("°", " градус "),
]
],
"nl": [
# Dutch
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " en "),
("@", " bij "),
("%", " procent "),
("#", " hekje "),
("$", " dollar "),
("£", " pond "),
("°", " graden "),
]
],
"tr": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " ve "),
("@", " at "),
("%", " yüzde "),
("#", " diyez "),
("$", " dolar "),
("£", " sterlin "),
("°", " derece "),
]
],
"hu": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " és "),
("@", " kukac "),
("%", " százalék "),
("#", " kettőskereszt "),
("$", " dollár "),
("£", " font "),
("°", " fok "),
]
],
"ko": [
# Korean
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " 그리고 "),
("@", ""),
("%", " 퍼센트 "),
("#", " 번호 "),
("$", " 달러 "),
("£", " 파운드 "),
("°", ""),
]
],
}
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".")
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
def _expand_currency(m, lang="en", currency="USD"):
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
"hu": ", ",
"ko": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang="en"):
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
def _expand_number(m, lang="en"):
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
def expand_numbers_multilingual(text, lang="en"):
if lang == "zh":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except:
pass
if lang != "tr":
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
try:
text = expand_numbers_multilingual(text, lang)
except:
pass
try:
text = expand_abbreviations_multilingual(text, lang)
except:
pass
try:
text = expand_symbols_multilingual(text, lang=lang)
except:
pass
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def chinese_transliterate(text):
return "".join(
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
)
def japanese_cleaners(text, katsu):
text = katsu.romaji(text)
text = lowercase(text)
return text
def korean_transliterate(text):
r = Transliter(academic)
return r.translit(text)
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
self.char_limits = {
"en": 10000,
"de": 253,
"fr": 273,
"es": 239,
"it": 213,
"pt": 203,
"pl": 224,
"zh": 82,
"ar": 166,
"cs": 186,
"ru": 182,
"nl": 251,
"tr": 226,
"ja": 71,
"hu": 224,
"ko": 95,
}
@cached_property
def katsu(self):
import cutlet
return cutlet.Cutlet()
def check_input_length(self, txt, lang):
lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250)
# if len(txt) > limit:
# print(
# f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
# )
def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
txt = multilingual_cleaners(txt, lang)
if lang == "zh":
txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "hi":
# @manmay will implement this
txt = basic_cleaners(txt)
else:
raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt
def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def decode(self, seq, skip_special_tokens=False):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
txt = txt.replace("[SPACE]", " ")
txt = txt.replace("[STOP]", "")
# txt = txt.replace("[UNK]", "")
return txt
#copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936
def batch_decode(
self,
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`List[str]`: The list of decoded sentences.
"""
return [
self.decode(seq)
for seq in sequences
]
#https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/trainer/dataset.py#L202
# def pad(self):
def __len__(self):
return self.tokenizer.get_vocab_size()
def get_number_tokens(self):
return max(self.tokenizer.get_vocab().values()) + 1
def test_expand_numbers_multilingual():
test_cases = [
# English
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
("This is a 1st test", "This is a first test", "en"),
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
# French
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
# German
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
# Spanish
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
("Este es un 1er test", "Este es un primero test", "es"),
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
# Italian
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
("Questo è un 1° test", "Questo è un primo test", "it"),
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
# Portuguese
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
(
"Isso custará 20,15€ senhor.",
"Isso custará vinte euros e quinze cêntimos senhor.",
"pt",
), # "cêntimos" should be "centavos" num2words issue
# Polish
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
# Arabic
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
# Czech
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
# Russian
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
# Dutch
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
# Chinese (Simplified)
("在12.5秒内", "在十二点五秒内", "zh"),
("有50名士兵", "有五十名士兵", "zh"),
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
# Turkish
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
("50 asker vardı.", "elli asker vardı.", "tr"),
("Bu 1. test", "Bu birinci test", "tr"),
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
# Hungarian
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
("50 katona volt.", "ötven katona volt.", "hu"),
("Ez az 1. teszt", "Ez az első teszt", "hu"),
# Korean
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
]
for a, b, lang in test_cases:
out = expand_numbers_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'"
def test_abbreviations_multilingual():
test_cases = [
# English
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
("Dr. Jones is here.", "doctor Jones is here.", "en"),
# Spanish
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
# French
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
# German
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
# Portuguese
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
# Italian
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
# Polish
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
# Czech
("P. Novák", "pan Novák", "cs"),
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
# Dutch
("Dhr. Jansen", "de heer Jansen", "nl"),
("Mevr. de Vries", "mevrouw de Vries", "nl"),
# Russian
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
# Turkish
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
# Hungarian
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
]
for a, b, lang in test_cases:
out = expand_abbreviations_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'"
def test_symbols_multilingual():
test_cases = [
("I have 14% battery", "I have 14 percent battery", "en"),
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
("Я буду @ дома", "Я буду собака дома", "ru"),
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
]
for a, b, lang in test_cases:
out = expand_symbols_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'"
if __name__ == "__main__":
test_expand_numbers_multilingual()
test_abbreviations_multilingual()
test_symbols_multilingual()

15535
models/lyrics_utils/vocab.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
[
{
"id": 1,
"name": "Dump Singer 1",
"description": "This is the first singer preset",
"spk_emb_path": "path/to/singer1/spk_emb"
}
]

View File

@@ -0,0 +1,482 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, List, Union
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging, is_torch_version
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
try:
from .attention import LinearTransformerBlock, t2i_modulate
from .lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
except ImportError:
from attention import LinearTransformerBlock, t2i_modulate
from lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def cross_norm(hidden_states, controlnet_input):
# input N x T x c
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
return controlnet_input
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class T2IFinalLayer(nn.Module):
"""
The final layer of Sana.
"""
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
super().__init__()
self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
self.out_channels = out_channels
self.patch_size = patch_size
def unpatchfy(
self,
hidden_states: torch.Tensor,
width: int,
):
# 4 unpatchify
new_height, new_width = 1, hidden_states.size(1)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
).contiguous()
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
).contiguous()
if width > new_width:
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
elif width < new_width:
output = output[:, :, :, :width]
return output
def forward(self, x, t, output_length):
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify
output = self.unpatchfy(x, output_length)
return output
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=16,
width=4096,
patch_size=(16, 1),
in_channels=8,
embed_dim=1152,
bias=True,
):
super().__init__()
patch_size_h, patch_size_w = patch_size
self.early_conv_layers = nn.Sequential(
nn.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias),
torch.nn.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True),
nn.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
)
self.patch_size = patch_size
self.height, self.width = height // patch_size_h, width // patch_size_w
self.base_size = self.width
def forward(self, latent):
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
latent = self.early_conv_layers(latent)
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
return latent
@dataclass
class Transformer1DModelOutput(BaseOutput):
sample: torch.FloatTensor
proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
class ACEFlowBaseModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: Optional[int] = 8,
num_layers: int = 28,
inner_dim: int = 1536,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
mlp_ratio: float = 4.0,
out_channels: int = 8,
max_position: int = 32768,
rope_theta: float = 1000000.0,
speaker_embedding_dim: int = 512,
text_embedding_dim: int = 768,
ssl_encoder_depths: List[int] = [9, 9],
ssl_names: List[str] = ["mert", "m-hubert"],
ssl_latent_dims: List[int] = [1024, 768],
lyric_encoder_vocab_size: int = 6681,
lyric_hidden_size: int = 1024,
patch_size: List[int] = [16, 1],
max_height: int = 16,
max_width: int = 4096,
**kwargs,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.out_channels = out_channels
self.max_position = max_position
self.patch_size = patch_size
self.rope_theta = rope_theta
self.rotary_emb = Qwen2RotaryEmbedding(
dim=self.attention_head_dim,
max_position_embeddings=self.max_position,
base=self.rope_theta,
)
# 2. Define input layers
self.in_channels = in_channels
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
LinearTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
add_cross_attention=True,
add_cross_attention_dim=self.inner_dim,
)
for i in range(self.config.num_layers)
]
)
self.num_layers = num_layers
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
# speaker
self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
# genre
self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
# lyric
self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
projector_dim = 2 * self.inner_dim
self.projectors = nn.ModuleList([
nn.Sequential(
nn.Linear(self.inner_dim, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, ssl_dim),
) for ssl_dim in ssl_latent_dims
])
self.ssl_latent_dims = ssl_latent_dims
self.ssl_encoder_depths = ssl_encoder_depths
self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
self.ssl_names = ssl_names
self.proj_in = PatchEmbed(
height=max_height,
width=max_width,
patch_size=patch_size,
embed_dim=self.inner_dim,
bias=True,
)
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
self.gradient_checkpointing = False
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward_lyric_encoder(
self,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
):
# N x T x D
lyric_embs = self.lyric_embs(lyric_token_idx)
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
return prompt_prenet_out
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
):
bs = encoder_text_hidden_states.shape[0]
device = encoder_text_hidden_states.device
# speaker embedding
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
speaker_mask = torch.ones(bs, 1, device=device)
# genre embedding
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
# lyric
encoder_lyric_hidden_states = self.forward_lyric_encoder(
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
)
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
return encoder_hidden_states, encoder_hidden_mask
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
hidden_states = self.proj_in(hidden_states)
# controlnet logic
if block_controlnet_hidden_states is not None:
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
hidden_states = hidden_states + control_condi * controlnet_scale
inner_hidden_states = []
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
)
for ssl_encoder_depth in self.ssl_encoder_depths:
if index_block == ssl_encoder_depth:
inner_hidden_states.append(hidden_states)
proj_losses = []
if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
if ssl_hidden_state is None:
continue
# 1. N x T x D1 -> N x D x D2
est_ssl_hidden_state = projector(inner_hidden_state)
# 3. projection loss
bs = inner_hidden_state.shape[0]
proj_loss = 0.0
for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
# 2. interpolate
z_tilde = F.interpolate(z_tilde.unsqueeze(0).transpose(1, 2), size=len(z), mode='linear', align_corners=False).transpose(1, 2).squeeze(0)
z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
z = torch.nn.functional.normalize(z, dim=-1)
# T x d -> T x 1 -> 1
target = torch.ones(z.shape[0], device=z.device)
proj_loss += self.cosine_loss(z, z_tilde, target)
proj_losses.append((ssl_name, proj_loss / bs))
output = self.final_layer(hidden_states, embedded_timestep, output_length)
if not return_dict:
return (output, proj_losses)
return Transformer1DModelOutput(sample=output, proj_losses=proj_losses)
# @torch.compile
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
timestep: Optional[torch.Tensor] = None,
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
ssl_hidden_states=ssl_hidden_states,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
return_dict=return_dict,
)
return output

2
music_dcae/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .music_dcae import MusicDCAE
from .music_log_mel import LogMelSpectrogram, get_mel_transform

92
music_dcae/balancer.py Normal file
View File

@@ -0,0 +1,92 @@
import torch
from torch.autograd import grad
class Balancer:
"""
Balancer for dynamically re-weighting multiple losses based on gradient norms.
Args:
weights (dict): Predefined weights for each loss.
Example: {"mse_loss": 1.0, "adv_loss": 1.0}
ema_decay (float): Decay factor for exponential moving average (default: 0.99).
epsilon (float): Small value to avoid division by zero (default: 1e-8).
"""
def __init__(self, weights, ema_decay=0.99, epsilon=1e-8):
self.weights = weights
self.ema_decay = ema_decay
self.epsilon = epsilon
self.ema_values = {key: 0.0 for key in weights} # Initialize EMA for each loss
def forward(self, losses, grad_inputs):
"""
Re-weight the input losses based on gradient norms and return a combined loss.
Args:
losses (dict): Dictionary of losses with names as keys and loss tensors as values.
Example: {"mse_loss": mse_loss, "adv_loss": adv_loss}
grad_inputs (dict): Dictionary of inputs for autograd.grad corresponding to each loss.
Example: {"mse_loss": recon_mels, "adv_loss": recon_mels}
Returns:
torch.Tensor: Combined weighted loss.
"""
# Validate inputs
if set(losses.keys()) != set(grad_inputs.keys()):
raise ValueError("Keys of losses and grad_inputs must match.")
norm_values = {}
# Compute gradient norms for each loss
for name, loss in losses.items():
loss_grad, = grad(loss.mean(), [grad_inputs[name]], create_graph=True)
dims = tuple(range(1, loss_grad.ndim)) # Exclude batch dimension
grad_norm = torch.linalg.vector_norm(loss_grad, ord=2, dim=dims).mean()
# Update EMA for the gradient norm
if self.ema_values[name] == 0.0:
self.ema_values[name] = grad_norm.item()
else:
self.ema_values[name] = (
self.ema_values[name] * self.ema_decay + grad_norm.item() * (1 - self.ema_decay)
)
# Normalize gradient norm
norm_values[name] = grad_norm / (self.ema_values[name] + self.epsilon)
# Compute dynamic weights
total_norm = sum(norm_values.values())
dynamic_weights = {name: norm / total_norm for name, norm in norm_values.items()}
# Combine losses with dynamic weights
loss = 0.0
log_weights = {}
for name in losses:
loss = loss + self.weights[name] * dynamic_weights[name] * losses[name]
log_weights[f"{name}_weight"] = dynamic_weights[name]
return loss, log_weights
if __name__ == "__main__":
# Example usage
mel_real = torch.randn(1, 80, 10)
generator = torch.nn.Linear(10, 10)
recon_mels = generator(mel_real)
discriminator = torch.nn.Linear(10, 1)
disc_out = discriminator(recon_mels)
mse_loss = torch.nn.functional.mse_loss(recon_mels, mel_real).mean()
adv_loss = torch.nn.functional.softplus(-disc_out).mean()
losses = {"mse_loss": mse_loss, "adv_loss": adv_loss}
grad_inputs = {"mse_loss": recon_mels, "adv_loss": recon_mels}
print("losses", losses)
# Define predefined weights for each loss
weights = {"mse_loss": 1.0, "adv_loss": 1.0}
# Initialize balancer
balancer = Balancer(weights)
# Forward pass
loss, log_weights = balancer.forward(losses, grad_inputs)
print("Combined Loss:", loss)
print("Dynamic Weights:", log_weights)

View File

@@ -0,0 +1,69 @@
{
"_class_name": "AutoencoderDC",
"_diffusers_version": "0.32.1",
"_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
"attention_head_dim": 32,
"decoder_act_fns": "silu",
"decoder_block_out_channels": [
128,
256,
512,
1024
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock"
],
"decoder_layers_per_block": [
3,
3,
3,
3
],
"decoder_norm_types": "rms_norm",
"decoder_qkv_multiscales": [
[],
[],
[
5
],
[
5
]
],
"downsample_block_type": "Conv",
"encoder_block_out_channels": [
128,
256,
512,
1024
],
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock"
],
"encoder_layers_per_block": [
2,
2,
3,
3
],
"encoder_qkv_multiscales": [
[],
[],
[
5
],
[
5
]
],
"in_channels": 2,
"latent_channels": 8,
"scaling_factor": 0.41407,
"upsample_block_type": "interpolate"
}

124
music_dcae/distrib.py Normal file
View File

@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one.")
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(
buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))

78
music_dcae/music_dcae.py Normal file
View File

@@ -0,0 +1,78 @@
import torch
import torch.nn as nn
from diffusers import AutoencoderDC
import json
DEFAULT_CONFIG_PATH = "/root/sag_train/music_dcae/config_f32c32_large.json"
class MusicDCAE(nn.Module):
def __init__(self, config_path=DEFAULT_CONFIG_PATH):
super(MusicDCAE, self).__init__()
with open(config_path) as f:
config = json.load(f)
self.dcae = AutoencoderDC(**config)
def encode(self, x):
return self.dcae.encode(x).latent
def decode(self, latent):
sample = self.dcae.decode(latent).sample
return sample
def forward(self, x):
sample = self.dcae(x).sample
return sample
def return_middle_layers(self):
last_down_block = self.dcae.encoder.down_blocks[-1]
encoder_conv_out = self.dcae.encoder.conv_out
decoder_conv_in = self.dcae.decoder.conv_in
decoder_up_blocks = self.dcae.decoder.up_blocks[0]
middle_layers = [last_down_block, encoder_conv_out, decoder_conv_in, decoder_up_blocks]
return middle_layers
def return_head_layers(self):
decoder_up_blocks = self.dcae.decoder.up_blocks[-1]
conv_out = self.dcae.decoder.conv_out
head_layers = [decoder_up_blocks, conv_out]
return head_layers
if __name__ == "__main__":
model = MusicDCAE("/root/sag_train/music_dcae/config_f8c8_large.json")
x = torch.randn(1, 2, 128, 1024)
# mask = None
# if mask is None:
# mask = torch.ones(x.shape[0], 1, x.shape[2], x.shape[3]).to(x.device)
# # N x 1024
# elif len(mask.shape) == 2:
# mask = mask.unsqueeze(1).unsqueeze(1).float()
# mask = mask.repeat(1, 1, x.shape[2], 1)
latent = model.encode(x)
print("latent shape: ", latent.shape)
y = model(x)
print("y", y.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数总数: {total_params / 1e6:.2f}M")
# middle_layers = model.return_middle_layers()
# middle_params_count = 0
# for layer in middle_layers:
# for name, param in layer.named_parameters():
# layer_param_count = param.numel()
# middle_params_count += layer_param_count
# print(f"{name}: {param.shape}, 参数量: {layer_param_count/1e6:.2f}M")
# print(f"中间层总参数量: {middle_params_count/1e6:.2f}M")
# head_layers = model.return_head_layers()
# head_params_count = 0
# for layer in head_layers:
# for name, param in layer.named_parameters():
# layer_param_count = param.numel()
# head_params_count += layer_param_count
# print(f"{name}: {param.shape}, 参数量: {layer_param_count/1e6:.2f}M")
# print(f"头部层总参数量: {head_params_count/1e6:.2f}M")

View File

@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
from diffusers import AutoencoderDC
import torchaudio
import torchvision.transforms as transforms
import torchaudio
try:
from .music_log_mel import get_mel_transform
from .music_vocoder import ADaMoSHiFiGANV1
except ImportError:
from music_log_mel import get_mel_transform
from music_vocoder import ADaMoSHiFiGANV1
import os
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder.pt")
class MusicDCAE(nn.Module):
def __init__(self, pretrained_path=DEFAULT_PRETRAINED_PATH, encoder_only=False, source_sample_rate=None):
super(MusicDCAE, self).__init__()
dcae = AutoencoderDC.from_pretrained(pretrained_path)
self.encoder_only = encoder_only
self.mel_transform = get_mel_transform()
if encoder_only:
self.encoder = dcae.encoder
else:
self.encoder = dcae.encoder
self.decoder = dcae.decoder
self.vocoder = ADaMoSHiFiGANV1(VOCODER_PRETRAINED_PATH).eval()
if source_sample_rate is None:
source_sample_rate = 48000
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
if audio_lengths is None:
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
audio_lengths = audio_lengths.to(audios.device)
# audios: N x 2 x T, 48kHz
device = audios.device
dtype = audios.dtype
if sr is None:
sr = 48000
resampler = self.resampler
else:
resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
audio = resampler(audios)
max_audio_len = audio.shape[-1]
if max_audio_len % (8 * 512) != 0:
audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
mels = self.forward_mel(audio)
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
mels = self.transform(mels)
latents = []
for mel in mels:
latent = self.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
latents = latents / self.scale_factor + self.shift_factor
mels = []
for latent in latents:
mel = self.decoder(latent.unsqueeze(0))
mels.append(mel)
mels = torch.cat(mels, dim=0)
mels = mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
bsz, channels, num_mel, mel_width = mels.shape
pred_wavs = []
for i in range(bsz):
mel = mels[i]
wav = self.vocoder.decode(mel).squeeze(1)
pred_wavs.append(wav)
pred_wavs = torch.stack(pred_wavs)
if sr is not None:
resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
pred_wavs = [resampler(wav) for wav in pred_wavs]
else:
sr = 44100
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
return sr, pred_wavs, latents, latent_lengths
if __name__ == "__main__":
audio, sr = torchaudio.load("/root/data/repo/gongjunmin/sag_train/orig2.wav")
audio_lengths = torch.tensor([audio.shape[1]])
audios = audio.unsqueeze(0)
# test encode only
model = MusicDCAE()
# latents, latent_lengths = model.encode(audios, audio_lengths)
# print("latents shape: ", latents.shape)
# print("latent_lengths: ", latent_lengths)
# test encode and decode
sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
print("reconstructed wavs: ", pred_wavs[0].shape)
print("latents shape: ", latents.shape)
print("latent_lengths: ", latent_lengths)
print("sr: ", sr)
torchaudio.save("/root/data/repo/gongjunmin/sag_train/reconstructed.wav", pred_wavs[0], sr)
print("reconstructed wav saved to /root/data/repo/gongjunmin/sag_train/reconstructed.wav")

View File

@@ -0,0 +1,551 @@
from typing import Tuple, Union, Optional, Dict, Any
import torch
import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_dc import DCUpBlock2d, get_block, RMSNorm, Decoder
from diffusers.models.transformers.sana_transformer import SanaTransformerBlock
from diffusers.models.embeddings import get_2d_sincos_pos_embed
from diffusers.models.normalization import AdaLayerNormSingle, RMSNorm
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import is_torch_version
from diffusers.models.unets import UNet2DModel
class Encoder(nn.Module):
def __init__(
self,
in_channels: int = 32,
out_channels: int = 8,
attention_head_dim: int = 32,
block_out_channels: Tuple[int] = (512, 1024, 2048),
layers_per_block: Tuple[int] = (3, 3, 3),
block_type: str = "EfficientViTBlock",
norm_type: str = "rms_norm",
act_fn: str = "silu",
qkv_multiscales: tuple = (5,),
):
super(Encoder, self).__init__()
num_blocks = len(block_out_channels)
self.dump_encoder = False
if num_blocks == 0:
self.dump_encoder = True
return
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=True,
shortcut=True,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type,
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
act_fn=act_fn,
qkv_mutliscales=qkv_multiscales,
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
self.norm_out = RMSNorm(block_out_channels[0], 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.dump_encoder:
return hidden_states
hidden_states = self.conv_in(hidden_states)
i = 0
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
i += 1
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding with support for SD3 cropping.
Args:
height (`int`, defaults to `224`): The height of the image.
width (`int`, defaults to `224`): The width of the image.
patch_size (`int`, defaults to `16`): The size of the patches.
in_channels (`int`, defaults to `3`): The number of input channels.
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
bias (`bool`, defaults to `True`): Whether or not to use bias.
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
"""
def __init__(
self,
height=16,
width=128,
patch_size=(16,1),
in_channels=16,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
pos_embed_max_size=None, # For SD3 cropping
):
super().__init__()
num_patches = (height // patch_size[0]) * (width // patch_size[1])
self.flatten = flatten
self.layer_norm = layer_norm
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
self.height, self.width = height // patch_size[0], width // patch_size[1]
self.base_size = height // patch_size[1]
self.interpolation_scale = interpolation_scale
# Calculate positional embeddings based on max size or default
if pos_embed_max_size:
grid_size = pos_embed_max_size
else:
grid_size = int(num_patches**0.5)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim,
grid_size,
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
output_type="pt",
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def forward(self, latent):
if self.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size[0], latent.shape[-1] // self.patch_size[1]
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate or crop positional embeddings as needed
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width)
else:
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
device=latent.device,
output_type="pt",
)
pos_embed = pos_embed.float().unsqueeze(0)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class DiTDecoder(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Tuple[int, int] = (16, 128),
in_channels: int = 16,
out_channels: int = 8,
patch_size: Tuple[int, int] = (16, 1),
inner_dim: int = 1152,
num_attention_heads: int = 36,
attention_head_dim: int = 32,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
num_cross_attention_heads: Optional[int] = None,
cross_attention_head_dim: Optional[int] = None,
attention_bias: bool = False,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: int = 1,
mlp_ratio: float = 2.5,
num_layers: int = 12,
):
super(DiTDecoder, self).__init__()
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.interpolation_scale = interpolation_scale
self.patch_embed = PatchEmbed(
height=sample_size[0],
width=sample_size[1],
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
self.time_embed = AdaLayerNormSingle(inner_dim)
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size[0] * patch_size[1] * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[int] = None,
return_dict: bool = True,
):
# 1. Input
batch_size, num_channels, height, width = hidden_states.shape
patch_size = self.config.patch_size
post_patch_height, post_patch_width = height // patch_size[0], width // patch_size[1]
hidden_states = self.patch_embed(hidden_states)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
None,
None,
timestep,
post_patch_height,
post_patch_width,
**ckpt_kwargs,
)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
None,
None,
None,
timestep,
post_patch_height,
post_patch_width,
)
# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_height, post_patch_width, self.config.patch_size[0], self.config.patch_size[1], -1
)
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * patch_size[0], post_patch_width * patch_size[1])
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class MusicDcaeRefiner(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int = 32,
attention_head_dim: int = 32,
block_out_channels: Tuple[int] = (512, 1024, 2048),
layers_per_block: Tuple[int] = (3, 3, 3),
conv_block_out_channels: Tuple[int] = (224, 448, 672, 896),
out_channels: int = 8,
block_type: str = "EfficientViTBlock",
norm_type: str = "rms_norm",
act_fn: str = "silu",
qkv_multiscales: tuple = (5,),
sample_size: Tuple[int, int] = (16, 128),
patch_size: Tuple[int, int] = (16, 1),
inner_dim: int = 1152,
num_attention_heads: int = 36,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
num_cross_attention_heads: Optional[int] = None,
cross_attention_head_dim: Optional[int] = None,
attention_bias: bool = False,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: int = 1,
mlp_ratio: float = 2.5,
num_layers: int = 12,
decoder_type: str = "ConvDecoder",
):
super(MusicDcaeRefiner, self).__init__()
self.encoder = Encoder(
in_channels=in_channels,
out_channels=out_channels,
attention_head_dim=attention_head_dim,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
block_type=block_type,
norm_type=norm_type,
act_fn=act_fn,
qkv_multiscales=qkv_multiscales,
)
if decoder_type == "DiTDecoder":
self.decoder = DiTDecoder(
sample_size=sample_size,
in_channels=out_channels * 2,
out_channels=out_channels,
patch_size=patch_size,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
interpolation_scale=interpolation_scale,
mlp_ratio=mlp_ratio,
num_layers=num_layers,
)
else:
self.decoder = UNet2DModel(
sample_size=sample_size,
in_channels=out_channels * 2,
out_channels=out_channels,
block_out_channels=conv_block_out_channels,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Optional[int] = None,
return_dict: bool = True
):
encoder_hidden_states = self.encoder(encoder_hidden_states)
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
output = self.decoder(hidden_states, timestep=timestep, return_dict=return_dict)
return output
if __name__ == "__main__":
# f32c32 -> f8c8
# model = MusicDcaeRefiner()
# x = torch.randn(1, 8, 16, 128)
# encoder_x = torch.randn(1, 32, 4, 32)
# timestep = 0
# y = model(x, encoder_x, timestep=timestep)
# print("y", y.sample.shape)
# total_params = sum(p.numel() for p in model.parameters())
# print(f"模型参数总数: {total_params / 1e6:.2f}M")
# # 分别计算encoder和decoder的参数量
# encoder_params_count = sum(p.numel() for p in model.encoder.parameters())
# decoder_params_count = sum(p.numel() for p in model.decoder.parameters())
# print(f"encoder参数量: {encoder_params_count/1e6:.2f}M")
# print(f"decoder参数量: {decoder_params_count/1e6:.2f}M")
# f8c8 -> mel
import json
with open("music_dcae/config_f8c8_to_mel_refiner.json", "r") as f:
config = json.load(f)
model = MusicDcaeRefiner(**config)
x = torch.randn(1, 2, 128, 1024)
encoder_x = torch.randn(1, 2, 128, 1024)
timestep = 0
y = model(x, encoder_x, timestep=timestep)
print("y", y.sample.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数总数: {total_params / 1e6:.2f}M")
# 分别计算encoder和decoder的参数量
encoder_params_count = sum(p.numel() for p in model.encoder.parameters())
decoder_params_count = sum(p.numel() for p in model.decoder.parameters())
print(f"encoder参数量: {encoder_params_count/1e6:.2f}M")
print(f"decoder参数量: {decoder_params_count/1e6:.2f}M")

View File

@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderDC
import json
import torchvision.transforms as transforms
import torchaudio
try:
from .music_vocoder import ADaMoSHiFiGANV1
except ImportError:
from music_vocoder import ADaMoSHiFiGANV1
DEFAULT_CONFIG_PATH = "/root/sag_train/music_dcae/config_f32c32_large.json"
DCAE_PRETRAINED_PATH = "/root/sag_train/checkpoints/music_dcae_f32c32"
VOCODER_PRETRAINED_PATH = "/root/sag_train/checkpoints/music_vocoder.pt"
class MusicDCAEVocoder(nn.Module):
def __init__(self, config_path=DEFAULT_CONFIG_PATH, pretrained_path=DCAE_PRETRAINED_PATH):
super(MusicDCAEVocoder, self).__init__()
if pretrained_path is None:
with open(config_path) as f:
config = json.load(f)
self.dcae = AutoencoderDC(**config)
else:
self.dcae = AutoencoderDC.from_pretrained(pretrained_path)
self.vocoder = ADaMoSHiFiGANV1(VOCODER_PRETRAINED_PATH)
self.freeze_vocoder()
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.target_sr = 44100
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
if audio.shape[0] == 1:
audio = torch.cat([audio, audio], dim=0)
return audio, sr
def resample_audio(self, audio, sr=48000):
resampler = torchaudio.transforms.Resample(sr, self.target_sr)
resampler = resampler.to(audio.device)
audio = resampler(audio)
return audio
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.vocoder.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
def norm_mel(self, mels):
normed_mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
normed_mels = self.transform(normed_mels)
return normed_mels
def denorm_mel(self, normed_mels):
mels = normed_mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
return mels
def encode_latent(self, normed_mels):
# N x 2 x 128 x W -> N x C x 128//F x W//F
latent = self.dcae.encode(normed_mels).latent
return latent
def decode_mel(self, latent):
# N x C x 128//F x W//F -> N x 2 x 128 x W
normed_mels = self.dcae.decode(latent).sample
return normed_mels
def decode_audio(self, mels):
# mels: N x 2 x 128 x W -> 2N x 128 x W
bs = mels.shape[0]
mono_mels = mels.reshape(-1, 128, mels.shape[-1])
mono_audios = self.vocoder(mono_mels)
audios = mono_audios.reshape(bs, 2, -1)
return audios
def encode(self, audios):
mels = self.forward_mel(audios)
normed_mels = self.norm_mel(mels)
latent = self.encode_latent(normed_mels)
return latent, mels
def decode(self, latent):
recon_normed_mels = self.decode_mel(latent)
recon_mels = self.denorm_mel(recon_normed_mels)
recon_audios = self.decode_audio(recon_mels)
return recon_audios, recon_mels
def forward(self, audios):
audios_len = audios.shape[-1]
latent, mels = self.encode(audios)
recon_audios, recon_mels = self.decode(latent)
if recon_audios.shape[-1] > audios_len:
recon_audios = recon_audios[:, :, :audios_len]
elif recon_audios.shape[-1] < audios_len:
recon_audios = F.pad(recon_audios, (0, audios_len - recon_audios.shape[-1]))
return recon_audios, mels, recon_mels, latent
def freeze_vocoder(self):
self.vocoder.eval()
self.vocoder.requires_grad_(False)
def unfreeze_vocoder(self):
self.vocoder.train()
self.vocoder.requires_grad_(True)
def return_middle_layers(self):
last_down_block = self.dcae.encoder.down_blocks[-1]
encoder_conv_out = self.dcae.encoder.conv_out
decoder_conv_in = self.dcae.decoder.conv_in
decoder_up_blocks = self.dcae.decoder.up_blocks[0]
middle_layers = [last_down_block, encoder_conv_out, decoder_conv_in, decoder_up_blocks]
return middle_layers
def return_head_layers(self):
decoder_up_blocks = self.dcae.decoder.up_blocks[-1]
conv_out = self.dcae.decoder.conv_out
head_layers = [decoder_up_blocks, conv_out]
return head_layers
if __name__ == "__main__":
model = MusicDCAEVocoder()
audio_path = "/root/sag_train/orig2.wav"
audio, sr = model.load_audio(audio_path)
audio = model.resample_audio(audio, sr)
model.eval()
model = model.to("cuda:0")
audio = audio.to("cuda:0")
with torch.no_grad():
audios_len = audio.shape[-1]
min_frame = 512 * 32
if audios_len % min_frame != 0:
padding = torch.zeros(audio.shape[0], 2, min_frame - audios_len % min_frame).to(audios.device)
audios = torch.cat([audio, padding], dim=-1)
recon_audios, mels, recon_mels, latent = model(audio.unsqueeze(0))
recon_audios = recon_audios[:, :, :audios_len]
print("latent shape: ", latent.shape)
print("recon_audios", recon_audios.shape)
print("mels", mels.shape, "min:", mels.min(), "max:", mels.max(), "mean:", mels.mean(), "std:", mels.std())
print("recon_mels", recon_mels.shape, "min:", recon_mels.min(), "max:", recon_mels.max(), "mean:", recon_mels.mean(), "std:", recon_mels.std())
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数总数: {total_params / 1e6:.2f}M")
torchaudio.save("/root/sag_train/recon2.wav", recon_audios[0].cpu(), 44100)

119
music_dcae/music_log_mel.py Executable file
View File

@@ -0,0 +1,119 @@
import torch
import torch.nn as nn
from torch import Tensor
from torchaudio.transforms import MelScale
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
dtype = y.dtype
spec = torch.stft(
y.float(),
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = spec.to(dtype)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or sample_rate // 2
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
"slaney",
"slaney",
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
linear = self.spectrogram(x)
x = self.mel_scale(linear)
x = self.compress(x)
# print(x.shape)
if return_linear:
return x, self.compress(linear)
return x
def get_mel_transform():
return LogMelSpectrogram(
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
f_min=40,
f_max=16000,
n_mels=128,
)

565
music_dcae/music_vocoder.py Executable file
View File

@@ -0,0 +1,565 @@
import librosa
import torch
from torch import nn
from functools import partial
from math import prod
from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
try:
from music_log_mel import LogMelSpectrogram
except ImportError:
from .music_log_mel import LogMelSpectrogram
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None] * x + self.bias[:, None]
return x
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim,
dim,
kernel_size=kernel_size,
padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(layer_scale_init_value *
torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ParallelConvNeXtBlock(nn.Module):
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
super().__init__()
self.blocks = nn.ModuleList(
[
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.stack(
[block(x, apply_residual=False) for block in self.blocks] + [x],
dim=1,
).sum(dim=1)
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
kernel_sizes: Tuple[int] = (7,),
):
super().__init__()
assert len(depths) == len(dims)
self.channel_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv1d(
input_channels,
dims[0],
kernel_size=7,
padding=3,
padding_mode="replicate",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.channel_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.channel_layers.append(mid_layer)
block_fn = (
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
if len(kernel_sizes) == 1
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
)
self.stages = nn.ModuleList()
drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
block_fn(
dim=dims[i],
drop_path=drop_path_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for channel_layer, stage in zip(self.channel_layers, self.stages):
x = channel_layer(x)
x = stage(x)
return self.norm(x)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for conv in self.convs1:
remove_weight_norm(conv)
for conv in self.convs2:
remove_weight_norm(conv)
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
use_template: bool = True,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = weight_norm(
nn.Conv1d(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
1,
padding=get_padding(pre_conv_kernel_size),
)
)
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.use_template = use_template
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if not use_template:
continue
if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(
Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = weight_norm(
nn.Conv1d(
ch,
1,
post_conv_kernel_size,
1,
padding=get_padding(post_conv_kernel_size),
)
)
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x, template=None):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.use_template:
x = x + self.noise_convs[i](template)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for up in self.ups:
remove_weight_norm(up)
for block in self.resblocks:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(nn.Module):
def __init__(
self,
checkpoint_path: str = "checkpoints/adamos-generator-1640000.pth",
):
super().__init__()
self.backbone = ConvNeXtEncoder(
input_channels=128,
depths=[3, 3, 9, 3],
dims=[128, 256, 384, 512],
drop_path_rate=0,
kernel_sizes=(7,),
)
self.head = HiFiGANGenerator(
hop_length=512,
upsample_rates=(4, 4, 2, 2, 2, 2, 2),
upsample_kernel_sizes=(8, 8, 4, 4, 4, 4, 4),
resblock_kernel_sizes=(3, 7, 11, 13),
resblock_dilation_sizes=(
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels=512,
upsample_initial_channel=1024,
use_template=False,
pre_conv_kernel_size=13,
post_conv_kernel_size=13,
)
self.sampling_rate = 44100
ckpt_state = torch.load(checkpoint_path, map_location="cpu")
if "state_dict" in ckpt_state:
ckpt_state = ckpt_state["state_dict"]
if any(k.startswith("generator.") for k in ckpt_state):
ckpt_state = {
k.replace("generator.", ""): v
for k, v in ckpt_state.items()
if k.startswith("generator.")
}
self.load_state_dict(ckpt_state)
self.eval()
self.mel_transform = LogMelSpectrogram(
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
f_min=40,
f_max=16000,
n_mels=128,
)
@torch.no_grad()
def decode(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
@torch.no_grad()
def encode(self, x):
return self.mel_transform(x)
def forward(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
if __name__ == "__main__":
import soundfile as sf
x = "./test.wav"
model = ADaMoSHiFiGANV1(checkpoint_path='./step_001640000.pth')
wav, sr = librosa.load(x, sr=44100, mono=True)
wav = torch.from_numpy(wav).float()[None]
mel = model.encode(wav)
wav = model.decode(mel)[0].mT
sf.write("test_out.wav", wav.cpu().numpy(), 44100)

39
optimizers/cosine_wsd.py Normal file
View File

@@ -0,0 +1,39 @@
from torch.optim.lr_scheduler import _LRScheduler
import torch
class CosineWSD(_LRScheduler):
def __init__(self, optimizer, warmup_iters, step_size, decay_length, decay_interval, eta_min=0, last_epoch=-1):
self.warmup_iters = warmup_iters
self.step_size = step_size
self.decay_length = decay_length
self.decay_interval = decay_interval
self.eta_min = eta_min
super(CosineWSD, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_iters:
lr = [(base_lr * self.last_epoch / self.warmup_iters) for base_lr in self.base_lrs]
elif self.last_epoch < self.step_size:
lr = [base_lr for base_lr in self.base_lrs]
elif self.last_epoch <= self.step_size + self.decay_length:
lr = [(base_lr * (0.5 ** ((self.last_epoch - self.step_size) / self.decay_interval)))
for base_lr in self.base_lrs]
else:
lr = [self.eta_min for base_lr in self.base_lrs]
return lr
def configure_lr_scheduler(optimizer, total_steps_per_epoch, epochs=10, decay_ratio=0.9, decay_interval=1000, warmup_iters=4000):
total_steps = total_steps_per_epoch * epochs
step_size = total_steps * decay_ratio
decay_length = total_steps - step_size
decay_interval = decay_interval
lr_scheduler = CosineWSD(
optimizer,
warmup_iters=warmup_iters,
step_size=step_size,
decay_length=decay_length,
decay_interval=decay_interval
)
return [{"scheduler": lr_scheduler, "name": "CosineWSD", "interval": "step"}]

19
requirements.txt Normal file
View File

@@ -0,0 +1,19 @@
datasets==3.4.1
diffusers==0.32.2
gradio==5.23.3
librosa==0.11.0
loguru==0.7.3
matplotlib==3.10.1
numpy
pypinyin==0.53.0
pytorch_lightning==2.5.1
soundfile==0.13.1
torch
torchaudio
torchvision
tqdm==4.67.1
transformers==4.50.0
py3langid==0.3.0
hangul-romanize==0.1.0
num2words==0.5.14
spacy==3.8.4

View File

@@ -0,0 +1,394 @@
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
if sigmas is None:
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
omega: Union[float, np.array] = 0.0
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
# L = Lower bound
# U = Upper bound
# x_0 = Midpoint (x corresponding to y = 1.0)
# k = Steepness, can adjust based on preference
if isinstance(x, torch.Tensor):
device_ = x.device
x = x.to(torch.float).cpu().numpy()
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
if isinstance(new_x, np.ndarray):
new_x = torch.from_numpy(new_x).to(device_)
return new_x
self.omega_bef_rescale = omega
omega = logistic_function(omega, k=0.1)
self.omega_aft_rescale = omega
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
## --
## mean shift 1
dx = (sigma_next - sigma) * model_output
m = dx.mean()
# print(dx.shape) # torch.Size([1, 16, 128, 128])
# print(f'm: {m}') # m: -0.0014209747314453125
# raise NotImplementedError
dx_ = (dx - m) * omega + m
prev_sample = sample + dx_
# ## --
# ## mean shift 2
# m = model_output.mean()
# model_output_ = (model_output - m) * omega + m
# prev_sample = sample + (sigma_next - sigma) * model_output_
# ## --
# ## original
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
# ## --
# ## spatial mean 1
# dx = (sigma_next - sigma) * model_output
# m = dx.mean(dim=(0, 1), keepdim=True)
# # print(dx.shape) # torch.Size([1, 16, 128, 128])
# # print(m.shape) # torch.Size([1, 1, 128, 128])
# # raise NotImplementedError
# dx_ = (dx - m) * omega + m
# prev_sample = sample + dx_
# ## --
# ## spatial mean 2
# m = model_output.mean(dim=(0, 1), keepdim=True)
# model_output_ = (model_output - m) * omega + m
# prev_sample = sample + (sigma_next - sigma) * model_output_
# ## --
# ## channel mean 1
# m = model_output.mean(dim=(2, 3), keepdim=True)
# # print(m.shape) # torch.Size([1, 16, 1, 1])
# model_output_ = (model_output - m) * omega + m
# prev_sample = sample + (sigma_next - sigma) * model_output_
# ## --
# ## channel mean 2
# dx = (sigma_next - sigma) * model_output
# m = dx.mean(dim=(2, 3), keepdim=True)
# # print(m.shape) # torch.Size([1, 16, 1, 1])
# dx_ = (dx - m) * omega + m
# prev_sample = sample + dx_
# ## --
# ## keep sample mean
# m_tgt = sample.mean()
# prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
# m_src = prev_sample_.mean()
# prev_sample = prev_sample_ - m_src + m_tgt
# ## --
# ## test
# # print(sample.mean())
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
# # raise NotImplementedError
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -0,0 +1,348 @@
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Heun scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
self.timesteps = timesteps.to(device=device)
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# empty dt and derivative
self.prev_derivative = None
self.dt = None
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
@property
def state_in_first_order(self):
return self.dt is None
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
omega: Union[float, np.array] = 0.0
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
# L = Lower bound
# U = Upper bound
# x_0 = Midpoint (x corresponding to y = 1.0)
# k = Steepness, can adjust based on preference
if isinstance(x, torch.Tensor):
device_ = x.device
x = x.to(torch.float).cpu().numpy()
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
if isinstance(new_x, np.ndarray):
new_x = torch.from_numpy(new_x).to(device_)
return new_x
self.omega_bef_rescale = omega
omega = logistic_function(omega, k=0.1)
self.omega_aft_rescale = omega
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
else:
# 2nd order / Heun's method
sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.state_in_first_order:
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
denoised = sample - model_output * sigma
# 2. convert to an ODE derivative for 1st order
derivative = (sample - denoised) / sigma_hat
# 3. Delta timestep
dt = sigma_next - sigma_hat
# store for 2nd order step
self.prev_derivative = derivative
self.dt = dt
self.sample = sample
else:
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
denoised = sample - model_output * sigma_next
# 2. 2nd order / Heun's method
derivative = (sample - denoised) / sigma_next
derivative = 0.5 * (self.prev_derivative + derivative)
# 3. take prev timestep & sample
dt = self.dt
sample = self.sample
# free dt and derivative
# Note, this puts the scheduler in "first order mode"
self.prev_derivative = None
self.dt = None
self.sample = None
# original sample way
# prev_sample = sample + derivative * dt
dx = derivative * dt
m = dx.mean()
dx_ = (dx - m) * omega + m
prev_sample = sample + dx_
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps

3
ui/auth.py Normal file
View File

@@ -0,0 +1,3 @@
def same_auth(username, password):
return username == "timedomain_text2music_team" and password == "TimeDomain_ACEFlow_DEMO"

44
ui/llm_prompt_gen.py Normal file
View File

@@ -0,0 +1,44 @@
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
random_genre_prompt = """randomly give me a short prompt that describes a music (with genre tag). less than 30 words
Here are some examples:
fusion jazz with synth, bass, drums, saxophone
Electronic, eerie, swing, dreamy, melodic, electro, sad, emotional
90s hip-hop, old school rap, turntablism, vinyl samples, instrumental loop
"""
def random_genre():
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": random_genre_prompt}],
max_tokens=30,
temperature=0.7,
)
return completion.choices[0].message.content
optimize_genre_prompt = """optimize the following music descirption and make it more genre specific. less than 30 words
output examples:
fusion jazz with synth, bass, drums, saxophone
Electronic, eerie, swing, dreamy, melodic, electro, sad, emotional
90s hip-hop, old school rap, turntablism, vinyl samples, instrumental loop
## input music descirption
"""
def optimize_genre(prompt):
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": optimize_genre_prompt+prompt}],
max_tokens=30,
temperature=0.7,
)
return completion.choices[0].message.content

View File

@@ -0,0 +1,323 @@
import gradio as gr
from pathlib import Path
import json
from collections import OrderedDict, Counter
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from language_segmentation import LangSegment
MAX_GENERATE_LEN = 60
SUPPORT_LANGUAGES = [
"af", "sq", "am", "ar", "an", "hy", "az", "ba", "eu", "be", "bn", "bs", "bg", "my", "ca", "zh", "cs", "da", "nl", "en", "eo", "et", "fi", "fr", "gd", "ka", "de", "el", "gn", "gu", "hi", "hu", "io", "id", "ia", "it", "ja", "kk", "km", "ko", "ku", "la", "lt", "lb", "mk", "mt", "nb", "no", "or", "fa", "pl", "pt", "ro", "ru", "sa", "sr", "sd", "sk", "sl", "es", "sw", "sv", "tl", "ta", "tt", "th", "tr", "tk", "uk", "vi", "cy", "is", "ga", "gl", "se", "yue"
]
langseg = LangSegment()
langseg.setfilters([
'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
])
keyscale_idx_mapping = OrderedDict({
"C major": 1,
"C# major": 2,
"D major": 3,
"Eb major": 4,
"E major": 5,
"F major": 6,
"F# major": 7,
"G major": 8,
"Ab major": 9,
"A major": 10,
"Bb major": 11,
"B major": 12,
"A minor": 13,
"Bb minor": 14,
"B minor": 15,
"C minor": 16,
"C# minor": 17,
"D minor": 18,
"Eb minor": 19,
"E minor": 20,
"F minor": 21,
"F# minor": 22,
"G minor": 23,
"Ab minor": 24
})
def get_checkpoint_paths(checkpoint_path):
# 获取指定目录中的所有checkpoint文件路径
directory = Path(checkpoint_path).parent
checkpoints = [str(p) for p in directory.glob("*.ckpt")]
print(checkpoints)
return checkpoints
def create_list_checkpoint_path_ui(checkpoint_path):
with gr.Column():
gr.Markdown("Checkpoint Selection")
with gr.Group():
with gr.Row(equal_height=True):
with gr.Column(scale=9):
selected_checkpoint = gr.Dropdown(
choices=get_checkpoint_paths(checkpoint_path),
label="Select Model",
interactive=True,
value=checkpoint_path,
)
with gr.Column(scale=1):
refresh_button = gr.Button("Refresh Checkpoints", elem_id="refresh_button", variant="primary")
refresh_button.click(
fn=lambda: gr.update(choices=get_checkpoint_paths(checkpoint_path)),
inputs=None,
outputs=[selected_checkpoint]
)
return selected_checkpoint
def create_keyscale_bpm_time_signature_input_ui(options=["auto", "manual"]):
gr.Markdown("### Time and Keyscale Control")
with gr.Group():
results = [
["keyscale", 0],
["bpm", 0],
["timesignature", 0],
["is_music_start", 0],
["is_music_end", 0],
]
keyscale_bpm_time_signature_input = gr.List(visible=False, elem_id="keyscale_bpm_time_signature_input", value=results)
audio_duration = gr.Slider(10, 600, step=1, value=MAX_GENERATE_LEN, label="Audio Duration", interactive=True)
with gr.Row():
is_music_start_input = gr.Radio(["auto", "start", "not_start"], value="auto", label="Is Music Start", elem_id="is_music_start_input")
is_music_end_input = gr.Radio(["auto", "end", "not_end"], value="auto", label="Is Music End", elem_id="is_music_end_input")
def when_is_music_start_input_change(
is_music_start_input,
):
nonlocal results
if is_music_start_input == "auto":
is_music_start = 0
elif is_music_start_input == "start":
is_music_start = 1
else:
is_music_start = 2
results[3][1] = is_music_start
return gr.update(elem_id="keyscale_bpm_time_signature_input", value=results)
is_music_start_input.change(
when_is_music_start_input_change,
inputs=[is_music_start_input],
outputs=[keyscale_bpm_time_signature_input]
)
def when_is_music_end_input_change(
is_music_end_input,
):
nonlocal results
if is_music_end_input == "auto":
is_music_end = 0
elif is_music_end_input == "end":
is_music_end = 1
else:
is_music_end = 2
results[4][1] = is_music_end
return gr.update(elem_id="keyscale_bpm_time_signature_input", value=results)
is_music_end_input.change(
when_is_music_end_input_change,
inputs=[is_music_end_input],
outputs=[keyscale_bpm_time_signature_input]
)
with gr.Row():
keyscale_control = gr.Radio(options, value="auto", label="Keyscale", elem_id="keyscale_control")
bpm_control = gr.Radio(options, value="auto", label="BPM", elem_id="bpm_control")
time_signature_control = gr.Radio(options, value="auto", label="Time Signature", elem_id="time_signature_control")
keyscale_input = gr.Dropdown(list(keyscale_idx_mapping.keys()), label="Keyscale", info="the keyscale of the music", visible=False, elem_id="keyscale_input")
def when_keyscale_change(
keyscale_input,
keyscale_control,
):
nonlocal results
keyscale = keyscale_input
if keyscale_control == "auto":
keyscale = 0
results[0][1] = keyscale
return [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="keyscale_input", visible=(keyscale_control == "manual"))]
keyscale_input.change(
when_keyscale_change,
inputs=[keyscale_input, keyscale_control],
outputs=[keyscale_bpm_time_signature_input, keyscale_input]
)
keyscale_control.change(
fn=when_keyscale_change,
inputs=[keyscale_input, keyscale_control],
outputs=[keyscale_bpm_time_signature_input, keyscale_input]
)
bpm_input = gr.Slider(30, 200, step=1, value=120, label="BPM", info="the beats per minute of the music", visible=False, interactive=True, elem_id="bpm_input")
def when_bmp_change(
bpm_input,
bpm_control,
):
nonlocal results
bpm = bpm_input
if bpm_control == "auto":
bpm = 0
results[1][1] = bpm
updates = [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="bpm_input", visible=(bpm_control == "manual"))]
return updates
bpm_control.change(
fn=when_bmp_change,
inputs=[bpm_input, bpm_control],
outputs=[keyscale_bpm_time_signature_input, bpm_input]
)
bpm_input.change(
when_bmp_change,
inputs=[bpm_input, bpm_control],
outputs=[keyscale_bpm_time_signature_input, bpm_input]
)
time_signature_input = gr.Slider(1, 12, step=1, value=4, label="Time Signature", info="the time signature of the music", visible=False, interactive=True, elem_id="time_signature_input")
def when_time_signature_change(
time_signature_input,
time_signature_control,
):
nonlocal results
time_signature = time_signature_input
if time_signature_control == "auto":
time_signature = 0
results[2][1] = time_signature
return [gr.update(elem_id="keyscale_bpm_time_signature_input", value=results), gr.update(elem_id="time_signature_input", visible=(time_signature_control == "manual"))]
time_signature_input.change(
when_time_signature_change,
inputs=[time_signature_input, time_signature_control],
outputs=[keyscale_bpm_time_signature_input, time_signature_input]
)
time_signature_control.change(
fn=when_time_signature_change,
inputs=[time_signature_input, time_signature_control],
outputs=[keyscale_bpm_time_signature_input, time_signature_input]
)
return [audio_duration, keyscale_bpm_time_signature_input]
def detect_language(lyrics: str) -> list:
lyrics = lyrics.strip()
if not lyrics:
return gr.update(value="en")
langs = langseg.getTexts(lyrics)
lang_counter = Counter()
for lang in langs:
lang_counter[lang["lang"]] += len(lang["text"])
lang = lang_counter.most_common(1)[0][0]
return lang
def create_output_ui():
target_audio = gr.Audio(type="filepath", label="Target Audio")
output_audio1 = gr.Audio(type="filepath", label="Generated Audio 1")
output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
input_params_json = gr.JSON(label="Input Parameters")
outputs = [output_audio1, output_audio2]
return outputs, target_audio, input_params_json
def dump_func(*args):
print(args)
return []
def create_main_demo_ui(
checkpoint_path="checkpoints/aceflow3_0311/1d_epoch=16-step=140k.ckpt",
text2music_process_func=dump_func,
sample_data_func=dump_func,
):
with gr.Blocks(
title="AceFlow 3.0 DEMO (3.5B)",
) as demo:
gr.Markdown(
"""
<h1 style="text-align: center;">AceFlow 3.0 DEMO</h1>
"""
)
selected_checkpoint = create_list_checkpoint_path_ui(checkpoint_path)
gr.Markdown("Dataset Filter")
with gr.Group():
with gr.Row(equal_height=True):
language = gr.Dropdown(["en", "zh"], label="Language", value="en", elem_id="language")
dataset_example_idx = gr.Number(
value=-1,
label="Dataset Example Index",
interactive=True
)
sample_bnt = gr.Button(value="Sample Data", elem_id="sample_bnt", variant="primary")
with gr.Row():
with gr.Column():
audio_duration = gr.Slider(10, 600, step=1, value=MAX_GENERATE_LEN, label="Audio Duration", interactive=True)
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4)
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=9)
scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type")
cfg_type = gr.Radio(["cfg", "apg"], value="apg", label="CFG Type", elem_id="cfg_type")
infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True)
omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True)
manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None)
text2music_bnt = gr.Button(variant="primary")
with gr.Column():
outputs, target_audio, input_params_json = create_output_ui()
sample_bnt.click(
sample_data_func,
inputs=[dataset_example_idx, audio_duration],
outputs=[target_audio, prompt, lyrics, input_params_json],
)
text2music_bnt.click(
fn=text2music_process_func,
inputs=[
audio_duration,
prompt,
lyrics,
input_params_json,
selected_checkpoint,
scheduler_type,
cfg_type,
infer_step,
guidance_scale,
omega_scale,
manual_seeds,
], outputs=outputs + [input_params_json]
)
return demo
if __name__ == "__main__":
demo = create_main_demo_ui()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)

197
utils.py Normal file
View File

@@ -0,0 +1,197 @@
from loguru import logger
import functools
import numpy as np
import time
import librosa
import sys
import yaml
from threading import Thread
logger.remove()
logger.add(sys.stderr, format="{time} {level} {message}", level="INFO")
def async_thread(f):
def wrapper(*args, **kwargs):
t = Thread(target=f, args=args, kwargs=kwargs)
t.start()
return wrapper
def timecost(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
logger.info(f"{func.__name__} took {end - start} seconds to run")
return result
return wrapper
def autocut(wav, min_cut_second=9.9, sample_rate=16_000, frame_length=2048, hop_length=512, cut_threshold=[2e-5, 1, 2**0.5], min_mute_duration=120, min_tail_second=2):
segs = []
seg_lengths = []
longest_wav_frames = int(min_cut_second * sample_rate)
if len(wav) < longest_wav_frames:
segs.append(wav)
seg_lengths.append(len(wav))
return segs, seg_lengths
# 自适应阈值算法找静音切分点
candidate_cut_positions = []
candidate_cut_durations = []
cut_threshold, cut_threshold_max, cut_step_multiple = cut_threshold
for i in range(8):
rms = librosa.feature.rms(y=wav, frame_length=frame_length, hop_length=hop_length)[0]
is_mute_mask = rms <= cut_threshold
is_mute = np.zeros_like(rms, dtype='bool')
is_mute[is_mute_mask], is_mute[~is_mute_mask] = True, False
# logger.info(f"{rms.mean()=}, {rms.min()=}, {rms.max()=}, {cut_threshold=}, {is_mute_mask.sum()=}")
last_start = 0
last_position = 0
curr_cut_positions = []
curr_cut_durations = []
interrupt = False
for i in range(len(is_mute) - 1):
# 从有到无
if not is_mute[i] and is_mute[i + 1]:
last_start = i
# 从无到有
if is_mute[i] and not is_mute[i + 1]:
# 静音部分至少大于等于min_mute_duration
mute_duration = (i - last_start) * \
hop_length / (sample_rate / 1000)
if mute_duration >= min_mute_duration:
# 切分规则:在静音中间部分作为分割点
# 还原到wav的帧
mid = (i + last_start) // 2
cut_position = mid * hop_length
curr_duration = cut_position - last_position
# 若超了,切分成四份
if (longest_wav_frames // 2) < curr_duration:
left_cut_position = (last_start+mid) // 2 * hop_length
left_curr_duration = left_cut_position - last_position
curr_cut_positions.append(left_cut_position)
curr_cut_durations.append(left_curr_duration)
last_position = left_cut_position
right_cut_position = (mid+i) // 2 * hop_length
right_curr_duration = right_cut_position - last_position
curr_cut_positions.append(right_cut_position)
curr_cut_durations.append(right_curr_duration)
last_position = right_cut_position
else:
curr_cut_positions.append(cut_position)
curr_cut_durations.append(curr_duration)
last_position = cut_position
candidate_cut_positions = curr_cut_positions
candidate_cut_durations = curr_cut_durations
if cut_threshold >= cut_threshold_max:
break
if cut_threshold < cut_threshold_max:
if len(curr_cut_durations) == 0:
curr_cut_positions.append(len(wav))
curr_cut_durations.append(len(wav))
else:
curr_cut_positions.append(len(wav))
curr_cut_durations.append(
curr_cut_positions[-1] - curr_cut_positions[-2])
max_duration = max(curr_cut_durations)
if max_duration >= longest_wav_frames:
interrupt = True
cut_threshold = cut_threshold * cut_step_multiple
min_mute_duration = int(max(min_mute_duration/cut_step_multiple, 10))
frame_length = int(max(frame_length / cut_step_multiple, 256))
hop_length = int(max(hop_length / cut_step_multiple, 64))
# logger.info(f"Adaptively adjust the threshold: {cut_threshold=} {min_mute_duration=} {frame_length=} {hop_length=} {len(curr_cut_durations)=}")
if not interrupt and len(curr_cut_durations) > 0:
candidate_cut_positions = curr_cut_positions
candidate_cut_durations = curr_cut_durations
break
# logger.info(f"candidate_cut_positions {candidate_cut_positions}")
# logger.info(f"candidate_cut_durations {candidate_cut_durations}")
# 从已有切分点中找最接近最大长度的切分点
curr_duration = 0
last_start = 0
for i, duration in enumerate(candidate_cut_durations):
curr_duration += duration
# 若超出最大限制,以上一个点作为实际切分
if curr_duration > longest_wav_frames:
segs.append(wav[last_start:candidate_cut_positions[i - 1]])
seg_lengths.append(curr_duration - duration)
curr_duration = duration
last_start = candidate_cut_positions[i - 1]
if len(candidate_cut_durations) == 0 or (len(candidate_cut_durations)==1 and candidate_cut_durations[0] >= len(wav)):
logger.info("自动切分算法失败,按最长强制切分")
# 按最长强制切分
last_start = 0
segs = []
seg_lengths = []
for end in range(longest_wav_frames, max(longest_wav_frames, len(wav)), longest_wav_frames):
segs.append(wav[last_start:end])
seg_lengths.append(end-last_start)
last_start = end
# 解决尾部问题
if sum(seg_lengths) < len(wav):
for end in range(last_start+longest_wav_frames, max(longest_wav_frames, len(wav)), longest_wav_frames):
segs.append(wav[last_start:end])
seg_lengths.append(end - last_start)
last_start = end
if sum(seg_lengths) < len(wav):
last_start = sum(seg_lengths)
tail_frame = len(wav) - last_start
if len(segs) > 0 and tail_frame < min_tail_second*sample_rate:
segs.pop()
seg_lengths.pop()
last_start = sum(seg_lengths)
segs.append(wav[last_start:])
seg_lengths.append(len(wav) - last_start)
if any([len(seg) > longest_wav_frames for seg in segs]):
new_segs = []
new_seg_lengths = []
for seg, seg_length in zip(segs, seg_lengths):
num_cut = len(seg) // longest_wav_frames
num_cut += 1 if len(seg) % longest_wav_frames > 0 else 0
for i in range(num_cut):
new_segs.append(seg[i*longest_wav_frames:(i+1)*longest_wav_frames])
new_seg_lengths.append(len(new_segs[-1]))
segs, seg_lengths = new_segs, new_seg_lengths
return segs, seg_lengths
class ConfigObj:
def __init__(self, d):
self.__dict__.update(d)
def __repr__(self) -> str:
return repr(self.__dict__)
def __str__(self) -> str:
return str(self.__dict__)
def __getitem__(self, k):
return self.__dict__[k]
def get(self, k, default=None):
if k in self.__dict__:
return self[k]
else:
return default
def __setitem__(self, k, v):
self.__dict__[k] = v
def load_config(config_path):
with open(config_path, encoding='utf-8') as yaml_file:
config = yaml.safe_load(yaml_file)
return ConfigObj(config)