Compare commits

..

26 Commits

Author SHA1 Message Date
comfyanonymous
1c21828236 ComfyUI v0.16.3 2026-03-05 17:25:49 -05:00
Tavi Halperin
58017e8726 feat: add causal_fix parameter to add_keyframe_index and append_keyframe (#12797)
Allows explicit control over the causal_fix flag passed to
latent_to_pixel_coords. Defaults to frame_idx == 0 when not
specified, fixing the previous heuristic.
2026-03-05 16:51:20 -05:00
comfyanonymous
17b43c2b87 LTX audio vae novram fixes. (#12796) 2026-03-05 16:31:28 -05:00
Jukka Seppänen
8befce5c7b Add manual cast to LTX2 vocoder conv_transpose1d (#12795)
* Add manual cast to LTX2 vocoder

* Update vocoder.py
2026-03-05 12:37:25 -08:00
comfyanonymous
50549aa252 ComfyUI v0.16.2 2026-03-05 13:41:06 -05:00
comfyanonymous
1c3b651c0a Refactor. (#12794) 2026-03-05 13:35:56 -05:00
ComfyUI Wiki
5073da57ad chore: update workflow templates to v0.9.10 (#12793) 2026-03-05 10:22:38 -08:00
rattus
42e0e023ee ops: Handle CPU weight in VBAR caster (#12792)
This shouldn't happen but custom nodes gets there. Handle it as best
we can.
2026-03-05 10:22:17 -08:00
rattus
6481569ad4 comfy-aimdo 0.2.7 (#12791)
Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in
and would cause an infinite stack overflow (via detours hooks).

A lock is also introduced on the link list holding the free sections
to avoid any possibility of threaded miscellaneous cuda allocations
being the root cause.
2026-03-05 09:04:24 -08:00
comfyanonymous
6ef82a89b8 ComfyUI v0.16.1 2026-03-05 10:38:33 -05:00
ComfyUI Wiki
da29b797ce Update workflow templates to v0.9.8 (#12788) 2026-03-05 07:23:23 -08:00
Alexander Piskun
9cdfd7403b feat(api-nodes): enable Kling 3.0 Motion Control (#12785) 2026-03-05 07:12:38 -08:00
Alexander Piskun
bd21363563 feat(api-nodes-xAI): updated models, pricing, added features (#12756) 2026-03-05 04:29:39 -08:00
comfyanonymous
e04d0dbeb8 ComfyUI v0.16.0 2026-03-05 04:06:29 -05:00
ComfyUI Wiki
c8428541a6 chore: update workflow templates to v0.9.7 (#12780) 2026-03-05 03:58:25 -05:00
comfyanonymous
4941671b5a Fix cuda getting initialized in cpu mode. (#12779) 2026-03-05 02:39:51 -05:00
ComfyUI Wiki
c5fe8ace68 chore: update workflow templates to v0.9.6 (#12778) 2026-03-05 02:37:35 -05:00
comfyanonymous
f2ee7f2d36 Fix cublas ops on dynamic vram. (#12776) 2026-03-05 01:21:55 -05:00
comfyanonymous
43c64b6308 Support the LTXAV 2.3 model. (#12773) 2026-03-04 20:06:20 -05:00
comfyanonymous
ac4a943ff3 Initial load device should be cpu when using dynamic vram. (#12766) 2026-03-04 16:33:14 -05:00
rattus
8811db52db comfy-aimdo 0.2.6 (#12764)
Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest
as an error after a number of workflow runs.
2026-03-04 12:12:37 -08:00
Jukka Seppänen
0a7446ade4 Pass tokens when loading text gen model for text generation (#12755)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-04 08:59:56 -08:00
rattus
9b85cf9558 Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754)
* ops: dont unpin nothing

This was calling into aimdo in the none case (offloaded weight). Whats worse,
is aimdo syncs for unpinning an offloaded weight, as that is the corner case of
a weight getting evicted by its own use which does require a sync. But this
was heppening every offloaded weight causing slowdown.

* mp: fix get_free_memory policy

The ModelPatcherDynamic get_free_memory was deducting the model from
to try and estimate the conceptual free memory with doing any
offloading. This is kind of what the old memory_memory_required
was estimating in ModelPatcher load logic, however in practical
reality, between over-estimates and padding, the loader usually
underloaded models enough such that sampling could send CFG +/-
through together even when partially loaded.

So don't regress from the status quo and instead go all in on the
idea that offloading is less of an issue than debatching. Tell the
sampler it can use everything.
2026-03-04 07:49:13 -08:00
rattus
d531e3fb2a model_patcher: Improve dynamic offload heuristic (#12759)
Define a threshold below which a weight loading takes priority. This
actually makes the offload consistent with non-dynamic, because what
happens, is when non-dynamic fills ints to_load list, it will fill-up
any left-over pieces that could fix large weights with small weights
and load them, even though they were lower priority. This actually
improves performance because the timy weights dont cost any VRAM and
arent worth the control overhead of the DMA etc.
2026-03-04 07:47:44 -08:00
Arthur R Longbottom
eb011733b6 Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683)
* Fix VideoFromComponents.save_to crash when writing to BytesIO

When `get_container_format()` or `get_stream_source()` is called on a
tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`.
Since BytesIO has no file extension, `av.open` can't infer the output
format and throws `ValueError: Could not determine output format`.

The sibling class `VideoFromFile` already handles this correctly via
`get_open_write_kwargs()`, which detects BytesIO and sets the format
explicitly. `VideoFromComponents` just never got the same treatment.

This surfaces when any downstream node validates the container format
of a tensor-based video, like TopazVideoEnhance or any node that calls
`validate_container_format_is_mp4()`.

Three-line fix in `comfy_api/latest/_input_impl/video_types.py`.

* Add docstring to save_to to satisfy CI coverage check
2026-03-04 00:29:00 -05:00
rattus
ac6513e142 DynamicVram: Add casting / fix torch Buffer weights (#12749)
* respect model dtype in non-comfy caster

* utils: factor out parent and name functionality of set_attr

* utils: implement set_attr_buffer for torch buffers

* ModelPatcherDynamic: Implement torch Buffer loading

If there is a buffer in dynamic - force load it.
2026-03-03 18:19:40 -08:00
25 changed files with 1164 additions and 258 deletions

View File

@@ -2,11 +2,16 @@ from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
ADALN_BASE_PARAMS_COUNT,
ADALN_CROSS_ATTN_PARAMS_COUNT,
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
NormSingleLinearTextProjection,
LTXVModel,
apply_cross_attention_adaln,
compute_prompt_timestep,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
apply_gated_attention=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=v_dim,
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def _apply_text_cross_attention(
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
timestep, prompt_timestep, attention_mask, transformer_options,
):
"""Apply text cross-attention, with optional ADaLN modulation."""
if self.cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(
scale_shift_table, x.shape[0], timestep, slice(6, 9)
)
return apply_cross_attention_adaln(
x, context, attn, shift_q, scale_q, gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask, transformer_options,
)
return attn(
comfy.ldm.common_dit.rms_norm(x), context=context,
mask=attention_mask, transformer_options=transformer_options,
)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
v_prompt_timestep=None, a_prompt_timestep=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx.add_(self._apply_text_cross_attention(
vx, v_context, self.attn2, self.scale_shift_table,
getattr(self, 'prompt_scale_shift_table', None),
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
)
# audio
if run_ax:
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ax.add_(self._apply_text_cross_attention(
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
getattr(self, 'audio_prompt_scale_shift_table', None),
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
)
# video - audio cross attention.
if run_a2v or run_v2a:
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
apply_gated_attention=False,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
self.apply_gated_attention = apply_gated_attention
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
)
# Audio-specific AdaLN
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=audio_embedding_coefficient,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=2,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_prompt_adaln_single = None
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.audio_caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_caption_projection = lambda a: a
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
connector_split_rope = kwargs.get("rope_type", "split") == "split"
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
num_layers = kwargs.get("connector_num_layers", 2)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
def preprocess_text_embeds(self, context, unprocessed=False):
# LTXv2 fully processed context has dimension of self.caption_channels * 2
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
if not unprocessed:
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
return context
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
context_vid = context[:, :, :self.cross_attention_dim]
context_audio = context[:, :, self.cross_attention_dim:]
else:
context_vid = context
context_audio = context
if self.caption_proj_before_connector:
context_vid = self.caption_projection(context_vid)
context_audio = self.audio_caption_projection(context_audio)
out_vid = self.video_embeddings_connector(context_vid)[0]
out_audio = self.audio_embeddings_connector(context_audio)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
apply_gated_attention=self.apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
timestep.max().expand_as(a_timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
a_timestep.max().expand_as(timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
a_prompt_timestep = compute_prompt_timestep(
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
)
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
a_prompt_timestep = None
return [v_timestep, a_timestep, cross_av_timestep_ss], [
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
v_embedded_timestep,
a_embedded_timestep,
]
], None
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
video_dim = vx.shape[-1]
audio_dim = ax.shape[-1]
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
context, [v_context_dim, a_context_dim], len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
if self.caption_proj_before_connector is False:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
a_context = a_context.view(batch_size, -1, audio_dim)
return [v_context, a_context], attention_mask
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
v_prompt_timestep = timestep[3]
a_prompt_timestep = timestep[4]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
v_prompt_timestep=args.get("v_prompt_timestep"),
a_prompt_timestep=args.get("a_prompt_timestep"),
)
return out
@@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
"v_prompt_timestep": v_prompt_timestep,
"a_prompt_timestep": a_prompt_timestep,
},
{"original_block": block_wrap},
)
@@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
v_prompt_timestep=v_prompt_timestep,
a_prompt_timestep=a_prompt_timestep,
)
return [vx, ax]

View File

@@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
d_head,
context_dim=None,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
heads=n_heads,
dim_head=d_head,
context_dim=None,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,

View File

@@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class NormSingleLinearTextProjection(nn.Module):
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
def __init__(
self, in_features, hidden_size, dtype=None, device=None, operations=None
):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.in_norm = operations.RMSNorm(
in_features, eps=1e-6, elementwise_affine=False
)
self.linear_1 = operations.Linear(
in_features, hidden_size, bias=True, dtype=dtype, device=device
)
self.hidden_size = hidden_size
self.in_features = in_features
def forward(self, caption):
caption = self.in_norm(caption)
caption = caption * (self.hidden_size / self.in_features) ** 0.5
return self.linear_1(caption)
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
@@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
dim_head=64,
dropout=0.0,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
else:
self.to_gate_logits = None
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
@@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
out = out.view(b, t, self.heads, self.dim_head)
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
out = out * gates.unsqueeze(-1)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
ADALN_BASE_PARAMS_COUNT = 6
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
@@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
if self.cross_attention_adaln:
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
x += apply_cross_attention_adaln(
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
)
else:
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
@@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
return x
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
"""Compute a single global prompt timestep for cross-attention ADaLN.
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
over text tokens. Returns None when *adaln_module* is None.
"""
if adaln_module is None:
return None
ts_input = (
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
if timestep_scaled.dim() > 1
else timestep_scaled.flatten()
)
prompt_ts, _ = adaln_module(
ts_input,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
def apply_cross_attention_adaln(
x, context, attn, q_shift, q_scale, q_gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask=None, transformer_options={},
):
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
that both regular tensors and CompressedTimestep are supported.
"""
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
@@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
caption_projection_first_linear=True,
dtype=None,
device=None,
operations=None,
@@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.caption_proj_before_connector = caption_proj_before_connector
self.cross_attention_adaln = cross_attention_adaln
self.caption_projection_first_linear = caption_projection_first_linear
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
@@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
else:
self.prompt_adaln_single = None
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.caption_projection = lambda a: a
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
@@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
return timestep, embedded_timestep, prompt_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
if self.caption_proj_before_connector is False:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
@@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
merged_args["prompt_timestep"] = prompt_timestep
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
@@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
timestep_scale_multiplier=1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prompt_timestep = kwargs.get("prompt_timestep", None)
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
prompt_timestep=prompt_timestep,
)
return x

View File

@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)

View File

@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
super().__init__()
if config is None:
config = self._guess_config()
config = self.get_default_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
def get_default_config(self):
ddconfig = {
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}
return config

View File

@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

View File

@@ -2,7 +2,9 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import comfy.model_management
import numpy as np
import math
ops = comfy.ops.disable_weight_init
@@ -12,6 +14,307 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor):
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride=1,
padding=True,
padding_mode="replicate",
kernel_size=12,
):
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
# Uses replicate boundary padding, matching the reference resampler exactly.
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter, persistent=persistent)
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio=2,
down_ratio=2,
up_kernel_size=12,
down_kernel_size=12,
):
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# ---------------------------------------------------------------------------
# BigVGAN v2 activations (Snake / SnakeBeta)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
# ---------------------------------------------------------------------------
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
# ---------------------------------------------------------------------------
class AMPBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
self.acts1 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
)
self.acts2 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
)
def forward(self, x):
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# HiFi-GAN residual blocks
# ---------------------------------------------------------------------------
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
@@ -119,6 +422,7 @@ class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
"""
def __init__(self, config=None):
@@ -128,19 +432,39 @@ class Vocoder(torch.nn.Module):
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
activation = config.get("activation", "snake")
use_bias_at_final = config.get("use_bias_at_final", True)
# "output_sample_rate" is not present in recent checkpoint configs.
# When absent (None), AudioVAE.output_sample_rate computes it as:
# sample_rate * vocoder.upsample_factor / mel_hop_length
# where upsample_factor = product of all upsample stride lengths,
# and mel_hop_length is loaded from the autoencoder config at
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
self.output_sample_rate = config.get("output_sample_rate")
self.resblock = config.get("resblock", "1")
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
self.apply_final_activation = config.get("apply_final_activation", True)
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
if self.resblock == "1":
resblock_cls = ResBlock1
elif self.resblock == "2":
resblock_cls = ResBlock2
elif self.resblock == "AMP1":
resblock_cls = AMPBlock1
else:
raise ValueError(f"Unknown resblock type: {self.resblock}")
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@@ -157,25 +481,40 @@ class Vocoder(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
if self.resblock == "AMP1":
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
if self.resblock == "AMP1":
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(ch))
else:
self.act_post = nn.LeakyReLU()
self.conv_post = ops.Conv1d(
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
"activation": "snake",
"use_bias_at_final": True,
"use_tanh_at_final": True,
}
return config
@@ -196,8 +535,10 @@ class Vocoder(torch.nn.Module):
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
if self.resblock != "AMP1":
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
@@ -206,8 +547,167 @@ class Vocoder(torch.nn.Module):
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.act_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
if self.apply_final_activation:
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, -1, 1)
return x
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
def __init__(self, filter_length: int, hop_length: int, win_length: int):
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input — no lookahead.
The STFT is computed by convolving the padded signal with forward_basis.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
Computed in float32 for numerical stability, then cast back to
the input dtype.
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
sampling_rate: int,
mel_fmin: float,
mel_fmax: float,
):
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(
self, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class VocoderWithBWE(torch.nn.Module):
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
"""
def __init__(self, config):
super().__init__()
vocoder_config = config["vocoder"]
bwe_config = config["bwe"]
self.vocoder = Vocoder(config=vocoder_config)
self.bwe_generator = Vocoder(
config={**bwe_config, "apply_final_activation": False}
)
self.input_sample_rate = bwe_config["input_sampling_rate"]
self.output_sample_rate = bwe_config["output_sampling_rate"]
self.hop_length = bwe_config["hop_length"]
self.mel_stft = MelSTFT(
filter_length=bwe_config["n_fft"],
hop_length=bwe_config["hop_length"],
win_length=bwe_config["n_fft"],
n_mel_channels=bwe_config["num_mels"],
sampling_rate=bwe_config["input_sampling_rate"],
mel_fmin=0.0,
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
)
self.resampler = UpSample1d(
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
persistent=False,
window_type="hann",
)
def _compute_mel(self, audio):
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
B, C, T = audio.shape
flat = audio.reshape(B * C, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec):
x = self.vocoder(mel_spec)
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate
remainder = T_low % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
mel = self._compute_mel(x)
residual = self.bwe_generator(mel)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :T_out]

View File

@@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))

View File

@@ -796,6 +796,8 @@ def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
for buf_name, buf in module.named_buffers(recurse=False):
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
def cleanup_models():
@@ -828,11 +830,14 @@ def unet_offload_device():
return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
cpu_dev = torch.device("cpu")
if comfy.memory_management.aimdo_enabled:
return cpu_dev
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
@@ -840,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
@@ -943,6 +948,9 @@ def text_encoder_device():
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if comfy.memory_management.aimdo_enabled:
return offload_device
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
@@ -1658,12 +1666,16 @@ def lora_compute_dtype(device):
return dtype
def synchronize():
if cpu_mode():
return
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
if cpu_mode():
return
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()

View File

@@ -241,6 +241,7 @@ class ModelPatcher:
self.patches = {}
self.backup = {}
self.backup_buffers = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
@@ -306,10 +307,16 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
@@ -336,7 +343,7 @@ class ModelPatcher:
n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
@@ -698,7 +705,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, for_dynamic=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
default = False
@@ -726,8 +733,13 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
# Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -1459,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher):
vbar = self._vbar_get()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
def pin_weight_to_device(self, key):
@@ -1507,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort()
for x in loading:
_, _, _, n, m, params = x
*_, module_mem, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
@@ -1579,11 +1585,22 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
casted_weight = weight.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_param(self.model, key, casted_weight)
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
move_weight_functions(m, device_to)
for key, buf in self.model.named_buffers(recurse=True):
if key not in self.backup_buffers:
self.backup_buffers[key] = buf
module, buf_name = comfy.utils.resolve_attr(self.model, key)
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
casted_buf = buf.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
@@ -1607,15 +1624,17 @@ class ModelPatcherDynamic(ModelPatcher):
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return

View File

@@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None
xfer_dest = None
@@ -269,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
#FIXME: This is really bad RTTI
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
@@ -660,23 +675,29 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
from cublas_ops import CublasLinear, cublas_half_matmul
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
class cublas_ops(manual_cast):
class Linear(CublasLinear, manual_cast.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations

View File

@@ -428,7 +428,7 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
@@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:

View File

@@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
super().__init__()
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
def forward(self, x):
source_dim = x.shape[-1]
x = x.movedim(1, -1)
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
return torch.cat((video, audio), dim=-1)
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.compat_mode = False
self.text_projection_type = text_projection_type
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
if self.text_projection_type == "single_linear":
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
elif self.text_projection_type == "dual_linear":
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
if self.text_projection_type == "single_linear":
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
return out.to(out_device), pooled
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
extra = {}
else:
extra = {"unprocessed_ltxav_embeds": True}
elif self.text_projection_type == "dual_linear":
out = self.text_embedding_projection(out)
extra = {"unprocessed_ltxav_embeds": True}
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
@@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
@@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
@@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
return LTXAVTEModel_
def sd_detect(state_dict_list, prefix=""):
for sd in state_dict_list:
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
return {"text_projection_type": "dual_linear"}
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
return {"text_projection_type": "single_linear"}
return {}
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
ATTR_UNSET={}
def set_attr(obj, attr, value):
def resolve_attr(obj, attr):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
return obj, attrs[-1]
def set_attr(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
delattr(obj, name)
else:
setattr(obj, attrs[-1], value)
setattr(obj, name, value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
obj.register_buffer(name, value, persistent=persistent)
return prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")

View File

@@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
elif isinstance(path, io.BytesIO):
# BytesIO has no file extension, so av.open can't infer the format.
# Default to mp4 since that's the only supported format anyway.
extra_kwargs["format"] = "mp4"
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:

View File

@@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
resolution: str = Field(...)
class InputUrlObject(BaseModel):
@@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel):
model: str = Field(...)
image: InputUrlObject = Field(...)
images: list[InputUrlObject] = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
class VideoGenerationRequest(BaseModel):
@@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel):
@@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)
usage: UsageObject | None = Field(None)

View File

@@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)

View File

@@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
)
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
category="api node/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.String.Input(
"prompt",
multiline=True,
@@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
],
outputs=[
IO.Image.Output(),
@@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": $rate * widgets.number_of_images}
)
""",
),
)
@@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio: str,
number_of_images: int,
seed: int,
resolution: str = "1K",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
@@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
n=number_of_images,
seed=seed,
resolution=resolution.lower(),
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Image.Input("image"),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.Image.Input("image", display_name="images"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input("resolution", options=["1K"]),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input(
"number_of_images",
default=1,
@@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
optional=True,
tooltip="Only allowed when multiple images are connected to the image input.",
),
],
outputs=[
IO.Image.Output(),
@@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
)
""",
),
)
@@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
resolution: str,
number_of_images: int,
seed: int,
aspect_ratio: str = "auto",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
if model == "grok-imagine-image-pro":
if get_number_of_images(image) > 1:
raise ValueError("The pro model supports only 1 input image.")
elif get_number_of_images(image) > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
category="api node/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
expr="""
(
$base := 0.181 * widgets.duration;
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
)
""",
@@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
category="api node/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
),
)
@@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))

View File

@@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
],
outputs=[
IO.Video.Output(),
@@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound: bool,
character_orientation: str,
mode: str,
model: str = "kling-v2-6",
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
@@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
model_name=model,
),
)
if response.code:

View File

@@ -1049,48 +1049,6 @@ class ManualSigmas(io.ComfyNode):
sigmas = torch.FloatTensor(sigmas)
return io.NodeOutput(sigmas)
class CurveToSigmas(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CurveToSigmas",
display_name="Curve to Sigmas",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Curve.Input("curve", default=[[0.0, 1.0], [1.0, 0.0]]),
io.Model.Input("model", optional=True),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, optional=True),
],
outputs=[io.Sigmas.Output()]
)
@classmethod
def execute(cls, curve, steps, sigma_max=14.614642, model=None) -> io.NodeOutput:
points = sorted(curve, key=lambda p: p[0])
model_sampling = model.get_model_object("model_sampling") if model is not None else None
sigmas = []
for i in range(steps + 1):
t = i / steps
y = points[0][1] if t < points[0][0] else points[-1][1]
for j in range(len(points) - 1):
if points[j][0] <= t <= points[j + 1][0]:
x0, y0 = points[j]
x1, y1 = points[j + 1]
y = y0 if x1 == x0 else y0 + (y1 - y0) * (t - x0) / (x1 - x0)
break
if model_sampling is not None:
sigmas.append(float(model_sampling.percent_to_sigma(1.0 - y)))
else:
sigmas.append(y * sigma_max)
sigmas[-1] = 0.0
return io.NodeOutput(torch.FloatTensor(sigmas))
get_sigmas = execute
class CustomSamplersExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -1130,7 +1088,6 @@ class CustomSamplersExtension(ComfyExtension):
AddNoise,
SamplerCustomAdvanced,
ManualSigmas,
CurveToSigmas,
]

View File

@@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode):
return frame_idx, latent_idx
@classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
if causal_fix is None:
causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix)
pixel_coords[:, 0] += frame_idx
# The following adjusts keyframe end positions for small grid IC-LoRA.
@@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode):
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None):
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
raise ValueError("Adding guide to a combined AV latent is not supported.")
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
if guide_mask is not None:
target_h = max(noise_mask.shape[3], guide_mask.shape[3])

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.15.1"
__version__ = "0.16.3"

View File

@@ -2034,24 +2034,6 @@ class ImagePadForOutpaint:
return (new_image, mask.unsqueeze(0))
class CurveEditor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
}
}
RETURN_TYPES = ("CURVE",)
RETURN_NAMES = ("curve",)
FUNCTION = "execute"
CATEGORY = "utils"
def execute(self, curve):
return (curve,)
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
@@ -2120,7 +2102,6 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
"CurveEditor": CurveEditor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -2189,7 +2170,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
"CurveEditor": "Curve Editor",
}
EXTENSION_WEB_DIRS = {}

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.15.1"
version = "0.16.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.5
comfyui-workflow-templates==0.9.10
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.4
comfy-aimdo>=0.2.7
requests
#non essential dependencies: