mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-10 02:19:31 +00:00
Compare commits
12 Commits
node-essen
...
range-edit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d721bff59 | ||
|
|
de67bb0870 | ||
|
|
9d70c2626b | ||
|
|
cb64b957f2 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 |
@@ -1,6 +1,7 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
@@ -35,6 +36,14 @@ reviews:
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
@@ -74,7 +83,11 @@ reviews:
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
@@ -84,7 +97,7 @@ reviews:
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: true
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
|
||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
|
||||
@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = None
|
||||
aimdo_enabled = False
|
||||
|
||||
@@ -988,10 +988,14 @@ class LTXAV(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
@@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
from transformers import T5TokenizerFast
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.genmo
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import torch
|
||||
import comfy.utils
|
||||
import math
|
||||
@@ -109,22 +108,6 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
self.gemma3_12b.set_clip_options(options)
|
||||
@@ -146,10 +129,6 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
return out.to(out_device), pooled
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
@@ -159,14 +138,14 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
else:
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
missing_all = []
|
||||
unexpected_all = []
|
||||
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
@@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
|
||||
@@ -1237,6 +1237,82 @@ class BoundingBox(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
Type = list
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = [[0, 0], [1, 1]]
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="RANGE")
|
||||
class Range(ComfyTypeIO):
|
||||
Type = dict # {"min": float, "max": float, "midpoint"?: float}
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None,
|
||||
display: str=None,
|
||||
gradient_stops: list=None,
|
||||
show_midpoint: bool=None,
|
||||
midpoint_scale: str=None,
|
||||
value_min: float=None,
|
||||
value_max: float=None,
|
||||
advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = {"min": 0.0, "max": 1.0}
|
||||
self.display = display
|
||||
self.gradient_stops = gradient_stops
|
||||
self.show_midpoint = show_midpoint
|
||||
self.midpoint_scale = midpoint_scale
|
||||
self.value_min = value_min
|
||||
self.value_max = value_max
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"display": self.display,
|
||||
"gradient_stops": self.gradient_stops,
|
||||
"show_midpoint": self.show_midpoint,
|
||||
"midpoint_scale": self.midpoint_scale,
|
||||
"value_min": self.value_min,
|
||||
"value_max": self.value_max,
|
||||
})
|
||||
|
||||
|
||||
@comfytype(io_type="COLOR_CURVES")
|
||||
class ColorCurves(ComfyTypeIO):
|
||||
class ColorCurvesDict(TypedDict):
|
||||
rgb: list[list[float]]
|
||||
red: list[list[float]]
|
||||
green: list[list[float]]
|
||||
blue: list[list[float]]
|
||||
|
||||
Type = ColorCurvesDict
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = {
|
||||
"rgb": [[0, 0], [1, 1]],
|
||||
"red": [[0, 0], [1, 1]],
|
||||
"green": [[0, 0], [1, 1]],
|
||||
"blue": [[0, 0], [1, 1]]
|
||||
}
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2223,5 +2299,7 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"ColorCurves",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
137
comfy_extras/nodes_color_curves.py
Normal file
137
comfy_extras/nodes_color_curves.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing_extensions import override
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
|
||||
|
||||
def _monotone_cubic_hermite(xs, ys, x_query):
|
||||
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(x_query)
|
||||
if n == 1:
|
||||
return np.full_like(x_query, ys[0])
|
||||
|
||||
# Compute slopes
|
||||
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
|
||||
|
||||
# Compute tangents (Fritsch-Carlson)
|
||||
slopes = np.zeros(n)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
# Enforce monotonicity
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0
|
||||
slopes[i + 1] = 0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha ** 2 + beta ** 2
|
||||
if s > 9:
|
||||
t = 3 / np.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
|
||||
# Evaluate
|
||||
result = np.zeros_like(x_query, dtype=np.float64)
|
||||
indices = np.searchsorted(xs, x_query, side='right') - 1
|
||||
indices = np.clip(indices, 0, n - 2)
|
||||
|
||||
for i in range(n - 1):
|
||||
mask = indices == i
|
||||
if not np.any(mask):
|
||||
continue
|
||||
dx = xs[i + 1] - xs[i]
|
||||
if dx == 0:
|
||||
result[mask] = ys[i]
|
||||
continue
|
||||
t = (x_query[mask] - xs[i]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
|
||||
|
||||
# Clamp edges
|
||||
result[x_query <= xs[0]] = ys[0]
|
||||
result[x_query >= xs[-1]] = ys[-1]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _build_lut(points):
|
||||
"""Build a 256-entry LUT from curve control points in [0,1] space."""
|
||||
if not points or len(points) < 2:
|
||||
return np.arange(256, dtype=np.float64) / 255.0
|
||||
|
||||
pts = sorted(points, key=lambda p: p[0])
|
||||
xs = np.array([p[0] for p in pts], dtype=np.float64)
|
||||
ys = np.array([p[1] for p in pts], dtype=np.float64)
|
||||
|
||||
x_query = np.linspace(0, 1, 256)
|
||||
lut = _monotone_cubic_hermite(xs, ys, x_query)
|
||||
return np.clip(lut, 0, 1)
|
||||
|
||||
|
||||
class ColorCurvesNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorCurves",
|
||||
display_name="Color Curves",
|
||||
category="image/adjustment",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.ColorCurves.Input("settings"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
|
||||
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
|
||||
red_pts = settings.get("red", [[0, 0], [1, 1]])
|
||||
green_pts = settings.get("green", [[0, 0], [1, 1]])
|
||||
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
|
||||
|
||||
rgb_lut = _build_lut(rgb_pts)
|
||||
red_lut = _build_lut(red_pts)
|
||||
green_lut = _build_lut(green_pts)
|
||||
blue_lut = _build_lut(blue_pts)
|
||||
|
||||
# Convert to numpy for LUT application
|
||||
img_np = image.cpu().numpy().copy()
|
||||
|
||||
# Apply per-channel curves then RGB master curve.
|
||||
# Index with floor(val * 256) clamped to [0, 255] to match GPU NEAREST
|
||||
# texture sampling on a 256-wide LUT texture.
|
||||
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
|
||||
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
|
||||
img_np[..., ch] = ch_lut[indices]
|
||||
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
|
||||
img_np[..., ch] = rgb_lut[indices]
|
||||
|
||||
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
|
||||
return io.NodeOutput(result, ui=ui.PreviewImage(result))
|
||||
|
||||
|
||||
class ColorCurvesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [ColorCurvesNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ColorCurvesExtension:
|
||||
return ColorCurvesExtension()
|
||||
@@ -716,12 +716,12 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(tex)
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(tex)
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(tex)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
|
||||
@@ -10,7 +10,7 @@ class NAGuidance(io.ComfyNode):
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
category="advanced/guidance",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
|
||||
@@ -29,6 +29,7 @@ class StringMultiline(io.ComfyNode):
|
||||
node_id="PrimitiveStringMultiline",
|
||||
display_name="String (Multiline)",
|
||||
category="utils/primitive",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
io.String.Input("value", multiline=True),
|
||||
],
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import subprocess
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
if os.name == 'nt':
|
||||
@@ -87,10 +85,6 @@ if not args.cuda_malloc:
|
||||
except:
|
||||
pass
|
||||
|
||||
if enables_dynamic_vram() and comfy_aimdo.control.init():
|
||||
args.cuda_malloc = False
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
|
||||
|
||||
if args.disable_cuda_malloc:
|
||||
args.cuda_malloc = False
|
||||
|
||||
|
||||
22
execution.py
22
execution.py
@@ -9,7 +9,6 @@ import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
@@ -521,19 +520,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
|
||||
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.model_vbar.vbars_analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
|
||||
11
main.py
11
main.py
@@ -173,6 +173,10 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
comfy_aimdo.control.init()
|
||||
|
||||
import comfy.utils
|
||||
|
||||
@@ -188,13 +192,9 @@ import hook_breaker_ac10a0
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
import comfy_aimdo.control
|
||||
import comfy_aimdo.torch
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
@@ -208,11 +208,10 @@ if enables_dynamic_vram():
|
||||
comfy_aimdo.control.set_log_info()
|
||||
|
||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
|
||||
comfy.memory_management.aimdo_enabled = True
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
|
||||
147
nodes.py
147
nodes.py
@@ -70,7 +70,6 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
ESSENTIALS_CATEGORY = "Basics"
|
||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"]
|
||||
|
||||
@@ -2036,6 +2035,144 @@ class ImagePadForOutpaint:
|
||||
return (new_image, mask.unsqueeze(0))
|
||||
|
||||
|
||||
class TestCurveWidget:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("points",)
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "testing"
|
||||
|
||||
def execute(self, curve):
|
||||
import json
|
||||
result = json.dumps(curve, indent=2)
|
||||
print("Curve points:", result)
|
||||
return {"ui": {"text": [result]}, "result": (result,)}
|
||||
|
||||
|
||||
class TestRangePlain:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"range": ("RANGE", {"default": {"min": 0.0, "max": 1.0}}),
|
||||
"range_midpoint": ("RANGE", {
|
||||
"default": {"min": 0.2, "max": 0.8, "midpoint": 0.5},
|
||||
"show_midpoint": True,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "testing"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
import json
|
||||
result = json.dumps(kwargs, indent=2)
|
||||
return {"ui": {"text": [result]}, "result": (result,)}
|
||||
|
||||
|
||||
class TestRangeGradient:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"range": ("RANGE", {
|
||||
"default": {"min": 0.0, "max": 1.0},
|
||||
"display": "gradient",
|
||||
"gradient_stops": [
|
||||
{"offset": 0.0, "color": [0, 0, 0]},
|
||||
{"offset": 1.0, "color": [255, 255, 255]}
|
||||
],
|
||||
}),
|
||||
"range_midpoint": ("RANGE", {
|
||||
"default": {"min": 0.0, "max": 1.0, "midpoint": 0.5},
|
||||
"display": "gradient",
|
||||
"gradient_stops": [
|
||||
{"offset": 0.0, "color": [0, 0, 0]},
|
||||
{"offset": 1.0, "color": [255, 255, 255]}
|
||||
],
|
||||
"show_midpoint": True,
|
||||
"midpoint_scale": "gamma",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "testing"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
import json
|
||||
result = json.dumps(kwargs, indent=2)
|
||||
return {"ui": {"text": [result]}, "result": (result,)}
|
||||
|
||||
|
||||
class TestRangeHistogram:
|
||||
RANGE_OPTS = {
|
||||
"display": "histogram",
|
||||
"show_midpoint": True,
|
||||
"midpoint_scale": "gamma",
|
||||
"value_min": 0,
|
||||
"value_max": 255,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
default = {"min": 0, "max": 255, "midpoint": 0.5}
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"rgb": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
|
||||
"red": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
|
||||
"green": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
|
||||
"blue": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "testing"
|
||||
|
||||
def execute(self, image, rgb, red, green, blue):
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
img = image[0].cpu().numpy() # (H, W, C)
|
||||
|
||||
# Per-channel histograms
|
||||
hist_r, _ = np.histogram(img[:, :, 0].flatten(), bins=256, range=(0.0, 1.0))
|
||||
hist_g, _ = np.histogram(img[:, :, 1].flatten(), bins=256, range=(0.0, 1.0))
|
||||
hist_b, _ = np.histogram(img[:, :, 2].flatten(), bins=256, range=(0.0, 1.0))
|
||||
|
||||
# Luminance histogram (BT.709)
|
||||
luminance = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
|
||||
hist_rgb, _ = np.histogram(luminance.flatten(), bins=256, range=(0.0, 1.0))
|
||||
|
||||
result = json.dumps({"rgb": rgb, "red": red, "green": green, "blue": blue}, indent=2)
|
||||
return {
|
||||
"ui": {
|
||||
"text": [result],
|
||||
"range_histogram_rgb": hist_rgb.astype(np.uint32).tolist(),
|
||||
"range_histogram_red": hist_r.astype(np.uint32).tolist(),
|
||||
"range_histogram_green": hist_g.astype(np.uint32).tolist(),
|
||||
"range_histogram_blue": hist_b.astype(np.uint32).tolist(),
|
||||
},
|
||||
"result": (result,)
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"KSampler": KSampler,
|
||||
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||
@@ -2104,6 +2241,10 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
||||
"TestCurveWidget": TestCurveWidget,
|
||||
"TestRangePlain": TestRangePlain,
|
||||
"TestRangeGradient": TestRangeGradient,
|
||||
"TestRangeHistogram": TestRangeHistogram,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -2172,6 +2313,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
"TestCurveWidget": "Test Curve Widget",
|
||||
"TestRangePlain": "Test Range (Plain)",
|
||||
"TestRangeGradient": "Test Range (Gradient)",
|
||||
"TestRangeHistogram": "Test Range (Histogram)",
|
||||
}
|
||||
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.1.8
|
||||
comfy-aimdo>=0.2.0
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
Reference in New Issue
Block a user