Compare commits

..

52 Commits

Author SHA1 Message Date
Terry Jia
4d721bff59 Range Editor 2026-02-26 08:22:46 -05:00
Terry Jia
de67bb0870 feat: add COLOR_CURVES type and ColorCurves node 2026-02-22 20:42:53 -05:00
Terry Jia
9d70c2626b test 2026-02-22 19:24:54 -05:00
Terry Jia
cb64b957f2 CURVE type 2026-02-22 15:06:46 -05:00
comfyanonymous
07ca6852e8 Fix dtype issue in embeddings connector. (#12570) 2026-02-22 03:18:20 -05:00
comfyanonymous
f266b8d352 Move LTXAV av embedding connectors to diffusion model. (#12569) 2026-02-21 22:29:58 -05:00
Christian Byrne
b6cb30bab5 chore: tune CodeRabbit config to limit review scope and disable for drafts (#12567)
* chore: tune CodeRabbit config to limit review scope and disable for drafts

- Add tone_instructions to focus only on newly introduced issues
- Add global path_instructions entry to ignore pre-existing issues in moved/reformatted code
- Disable draft PR reviews (drafts: false) and add WIP title keywords
- Disable ruff tool to prevent linter-based outside-diff-range comments

Addresses feedback from maintainers about CodeRabbit flagging pre-existing
issues in code that was merely moved or de-indented (e.g., PR #12557),
which can discourage community contributions and cause scope creep.

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba

* chore: add 'DO NOT MERGE' to ignore_title_keywords

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba
2026-02-21 18:32:15 -08:00
Christian Byrne
ee72752162 Add category to Normalized Attention Guidance node (#12565) 2026-02-21 19:51:21 -05:00
Alexander Brown
7591d781a7 fix: specify UTF-8 encoding when reading subgraph files (#12563)
On Windows, Python defaults to cp1252 encoding when no encoding is
specified. JSON files containing UTF-8 characters (e.g., non-ASCII
characters) cause UnicodeDecodeError when read with cp1252.

This fixes the error that occurs when loading blueprint subgraphs
on Windows systems.

https://claude.ai/code/session_014WHi3SL9Gzsi3U6kbSjbSb

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-21 15:05:00 -08:00
rattus
0bfb936ab4 comfy-aimdo 0.2 - Improved pytorch allocator integration (#12557)
Integrate comfy-aimdo 0.2 which takes a different approach to
installing the memory allocator hook. Instead of using the complicated
and buggy pytorch MemPool+CudaPluggableAlloctor, cuda is directly hooked
making the process much more transparent to both comfy and pytorch. As
far as pytorch knows, aimdo doesnt exist anymore, and just operates
behind the scenes.

Remove all the mempool setup stuff for dynamic_vram and bump the
comfy-aimdo version. Remove the allocator object from memory_management
and demote its use as an enablment check to a boolean flag.

Comfy-aimdo 0.2 also support the pytorch cuda async allocator, so
remove the dynamic_vram based force disablement of cuda_malloc and
just go back to the old settings of allocators based on command line
input.
2026-02-21 10:52:57 -08:00
pythongosssss
602b2505a4 add support for pyopengl < 3.1.4 where the size parameter does not exist (#12555) 2026-02-21 06:14:57 -08:00
Christian Byrne
04a55d5019 fix: swap essentials_category from CLIPTextEncode to PrimitiveStringMultiline (#12553)
Remove CLIPTextEncode from Basics essentials category and add
PrimitiveStringMultiline (String Multiline) in its place.

Amp-Thread-ID: https://ampcode.com/threads/T-019c7efb-d916-7244-8c43-77b615ba0622

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 23:46:46 -08:00
Christian Byrne
5fb8f06495 feat: add essential subgraph blueprints (#12552)
Add 24 non-cloud essential blueprints from comfyui-wiki/Subgraph-Blueprints.
These cover common workflows: text/image/video generation, editing,
inpainting, outpainting, upscaling, depth maps, pose, captioning, and more.

Cloud-only blueprints (5) are excluded and will be added once
client-side distribution filtering lands.

Amp-Thread-ID: https://ampcode.com/threads/T-019c6f43-6212-7308-bea6-bfc35a486cbf
2026-02-20 23:40:13 -08:00
Terry Jia
5a182bfaf1 update glsl blueprint with gradient (#12548)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 23:02:36 -08:00
Terry Jia
f394af8d0f feat: add gradient-slider display mode for FLOAT inputs (#12536)
* feat: add gradient-slider display mode for FLOAT inputs

* fix: use precise type annotation list[list[float]] for gradient_stops

Amp-Thread-ID: https://ampcode.com/threads/T-019c7eea-be2b-72ce-a51f-838376f9b7a7

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: bymyself <cbyrne@comfy.org>
2026-02-20 22:52:32 -08:00
Arthur R Longbottom
aeb5bdc8f6 Fix non-contiguous audio waveform crash in video save (#12550)
Fixes #12549
2026-02-20 23:37:55 -05:00
comfyanonymous
64953bda0a Update nightly installation command for ROCm (#12547) 2026-02-20 21:32:05 -05:00
Christian Byrne
b254cecd03 chore: add CodeRabbit configuration for automated code review (#12539)
Amp-Thread-ID: https://ampcode.com/threads/T-019c7915-2abf-743c-9c74-b93d87d63055

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 00:25:54 -08:00
Alexander Piskun
1bb956fb66 [API Nodes] add ElevenLabs nodes (#12207)
* feat(api-nodes): add ElevenLabs API nodes

* added price badge for ElevenLabsInstantVoiceClone node

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 22:12:28 -08:00
pythongosssss
96d6bd1a4a Add GLSL shader node using PyOpenGL (#12148)
* adds support for executing simple glsl shaders
using moderngl package

* tidy

* Support multiple outputs

* Try fix build

* fix casing

* fix line endings

* convert to using PyOpenGL and glfw

* remove cpu support

* tidy

* add additional support for egl & osmesa backends

* fix ci
perf: only read required outputs

* add diagnostics, update mac initialization

* GLSL glueprints + node fixes (#12492)

* Add image operation blueprints

* Add channels

* Add glow

* brightness/contrast

* hsb

* add glsl shader update system

* shader nit iteration

* add multipass for faster blur

* more fixes

* rebuild blueprints

* print -> logger

* Add edge preserving blur

* fix: move _initialized flag to end of GLContext.__init__

Prevents '_vao' attribute error when init fails partway through
and subsequent calls skip initialization due to early _initialized flag.

* update valid ranges
- threshold 0-100
- step 0+

* fix value ranges

* rebuild node to remove extra inputs

* Fix gamma step

* clamp saturation in colorize instead of wrapping

* Fix crash on 1x1 px images

* rework description

* remove unnecessary f


Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Hunter Senft-Grupp <hunter@comfy.org>
2026-02-19 23:22:13 -05:00
comfyanonymous
5f2117528a Force min length 1 when tokenizing for text generation. (#12538) 2026-02-19 22:57:44 -05:00
comfyanonymous
0301ccf745 Small cleanup and try to get qwen 3 work with the text gen. (#12537) 2026-02-19 22:42:28 -05:00
Christian Byrne
4d172e9ad7 feat: mark 429 widgets as advanced for collapsible UI (#12197)
* feat: mark 429 widgets as advanced for collapsible UI

Mark widgets as advanced across core, comfy_extras, and comfy_api_nodes
to support the new collapsible advanced inputs section in the frontend.

Changes:
- 267 advanced markers in comfy_extras/
- 162 advanced markers in comfy_api_nodes/
- All files pass python3 -m py_compile verification

Widgets marked advanced (hidden by default):
- Scheduler internals: sigma_max, sigma_min, rho, mu, beta, alpha
- Sampler internals: eta, s_noise, order, rtol, atol, h_init, pcoeff, etc.
- Memory optimization: tile_size, overlap, temporal_size, temporal_overlap
- Pipeline controls: add_noise, start_at_step, end_at_step
- Timing controls: start_percent, end_percent
- Layer selection: stop_at_clip_layer, layers, block_number
- Video encoding: codec, crf, format
- Device/dtype: device, noise_device, dtype, weight_dtype

Widgets kept basic (always visible):
- Core params: strength, steps, cfg, denoise, seed, width, height
- Model selectors: ckpt_name, lora_name, vae_name, sampler_name
- Common controls: upscale_method, crop, batch_size, fps, opacity

Related: frontend PR #11939
Amp-Thread-ID: https://ampcode.com/threads/T-019c1734-6b61-702e-b333-f02c399963fc

* fix: remove advanced=True from DynamicCombo.Input (unsupported)

Amp-Thread-ID: https://ampcode.com/threads/T-019c1734-6b61-702e-b333-f02c399963fc

* fix: address review - un-mark model merge, video, image, and training node widgets as advanced

Per comfyanonymous review:
- Model merge arguments should not be advanced (all 14 model-specific merge classes)
- SaveAnimatedWEBP lossless/quality/method should not be advanced
- SaveWEBM/SaveVideo codec/crf/format should not be advanced
- TrainLoraNode options should not be advanced (7 inputs)

Amp-Thread-ID: https://ampcode.com/threads/T-019c322b-a3a8-71b7-9962-d44573ca6352

* fix: un-mark batch_size and webcam width/height as advanced (should stay basic)

Amp-Thread-ID: https://ampcode.com/threads/T-019c3236-1417-74aa-82a3-bcb365fbe9d1

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 19:20:02 -08:00
Yourz
5632b2df9d feat: add essentials_category (#12357)
* feat: add essentials_category field to node schema

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b25-cd90-7218-9071-03cb46b351b3

* feat: add ESSENTIALS_CATEGORY to core nodes

Marked nodes:
- Basic: LoadImage, SaveImage, LoadVideo, SaveVideo, Load3D, CLIPTextEncode
- Image Tools: ImageScale, ImageInvert, ImageBatch, ImageCrop, ImageRotate, ImageBlur
- Image Tools/Preprocessing: Canny
- Image Generation: LoraLoader
- Audio: LoadAudio, SaveAudio

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b25-cd90-7218-9071-03cb46b351b3

* Add ESSENTIALS_CATEGORY to more nodes

- SaveGLB (Basic)
- GetVideoComponents (Video Tools)
- TencentTextToModelNode, TencentImageToModelNode (3D)
- RecraftRemoveBackgroundNode (Image Tools)
- KlingLipSyncAudioToVideoNode (Video Generation)
- OpenAIChatNode (Text Generation)
- StabilityTextToAudio (Audio)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b69-81c1-71c3-8096-450a39e20910

* fix: correct essentials category for Canny node

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* refactor: replace essentials_category string literals with constants

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* refactor: revert constants, use string literals for essentials_category

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* fix: update basics

---------

Co-authored-by: bymyself <cbyrne@comfy.org>
Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 19:00:26 -08:00
Alexander Piskun
2687652530 fix(api-nodes): force Gemini to return uncompressed images (#12516) 2026-02-19 04:10:39 -08:00
Jukka Seppänen
6d11cc7354 feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392) 2026-02-18 20:49:43 -05:00
comfyanonymous
f262444dd4 Add simple 3 band equalizer node for audio. (#12519) 2026-02-18 18:36:35 -05:00
Alexander Piskun
239ddd3327 fix(api-nodes): add price badge for Rodin Gen-2 node (#12512) 2026-02-17 23:15:23 -08:00
Hunter
83dd65f23a fix: use glob matching for Gemini image MIME types (#12511)
gemini-3-pro-image-preview nondeterministically returns image/jpeg
instead of image/png. get_image_from_response() hardcoded
get_parts_by_type(response, "image/png"), silently dropping JPEG
responses and falling back to torch.zeros (all-black output).

Add _mime_matches() helper using fnmatch for glob-style MIME matching.
Change get_image_from_response() to request "image/*" so any image
format returned by the API is correctly captured.
2026-02-18 00:03:54 -05:00
Terry Jia
8ad38d2073 BBox widget (#11594)
* Boundingbox widget

* code improve

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 17:13:39 -08:00
Comfy Org PR Bot
6c14f129af Bump comfyui-frontend-package to 1.39.14 (#12494)
* Bump comfyui-frontend-package to 1.39.13

* Update requirements.txt

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 13:41:34 -08:00
rattus
58dcc97dcf ops: limit return of requants (#12506)
This check was far too broad and the dtype is not a reliable indicator
of wanting the requant (as QT returns the compute dtype as the dtype).
So explictly plumb whether fp8mm wants the requant or not.
2026-02-17 15:32:27 -05:00
comfyanonymous
19236edfa4 ComfyUI v0.14.1 2026-02-17 13:28:06 -05:00
ComfyUI Wiki
73c3f86973 chore: update workflow templates to v0.8.43 (#12507) 2026-02-17 13:25:55 -05:00
Alexander Piskun
262abf437b feat(api-nodes): add Recraft V4 nodes (#12502) 2026-02-17 13:25:44 -05:00
Alexander Piskun
5284e6bf69 feat(api-nodes): add "viduq3-turbo" model and Vidu3StartEnd node; fix the price badges (#12482) 2026-02-17 10:07:14 -08:00
chaObserv
44f8598521 Fix anima LLM adapter forward when manual cast (#12504) 2026-02-17 07:56:44 -08:00
comfyanonymous
fe52843fe5 ComfyUI v0.14.0 2026-02-17 00:39:54 -05:00
comfyanonymous
c39653163d Fix anima preprocess text embeds not using right inference dtype. (#12501) 2026-02-17 00:29:20 -05:00
comfyanonymous
18927538a1 Implement NAG on all the models based on the Flux code. (#12500)
Use the Normalized Attention Guidance node.

Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
2026-02-16 23:30:34 -05:00
Jedrzej Kosinski
8a6fbc2dc2 Allow control_after_generate to be type ControlAfterGenerate in v3 schema (#12187) 2026-02-16 22:20:21 -05:00
Alex Butler
b44fc4c589 add venv* to gitignore (#12431) 2026-02-16 22:16:19 -05:00
comfyanonymous
4454fab7f0 Remove code to support RMSNorm on old pytorch. (#12499) 2026-02-16 20:09:24 -05:00
ComfyUI Wiki
1978f59ffd chore: update workflow templates to v0.8.42 (#12491) 2026-02-16 17:33:43 -05:00
comfyanonymous
88e6370527 Remove workaround for old pytorch. (#12480) 2026-02-15 20:43:53 -05:00
rattus
c0370044cd MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446)
* lora: add weight shape calculations.

This lets the loader know if a lora will change the shape of a weight
so it can take appropriate action.

* MPDynamic: force load flux img_in weight

This weight is a bit special, in that the lora changes its geometry.
This is rather unique, not handled by existing estimate and doesn't
work for either offloading or dynamic_vram.

Fix for dynamic_vram as a special case. Ideally we can fully precalculate
these lora geometry changes at load time, but just get these models
working first.
2026-02-15 20:30:09 -05:00
rattus
ecd2a19661 Fix lora Extraction in offload conditions (+ dynamic_vram mode) (#12479)
* lora_extract: Add a trange

If you bite off more than your GPU can chew, this kinda just hangs.
Give a rough indication of progress counting the weights in a trange.

* lora_extract: Support on-the-fly patching

Use the on-the-fly approach from the regular model saving logic for
lora extraction too. Switch off force_cast_weights accordingly.

This gets extraction working in dynamic vram while also supporting
extraction on GPU offloaded.
2026-02-15 20:28:51 -05:00
Alexander Piskun
2c1d06a4e3 feat(api-nodes): add Bria RMBG nodes (#12465)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-15 17:22:30 -08:00
Alexander Piskun
e2c71ceb00 feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes (#12428)
* feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes

* add image output to TencentModelTo3DUV node

* commented out two nodes

* added rate_limit check to other hunyuan3d nodes
2026-02-15 05:33:18 -08:00
Jedrzej Kosinski
596ed68691 Node Replacement API (#12014) 2026-02-15 02:12:30 -08:00
Alexander Piskun
ce4a1ab48d chore(api-nodes): remove "gpt-4o" model (#12467) 2026-02-15 01:31:59 -08:00
comfyanonymous
e1ede29d82 Remove unsafe pickle loading code that was used on pytorch older than 2.4 (#12473)
ComfyUI hasn't started on pytorch 2.4 since last month.
2026-02-14 22:53:52 -05:00
180 changed files with 6312 additions and 810 deletions

127
.coderabbit.yaml Normal file
View File

@@ -0,0 +1,127 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: false
poem: false
review_status: false
review_details: false
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
sequence_diagrams: false
estimate_code_review_effort: false
assess_linked_issues: false
related_issues: false
related_prs: false
suggested_labels: false
auto_apply_labels: false
suggested_reviewers: false
auto_assign_reviewers: false
in_progress_fortune: false
enable_prompt_for_ai_agents: true
path_filters:
- "!comfy_api_nodes/apis/**"
- "!**/generated/*.pyi"
- "!.ci/**"
- "!script_examples/**"
- "!**/__pycache__/**"
- "!**/*.ipynb"
- "!**/*.png"
- "!**/*.bat"
path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**"
instructions: |
Core ML/diffusion engine. Focus on:
- Backward compatibility (breaking changes affect all custom nodes)
- Memory management and GPU resource handling
- Performance implications in hot paths
- Thread safety for concurrent execution
- path: "comfy_api_nodes/**"
instructions: |
Third-party API integration nodes. Focus on:
- No hardcoded API keys or secrets
- Proper error handling for API failures (timeouts, rate limits, auth errors)
- Correct Pydantic model usage
- Security of user data passed to external APIs
- path: "comfy_extras/**"
instructions: |
Community-contributed extra nodes. Focus on:
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
- No breaking changes to existing node interfaces
- path: "comfy_execution/**"
instructions: |
Execution engine (graph execution, caching, jobs). Focus on:
- Caching correctness
- Concurrent execution safety
- Graph validation edge cases
- path: "nodes.py"
instructions: |
Core node definitions (2500+ lines). Focus on:
- Backward compatibility of NODE_CLASS_MAPPINGS
- Consistency of INPUT_TYPES return format
- path: "alembic_db/**"
instructions: |
Database migrations. Focus on:
- Migration safety and rollback support
- Data preservation during schema changes
auto_review:
enabled: true
auto_incremental_review: true
drafts: false
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches:
docstrings:
enabled: false
unit_tests:
enabled: false
tools:
ruff:
enabled: false
pylint:
enabled: false
flake8:
enabled: false
gitleaks:
enabled: true
shellcheck:
enabled: false
markdownlint:
enabled: false
yamllint:
enabled: false
languagetool:
enabled: false
github-checks:
enabled: true
timeout_ms: 90000
ast-grep:
essential_rules: true
chat:
auto_reply: true
knowledge_base:
opt_out: false
learnings:
scope: "auto"

2
.gitignore vendored
View File

@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
venv/
venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example

View File

@@ -229,9 +229,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.

105
app/node_replace_manager.py Normal file
View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@@ -53,7 +53,7 @@ class SubgraphManager:
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
with open(entry['path'], 'r', encoding='utf-8') as f:
entry['data'] = f.read()
return entry

View File

@@ -0,0 +1,44 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Brightness slider -100..100
uniform float u_float1; // Contrast slider -100..100
in vec2 v_texCoord;
out vec4 fragColor;
const float MID_GRAY = 0.18; // 18% reflectance
// sRGB gamma 2.2 approximation
vec3 srgbToLinear(vec3 c) {
return pow(max(c, 0.0), vec3(2.2));
}
vec3 linearToSrgb(vec3 c) {
return pow(max(c, 0.0), vec3(1.0/2.2));
}
float mapBrightness(float b) {
return clamp(b / 100.0, -1.0, 1.0);
}
float mapContrast(float c) {
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
}
void main() {
vec4 orig = texture(u_image0, v_texCoord);
float brightness = mapBrightness(u_float0);
float contrast = mapContrast(u_float1);
vec3 lin = srgbToLinear(orig.rgb);
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
// Convert back to sRGB
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
fragColor = vec4(result, orig.a);
}

View File

@@ -0,0 +1,72 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Mode
uniform float u_float0; // Amount (0 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const int MODE_LINEAR = 0;
const int MODE_RADIAL = 1;
const int MODE_BARREL = 2;
const int MODE_SWIRL = 3;
const int MODE_DIAGONAL = 4;
const float AMOUNT_SCALE = 0.0005;
const float RADIAL_MULT = 4.0;
const float BARREL_MULT = 8.0;
const float INV_SQRT2 = 0.70710678118;
void main() {
vec2 uv = v_texCoord;
vec4 original = texture(u_image0, uv);
float amount = u_float0 * AMOUNT_SCALE;
if (amount < 0.000001) {
fragColor = original;
return;
}
// Aspect-corrected coordinates for circular effects
float aspect = u_resolution.x / u_resolution.y;
vec2 centered = uv - 0.5;
vec2 corrected = vec2(centered.x * aspect, centered.y);
float r = length(corrected);
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
vec2 offset = vec2(0.0);
if (u_int0 == MODE_LINEAR) {
// Horizontal shift (no aspect correction needed)
offset = vec2(amount, 0.0);
}
else if (u_int0 == MODE_RADIAL) {
// Outward from center, stronger at edges
offset = dir * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_BARREL) {
// Lens distortion simulation (r² falloff)
offset = dir * r * r * amount * BARREL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_SWIRL) {
// Perpendicular to radial (rotational aberration)
vec2 perp = vec2(-dir.y, dir.x);
offset = perp * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_DIAGONAL) {
// 45° offset (no aspect correction needed)
offset = vec2(amount, amount) * INV_SQRT2;
}
float red = texture(u_image0, uv + offset).r;
float green = original.g;
float blue = texture(u_image0, uv - offset).b;
fragColor = vec4(red, green, blue, original.a);
}

View File

@@ -0,0 +1,78 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // temperature (-100 to 100)
uniform float u_float1; // tint (-100 to 100)
uniform float u_float2; // vibrance (-100 to 100)
uniform float u_float3; // saturation (-100 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const float INPUT_SCALE = 0.01;
const float TEMP_TINT_PRIMARY = 0.3;
const float TEMP_TINT_SECONDARY = 0.15;
const float VIBRANCE_BOOST = 2.0;
const float SATURATION_BOOST = 2.0;
const float SKIN_PROTECTION = 0.5;
const float EPSILON = 0.001;
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
// Scale inputs: -100/100 → -1/1
float temperature = u_float0 * INPUT_SCALE;
float tint = u_float1 * INPUT_SCALE;
float vibrance = u_float2 * INPUT_SCALE;
float saturation = u_float3 * INPUT_SCALE;
// Temperature (warm/cool): positive = warm, negative = cool
color.r += temperature * TEMP_TINT_PRIMARY;
color.b -= temperature * TEMP_TINT_PRIMARY;
// Tint (green/magenta): positive = green, negative = magenta
color.g += tint * TEMP_TINT_PRIMARY;
color.r -= tint * TEMP_TINT_SECONDARY;
color.b -= tint * TEMP_TINT_SECONDARY;
// Single clamp after temperature/tint
color = clamp(color, 0.0, 1.0);
// Vibrance with skin protection
if (vibrance != 0.0) {
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float sat = maxC - minC;
float gray = dot(color, LUMA_WEIGHTS);
if (vibrance < 0.0) {
// Desaturate: -100 → gray
color = mix(vec3(gray), color, 1.0 + vibrance);
} else {
// Boost less saturated colors more
float vibranceAmt = vibrance * (1.0 - sat);
// Branchless skin tone protection
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
float warmth = (color.r - color.b) / max(maxC, EPSILON);
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
}
}
// Saturation
if (saturation != 0.0) {
float gray = dot(color, LUMA_WEIGHTS);
float satMix = saturation < 0.0
? 1.0 + saturation // -100 → gray
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
color = mix(vec3(gray), color, satMix);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -0,0 +1,94 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Blur radius (020, default ~5)
uniform float u_float1; // Edge threshold (0100, default ~30)
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
in vec2 v_texCoord;
out vec4 fragColor;
const int MAX_RADIUS = 20;
const float EPSILON = 0.0001;
// Perceptual luminance
float getLuminance(vec3 rgb) {
return dot(rgb, vec3(0.299, 0.587, 0.114));
}
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
float sigmaSpatial, float sigmaColor)
{
vec4 center = texture(u_image0, uv);
vec3 centerRGB = center.rgb;
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
vec3 sumRGB = vec3(0.0);
float sumWeight = 0.0;
int step = max(u_int0, 1);
float radius2 = float(radius * radius);
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
if (dy < -radius || dy > radius) continue;
if (abs(dy) % step != 0) continue;
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
if (dx < -radius || dx > radius) continue;
if (abs(dx) % step != 0) continue;
vec2 offset = vec2(float(dx), float(dy));
float dist2 = dot(offset, offset);
if (dist2 > radius2) continue;
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
// Spatial Gaussian
float spatialWeight = exp(dist2 * invSpatial2);
// Perceptual color distance (weighted RGB)
vec3 diff = sampleRGB - centerRGB;
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
float colorWeight = exp(colorDist * invColor2);
float w = spatialWeight * colorWeight;
sumRGB += sampleRGB * w;
sumWeight += w;
}
}
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
return vec4(resultRGB, center.a); // preserve center alpha
}
void main() {
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
int radius = int(radiusF + 0.5);
if (radius == 0) {
fragColor = texture(u_image0, v_texCoord);
return;
}
// Edge threshold → color sigma
// Squared curve for better low-end control
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
t *= t;
float sigmaColor = mix(0.01, 0.5, t);
// Spatial sigma tied to radius
float sigmaSpatial = max(radiusF * 0.75, 0.5);
fragColor = bilateralFilter(
v_texCoord,
texelSize,
radius,
sigmaSpatial,
sigmaColor
);
}

View File

@@ -0,0 +1,124 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // grain amount [0.0 1.0] typical: 0.20.8
uniform float u_float1; // grain size [0.3 3.0] lower = finer grain
uniform float u_float2; // color amount [0.0 1.0] 0 = monochrome, 1 = RGB grain
uniform float u_float3; // luminance bias [0.0 1.0] 0 = uniform, 1 = shadows only
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// High-quality integer hash (pcg-like)
uint pcg(uint v) {
uint state = v * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
// 2D -> 1D hash input
uint hash2d(uvec2 p) {
return pcg(p.x + pcg(p.y));
}
// Hash to float [0, 1]
float hashf(uvec2 p) {
return float(hash2d(p)) / float(0xffffffffu);
}
// Hash to float with offset (for RGB channels)
float hashf(uvec2 p, uint offset) {
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
}
// Convert uniform [0,1] to roughly Gaussian distribution
// Using simple approximation: average of multiple samples
float toGaussian(uvec2 p) {
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
return (sum - 2.0) * 0.7; // Centered, scaled
}
float toGaussian(uvec2 p, uint offset) {
float sum = hashf(p, offset) + hashf(p, offset + 1u)
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
return (sum - 2.0) * 0.7;
}
// Smooth noise with better interpolation
float smoothNoise(vec2 p) {
vec2 i = floor(p);
vec2 f = fract(p);
// Quintic interpolation (less banding than cubic)
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui);
float b = toGaussian(ui + uvec2(1u, 0u));
float c = toGaussian(ui + uvec2(0u, 1u));
float d = toGaussian(ui + uvec2(1u, 1u));
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
float smoothNoise(vec2 p, uint offset) {
vec2 i = floor(p);
vec2 f = fract(p);
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui, offset);
float b = toGaussian(ui + uvec2(1u, 0u), offset);
float c = toGaussian(ui + uvec2(0u, 1u), offset);
float d = toGaussian(ui + uvec2(1u, 1u), offset);
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Luminance (Rec.709)
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
// Grain UV (resolution-independent)
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
uvec2 grainPixel = uvec2(grainUV);
float g;
vec3 grainRGB;
if (u_int0 == 1) {
// Grainy mode: pure hash noise (no interpolation = no banding)
g = toGaussian(grainPixel);
grainRGB = vec3(
toGaussian(grainPixel, 100u),
toGaussian(grainPixel, 200u),
toGaussian(grainPixel, 300u)
);
} else {
// Smooth mode: interpolated with quintic curve
g = smoothNoise(grainUV);
grainRGB = vec3(
smoothNoise(grainUV, 100u),
smoothNoise(grainUV, 200u),
smoothNoise(grainUV, 300u)
);
}
// Luminance weighting (less grain in highlights)
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
// Strength
float strength = u_float0 * 0.15;
// Color vs monochrome grain
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
color.rgb += grainColor * strength * lumWeight;
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
}

View File

@@ -0,0 +1,133 @@
#version 300 es
precision mediump float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint
uniform float u_float0; // Intensity
uniform float u_float1; // Radius
uniform float u_float2; // Threshold
in vec2 v_texCoord;
out vec4 fragColor;
const int BLEND_ADD = 0;
const int BLEND_SCREEN = 1;
const int BLEND_SOFT = 2;
const int BLEND_OVERLAY = 3;
const int BLEND_LIGHTEN = 4;
const float GOLDEN_ANGLE = 2.39996323;
const int MAX_SAMPLES = 48;
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
float hash(vec2 p) {
p = fract(p * vec2(123.34, 456.21));
p += dot(p, p + 45.32);
return fract(p.x * p.y);
}
vec3 hexToRgb(int h) {
return vec3(
float((h >> 16) & 255),
float((h >> 8) & 255),
float(h & 255)
) * (1.0 / 255.0);
}
vec3 blend(vec3 base, vec3 glow, int mode) {
if (mode == BLEND_SCREEN) {
return 1.0 - (1.0 - base) * (1.0 - glow);
}
if (mode == BLEND_SOFT) {
return mix(
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
step(0.5, glow)
);
}
if (mode == BLEND_OVERLAY) {
return mix(
2.0 * base * glow,
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
step(0.5, base)
);
}
if (mode == BLEND_LIGHTEN) {
return max(base, glow);
}
return base + glow;
}
void main() {
vec4 original = texture(u_image0, v_texCoord);
float intensity = u_float0 * 0.05;
float radius = u_float1 * u_float1 * 0.012;
if (intensity < 0.001 || radius < 0.1) {
fragColor = original;
return;
}
float threshold = 1.0 - u_float2 * 0.01;
float t0 = threshold - 0.15;
float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution;
float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
int samples = int(float(MAX_SAMPLES) * sampleScale);
float noise = hash(gl_FragCoord.xy);
float angleOffset = noise * GOLDEN_ANGLE;
float radiusJitter = 0.85 + noise * 0.3;
float ca = cos(GOLDEN_ANGLE);
float sa = sin(GOLDEN_ANGLE);
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
vec3 glow = vec3(0.0);
float totalWeight = 0.0;
// Center tap
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
glow += original.rgb * centerMask * 2.0;
totalWeight += 2.0;
for (int i = 1; i < MAX_SAMPLES; i++) {
if (i >= samples) break;
float fi = float(i);
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
vec2 offset = dir * dist * texelSize;
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
float mask = smoothstep(t0, t1, dot(c, LUMA));
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
w = max(w, 0.0);
w *= w;
glow += c * mask * w;
totalWeight += w;
dir = vec2(
dir.x * ca - dir.y * sa,
dir.x * sa + dir.y * ca
);
}
glow *= intensity / max(totalWeight, 0.001);
if (u_int1 > 0) {
glow *= hexToRgb(u_int1);
}
vec3 result = blend(original.rgb, glow, u_int0);
result += (noise - 0.5) * (1.0 / 255.0);
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
}

View File

@@ -0,0 +1,222 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
uniform float u_float0; // Hue (-180 to 180)
uniform float u_float1; // Saturation (-100 to 100)
uniform float u_float2; // Lightness/Brightness (-100 to 100)
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
in vec2 v_texCoord;
out vec4 fragColor;
// Color range modes
const int MODE_MASTER = 0;
const int MODE_RED = 1;
const int MODE_YELLOW = 2;
const int MODE_GREEN = 3;
const int MODE_CYAN = 4;
const int MODE_BLUE = 5;
const int MODE_MAGENTA = 6;
const int MODE_COLORIZE = 7;
// Color space modes
const int COLORSPACE_HSL = 0;
const int COLORSPACE_HSB = 1;
const float EPSILON = 0.0001;
//=============================================================================
// RGB <-> HSL Conversions
//=============================================================================
vec3 rgb2hsl(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = 0.0;
float l = (maxC + minC) * 0.5;
if (delta > EPSILON) {
s = l < 0.5
? delta / (maxC + minC)
: delta / (2.0 - maxC - minC);
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
t = fract(t);
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
if (t < 0.5) return q;
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
if (hsl.y < EPSILON) return vec3(hsl.z);
float q = hsl.z < 0.5
? hsl.z * (1.0 + hsl.y)
: hsl.z + hsl.y - hsl.z * hsl.y;
float p = 2.0 * hsl.z - q;
return vec3(
hue2rgb(p, q, hsl.x + 1.0/3.0),
hue2rgb(p, q, hsl.x),
hue2rgb(p, q, hsl.x - 1.0/3.0)
);
}
vec3 rgb2hsb(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
float b = maxC;
if (delta > EPSILON) {
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, b);
}
vec3 hsb2rgb(vec3 hsb) {
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
}
//=============================================================================
// Color Range Weight Calculation
//=============================================================================
float hueDistance(float a, float b) {
float d = abs(a - b);
return min(d, 1.0 - d);
}
float getHueWeight(float hue, float center, float overlap) {
float baseWidth = 1.0 / 6.0;
float feather = baseWidth * overlap;
float d = hueDistance(hue, center);
float inner = baseWidth * 0.5;
float outer = inner + feather;
return 1.0 - smoothstep(inner, outer, d);
}
float getModeWeight(float hue, int mode, float overlap) {
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
if (mode == MODE_RED) {
return max(
getHueWeight(hue, 0.0, overlap),
getHueWeight(hue, 1.0, overlap)
);
}
float center = float(mode - 1) / 6.0;
return getHueWeight(hue, center, overlap);
}
//=============================================================================
// Adjustment Functions
//=============================================================================
float adjustLightness(float l, float amount) {
return amount > 0.0
? l + (1.0 - l) * amount
: l + l * amount;
}
float adjustBrightness(float b, float amount) {
return clamp(b + amount, 0.0, 1.0);
}
float adjustSaturation(float s, float amount) {
return amount > 0.0
? s + (1.0 - s) * amount
: s + s * amount;
}
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
float l = adjustLightness(lum, light);
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
return hsl2rgb(hsl);
}
//=============================================================================
// Main
//=============================================================================
void main() {
vec4 original = texture(u_image0, v_texCoord);
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
vec3 result;
if (u_int0 == MODE_COLORIZE) {
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
fragColor = vec4(result, original.a);
return;
}
vec3 hsx = (u_int1 == COLORSPACE_HSL)
? rgb2hsl(original.rgb)
: rgb2hsb(original.rgb);
float weight = getModeWeight(hsx.x, u_int0, overlap);
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
weight = 0.0;
}
if (weight > EPSILON) {
float h = fract(hsx.x + hueShift * weight);
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
float v = (u_int1 == COLORSPACE_HSL)
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
vec3 adjusted = vec3(h, s, v);
result = (u_int1 == COLORSPACE_HSL)
? hsl2rgb(adjusted)
: hsb2rgb(adjusted);
} else {
result = original.rgb;
}
fragColor = vec4(result, original.a);
}

View File

@@ -0,0 +1,111 @@
#version 300 es
#pragma passes 2
precision highp float;
// Blur type constants
const int BLUR_GAUSSIAN = 0;
const int BLUR_BOX = 1;
const int BLUR_RADIAL = 2;
// Radial blur config
const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
void main() {
vec2 texelSize = 1.0 / u_resolution;
float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable
if (u_int0 == BLUR_RADIAL) {
// Only execute on first pass
if (u_pass > 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec2 center = vec2(0.5);
vec2 dir = v_texCoord - center;
float dist = length(dir);
if (dist < 1e-4) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec4 sum = vec4(0.0);
float totalWeight = 0.0;
float angleStep = radius * RADIAL_STRENGTH;
dir /= dist;
float cosStep = cos(angleStep);
float sinStep = sin(angleStep);
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
vec2 rotDir = vec2(
dir.x * cos(negAngle) - dir.y * sin(negAngle),
dir.x * sin(negAngle) + dir.y * cos(negAngle)
);
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
vec2 uv = center + rotDir * dist;
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
sum += texture(u_image0, uv) * w;
totalWeight += w;
rotDir = vec2(
rotDir.x * cosStep - rotDir.y * sinStep,
rotDir.x * sinStep + rotDir.y * cosStep
);
}
fragColor0 = sum / max(totalWeight, 0.001);
return;
}
// Separable Gaussian / Box blur
int samples = int(ceil(radius));
if (samples == 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
// Direction: pass 0 = horizontal, pass 1 = vertical
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
vec4 color = vec4(0.0);
float totalWeight = 0.0;
float sigma = radius / 2.0;
for (int i = -samples; i <= samples; i++) {
vec2 offset = dir * float(i) * texelSize;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float weight;
if (u_int0 == BLUR_GAUSSIAN) {
weight = gaussian(float(i), sigma);
} else {
// BLUR_BOX
weight = 1.0;
}
color += sample_color * weight;
totalWeight += weight;
}
fragColor0 = color / totalWeight;
}

View File

@@ -0,0 +1,19 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
layout(location = 1) out vec4 fragColor1;
layout(location = 2) out vec4 fragColor2;
layout(location = 3) out vec4 fragColor3;
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Output each channel as grayscale to separate render targets
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
}

View File

@@ -0,0 +1,71 @@
#version 300 es
precision highp float;
// Levels Adjustment
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
// u_float0: input black (0-255) default: 0
// u_float1: input white (0-255) default: 255
// u_float2: gamma (0.01-9.99) default: 1.0
// u_float3: output black (0-255) default: 0
// u_float4: output white (0-255) default: 255
uniform sampler2D u_image0;
uniform int u_int0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
result = pow(result, vec3(1.0 / gamma));
result = mix(vec3(outBlack), vec3(outWhite), result);
return result;
}
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
result = pow(result, 1.0 / gamma);
result = mix(outBlack, outWhite, result);
return result;
}
void main() {
vec4 texColor = texture(u_image0, v_texCoord);
vec3 color = texColor.rgb;
float inBlack = u_float0 / 255.0;
float inWhite = u_float1 / 255.0;
float gamma = u_float2;
float outBlack = u_float3 / 255.0;
float outWhite = u_float4 / 255.0;
vec3 result;
if (u_int0 == 0) {
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 1) {
result = color;
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 2) {
result = color;
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 3) {
result = color;
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
}
else {
result = color;
}
fragColor = vec4(result, texColor.a);
}

View File

@@ -0,0 +1,28 @@
# GLSL Shader Sources
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
## File Naming Convention
`{Blueprint_Name}_{node_id}.frag`
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
- **node_id**: The GLSLShader node ID within the subgraph
## Usage
```bash
# Extract shaders from blueprint JSONs to this folder
python update_blueprints.py extract
# Patch edited shaders back into blueprint JSONs
python update_blueprints.py patch
```
## Workflow
1. Run `extract` to pull current shaders from JSONs
2. Edit `.frag` files
3. Run `patch` to update the blueprint JSONs
4. Test
5. Commit both `.frag` files and updated JSONs

View File

@@ -0,0 +1,28 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
vec2 texel = 1.0 / u_resolution;
// Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord);
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
// Edge enhancement (Laplacian)
vec4 edges = center * 4.0 - top - bottom - left - right;
// Add edges back scaled by strength
vec4 sharpened = center + edges * u_float0;
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
}

View File

@@ -0,0 +1,61 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
float getLuminance(vec3 color) {
return dot(color, vec3(0.2126, 0.7152, 0.0722));
}
void main() {
vec2 texel = 1.0 / u_resolution;
float radius = max(u_float1, 0.5);
float amount = u_float0;
float threshold = u_float2;
vec4 original = texture(u_image0, v_texCoord);
// Gaussian blur for the "unsharp" mask
int samples = int(ceil(radius));
float sigma = radius / 2.0;
vec4 blurred = vec4(0.0);
float totalWeight = 0.0;
for (int x = -samples; x <= samples; x++) {
for (int y = -samples; y <= samples; y++) {
vec2 offset = vec2(float(x), float(y)) * texel;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float dist = length(vec2(float(x), float(y)));
float weight = gaussian(dist, sigma);
blurred += sample_color * weight;
totalWeight += weight;
}
}
blurred /= totalWeight;
// Unsharp mask = original - blurred
vec3 mask = original.rgb - blurred.rgb;
// Luminance-based threshold with smooth falloff
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
mask *= thresholdScale;
// Sharpen: original + mask * amount
vec3 sharpened = original.rgb + mask * amount;
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
}

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""
Shader Blueprint Updater
Syncs GLSL shader files between this folder and blueprint JSON files.
File naming convention:
{Blueprint Name}_{node_id}.frag
Usage:
python update_blueprints.py extract # Extract shaders from JSONs to here
python update_blueprints.py patch # Patch shaders back into JSONs
python update_blueprints.py # Same as patch (default)
"""
import json
import logging
import sys
import re
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
GLSL_DIR = Path(__file__).parent
BLUEPRINTS_DIR = GLSL_DIR.parent
def get_blueprint_files():
"""Get all blueprint JSON files."""
return sorted(BLUEPRINTS_DIR.glob("*.json"))
def sanitize_filename(name):
"""Convert blueprint name to safe filename."""
return re.sub(r'[^\w\-]', '_', name)
def extract_shaders():
"""Extract all shaders from blueprint JSONs to this folder."""
extracted = 0
for json_path in get_blueprint_files():
blueprint_name = json_path.stem
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning("Skipping %s: %s", json_path.name, e)
continue
# Find GLSLShader nodes in subgraphs
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('type') == 'GLSLShader':
node_id = node.get('id')
widgets = node.get('widgets_values', [])
# Find shader code (first string that looks like GLSL)
for widget in widgets:
if isinstance(widget, str) and widget.startswith('#version'):
safe_name = sanitize_filename(blueprint_name)
frag_name = f"{safe_name}_{node_id}.frag"
frag_path = GLSL_DIR / frag_name
with open(frag_path, 'w') as f:
f.write(widget)
logger.info(" Extracted: %s", frag_name)
extracted += 1
break
logger.info("\nExtracted %d shader(s)", extracted)
def patch_shaders():
"""Patch shaders from this folder back into blueprint JSONs."""
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
shader_updates = {}
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
# Parse filename: {blueprint_name}_{node_id}.frag
parts = frag_path.stem.rsplit('_', 1)
if len(parts) != 2:
logger.warning("Skipping %s: invalid filename format", frag_path.name)
continue
blueprint_name, node_id_str = parts
try:
node_id = int(node_id_str)
except ValueError:
logger.warning("Skipping %s: invalid node_id", frag_path.name)
continue
with open(frag_path, 'r') as f:
shader_code = f.read()
if blueprint_name not in shader_updates:
shader_updates[blueprint_name] = []
shader_updates[blueprint_name].append((node_id, shader_code))
# Apply updates to JSON files
patched = 0
for json_path in get_blueprint_files():
blueprint_name = sanitize_filename(json_path.stem)
if blueprint_name not in shader_updates:
continue
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.error("Error reading %s: %s", json_path.name, e)
continue
modified = False
for node_id, shader_code in shader_updates[blueprint_name]:
# Find the node and update
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
widgets = node.get('widgets_values', [])
if len(widgets) > 0 and widgets[0] != shader_code:
widgets[0] = shader_code
modified = True
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
patched += 1
if modified:
with open(json_path, 'w') as f:
json.dump(data, f)
if patched == 0:
logger.info("No changes to apply.")
else:
logger.info("\nPatched %d shader(s)", patched)
def main():
if len(sys.argv) < 2:
command = "patch"
else:
command = sys.argv[1].lower()
if command == "extract":
logger.info("Extracting shaders from blueprints...")
extract_shaders()
elif command in ("patch", "update", "apply"):
logger.info("Patching shaders into blueprints...")
patch_shaders()
else:
logger.info(__doc__)
sys.exit(1)
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
blueprints/Glow.json Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}

1
blueprints/Sharpen.json Normal file
View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}

View File

@@ -1,13 +0,0 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):

View File

@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)

View File

@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:

View File

@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@@ -142,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -231,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
class CompressedTimestep:
@@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
operations=self.operations,
)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(

View File

@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward(
self,
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):

View File

@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]

View File

@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_allocator = None
aimdo_enabled = False

View File

@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs):
return None
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -986,10 +988,14 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@@ -1165,7 +1171,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

View File

@@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
return torch_dev
else:
return cpu_dev
@@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@@ -406,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
@@ -1514,8 +1517,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
return (False, 0)
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1529,7 +1534,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1537,13 +1548,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
@@ -1606,6 +1623,11 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):

View File

@@ -21,7 +21,6 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
@@ -80,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None
xfer_dest = None
@@ -171,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
@@ -195,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -213,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -463,7 +462,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
@@ -475,8 +474,7 @@ class disable_weight_init:
weight = None
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -852,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -883,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
# Reshape output back to 3D if input was 3D
if reshaped_3d:

View File

@@ -1,57 +1,10 @@
import torch
import comfy.model_management
import numbers
import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)

View File

@@ -423,6 +423,17 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@@ -1182,6 +1193,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
def detect_te_model(sd):
@@ -1210,7 +1222,10 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
if 'vision_model.embeddings.patch_embedding.weight' in sd:
return TEModel.GEMMA_3_4B_VISION
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1270,6 +1285,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
if "lm_head.weight" in clip_data[i]:
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {}
clip_target = EmptyClass()
@@ -1335,6 +1352,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B_VISION:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_12B:
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)

View File

@@ -308,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string):
result = []
current_item = ""
@@ -564,6 +573,8 @@ class SDTokenizer:
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
min_length = kwargs.get("min_length", min_length)
text = escape_important(text)
if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)]
@@ -663,6 +674,9 @@ class SDTokenizer:
def state_dict(self):
return {}
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
@@ -686,6 +700,9 @@ class SD1Tokenizer:
def state_dict(self):
return getattr(self, self.clip).state_dict()
def decode(self, token_ids, skip_special_tokens=True):
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@@ -722,3 +739,6 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@@ -33,6 +33,8 @@ class AnimaTokenizer:
def state_dict(self):
return {}
def decode(self, token_ids, **kwargs):
return self.qwen3_06b.decode(token_ids, **kwargs)
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):

View File

@@ -3,6 +3,8 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any, Tuple
import math
from tqdm import tqdm
import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@@ -103,6 +105,7 @@ class Qwen3_06BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_06B_ACE15_Config:
@@ -126,6 +129,7 @@ class Qwen3_06B_ACE15_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_2B_ACE15_lm_Config:
@@ -149,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4B_ACE15_lm_Config:
@@ -172,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4BConfig:
@@ -195,6 +201,7 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_8BConfig:
@@ -218,6 +225,7 @@ class Qwen3_8BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Ovis25_2BConfig:
@@ -288,6 +296,7 @@ class Gemma2_2B_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1]
@dataclass
class Gemma3_4B_Config:
@@ -312,6 +321,14 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1, 106]
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
@dataclass
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
@dataclass
class Gemma3_12B_Config:
@@ -336,8 +353,9 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
stop_tokens = [1, 106]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -441,8 +459,10 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
sliding_window: Optional[int] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -477,6 +497,11 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -559,10 +584,12 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
sliding_window = None
if self.transformer_type == 'gemma3':
if self.sliding_attention:
sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
@@ -581,6 +608,7 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
sliding_window=sliding_window,
)
x = self.post_attention_layernorm(x)
@@ -765,6 +793,107 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class BaseGenerate:
def logits(self, x):
input = x[:, -1:]
if hasattr(self.model, "lm_head"):
module = self.model.lm_head
else:
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds = embeds.to(execution_dtype)
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
pbar.update(1)
if token_id in stop_tokens:
break
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min
probs = torch.nn.functional.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=generator)
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
@@ -808,7 +937,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -835,7 +964,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -853,7 +982,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
@@ -871,7 +1000,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
@@ -881,6 +1010,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
# todo: should this be tied or not?
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@@ -914,7 +1046,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, torch.nn.Module):
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)
@@ -923,7 +1055,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
@@ -932,7 +1064,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module):
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Vision_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)

View File

@@ -3,9 +3,9 @@ import os
from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
import math
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -22,40 +22,79 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
class Gemma3_Tokenizer():
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(896 * 896)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
images = [s[:, :, :, :3]]
if text.startswith('<start_of_turn>'):
skip_template = True
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
if len(images) > 0:
embed_count = 0
for r in text_tokens:
for i, token in enumerate(r):
if token[0] == 262144 and embed_count < len(images):
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
embed_count += 1
return text_tokens
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
self.dtypes = set()
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(text, return_word_ids)
embed_count = 0
for k in text_tokens:
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
@@ -69,22 +108,6 @@ class LTXAVTEModel(torch.nn.Module):
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options)
@@ -106,24 +129,23 @@ class LTXAVTEModel(torch.nn.Module):
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing_all = []
unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
@@ -152,3 +174,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Gemma3_12BModel_

View File

@@ -1,23 +1,23 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
from comfy.text_encoders.lt import Gemma3_Tokenizer
import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
special_tokens = {"<end_of_turn>": 107}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -40,6 +40,20 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@@ -50,6 +64,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
elif model_type == "gemma3_4b_vision":
model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@@ -6,9 +6,10 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
self.add_bos = add_bos
self.add_eos = add_eos
self.special_tokens = special_tokens
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
@@ -27,8 +28,32 @@ class SPieceTokenizer:
return out
def __call__(self, string):
if self.special_tokens is not None:
import re
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
if special_tokens_pattern and re.search(special_tokens_pattern, string):
parts = re.split(f'({special_tokens_pattern})', string)
result = []
for part in parts:
if not part:
continue
if part in self.special_tokens:
result.append(self.special_tokens[part])
else:
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
result.extend(encoded)
return {"input_ids": result}
out = self.tokenizer.encode(string)
return {"input_ids": out}
def decode(self, token_ids, skip_special_tokens=False):
if skip_special_tokens and self.special_tokens:
special_token_ids = set(self.special_tokens.values())
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
return self.tokenizer.decode(token_ids)
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@@ -20,7 +20,7 @@
import torch
import math
import struct
import comfy.checkpoint_pickle
import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
@@ -38,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
return None
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
def encode(*args, **kwargs): # no longer necessary on newer torch
return None
encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0
_TYPES = {
@@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -1157,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
@@ -1421,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight(
self,
weight,

View File

@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight(
self,
weight,

View File

@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""

View File

@@ -444,7 +444,7 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))

View File

@@ -73,8 +73,15 @@ class RemoteOptions:
class NumberDisplay(str, Enum):
number = "number"
slider = "slider"
gradient_slider = "gradientslider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +270,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -290,13 +297,15 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
self.max = max
self.step = step
self.round = round
self.display_mode = display_mode
self.gradient_stops = gradient_stops
self.default: float
def as_dict(self):
@@ -306,6 +315,7 @@ class Float(ComfyTypeIO):
"step": self.step,
"round": self.round,
"display": self.display_mode,
"gradient_stops": self.gradient_stops,
})
@comfytype(io_type="STRING")
@@ -345,7 +355,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +399,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -1203,54 +1213,89 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR_CORRECT")
class ColorCorrect(ComfyTypeIO):
Type = dict
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
Type = list
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"temperature": 0,
"hue": 0,
"brightness": 0,
"contrast": 0,
"saturation": 0,
"gamma": 1.0
}
self.default = [[0, 0], [1, 1]]
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR_BALANCE")
class ColorBalance(ComfyTypeIO):
Type = dict
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
Type = dict # {"min": float, "max": float, "midpoint"?: float}
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"shadows_red": 0,
"shadows_green": 0,
"shadows_blue": 0,
"midtones_red": 0,
"midtones_green": 0,
"midtones_blue": 0,
"highlights_red": 0,
"highlights_green": 0,
"highlights_blue": 0
}
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict()
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
@comfytype(io_type="COLOR_CURVES")
class ColorCurves(ComfyTypeIO):
Type = dict
class ColorCurvesDict(TypedDict):
rgb: list[list[float]]
red: list[list[float]]
green: list[list[float]]
blue: list[list[float]]
Type = ColorCurvesDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1267,6 +1312,7 @@ class ColorCurves(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -1373,6 +1419,7 @@ class NodeInfoV1:
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
essentials_category: str=None
@dataclass
@@ -1494,6 +1541,8 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
def validate(self):
'''Validate the schema:
@@ -1600,6 +1649,7 @@ class Schema:
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
essentials_category=self.essentials_category,
)
return info
@@ -2094,11 +2144,74 @@ class _UIOutput(ABC):
...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
"ControlAfterGenerate",
"comfytype",
"Custom",
@@ -2185,7 +2298,8 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"ColorCorrect",
"ColorBalance",
"ColorCurves"
"BoundingBox",
"Curve",
"ColorCurves",
"NodeReplace",
]

View File

@@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)
class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
class BriaResult(BaseModel):
class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)
class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)
class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
class BriaResponse(BaseModel):
class BriaImageEditResponse(BaseModel):
status: str = Field(...)
result: BriaResult | None = Field(None)
result: BriaImageEditResult | None = Field(None)
class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)

View File

@@ -0,0 +1,88 @@
from pydantic import BaseModel, Field
class SpeechToTextRequest(BaseModel):
model_id: str = Field(...)
cloud_storage_url: str = Field(...)
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
tag_audio_events: bool | None = Field(None, description="Annotate sounds like (laughter) in transcript")
num_speakers: int | None = Field(None, description="Max speakers predicted")
timestamps_granularity: str = Field(default="word", description="Timing precision: none, word, or character")
diarize: bool | None = Field(None, description="Annotate which speaker is talking")
diarization_threshold: float | None = Field(None, description="Speaker separation sensitivity")
temperature: float | None = Field(None, description="Randomness control")
seed: int = Field(..., description="Seed for deterministic sampling")
class SpeechToTextWord(BaseModel):
text: str = Field(..., description="The word text")
type: str = Field(default="word", description="Type of text element (word, spacing, etc.)")
start: float | None = Field(None, description="Start time in seconds (when timestamps enabled)")
end: float | None = Field(None, description="End time in seconds (when timestamps enabled)")
speaker_id: str | None = Field(None, description="Speaker identifier when diarization is enabled")
logprob: float | None = Field(None, description="Log probability of the word")
class SpeechToTextResponse(BaseModel):
language_code: str = Field(..., description="Detected or specified language code")
language_probability: float | None = Field(None, description="Confidence of language detection")
text: str = Field(..., description="Full transcript text")
words: list[SpeechToTextWord] | None = Field(None, description="Word-level timing information")
class TextToSpeechVoiceSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability")
similarity_boost: float | None = Field(None, description="Similarity boost")
style: float | None = Field(None, description="Style exaggeration")
use_speaker_boost: bool | None = Field(None, description="Boost similarity to original speaker")
speed: float | None = Field(None, description="Speech speed")
class TextToSpeechRequest(BaseModel):
text: str = Field(..., description="Text to convert to speech")
model_id: str = Field(..., description="Model ID for TTS")
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
voice_settings: TextToSpeechVoiceSettings | None = Field(None, description="Voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")
class TextToSoundEffectsRequest(BaseModel):
text: str = Field(..., description="Text prompt to convert into a sound effect")
duration_seconds: float = Field(..., description="Duration of generated sound in seconds")
prompt_influence: float = Field(..., description="How closely generation follows the prompt")
loop: bool | None = Field(None, description="Whether to create a smoothly looping sound effect")
class AddVoiceRequest(BaseModel):
name: str = Field(..., description="Name that identifies the voice")
remove_background_noise: bool = Field(..., description="Remove background noise from voice samples")
class AddVoiceResponse(BaseModel):
voice_id: str = Field(..., description="The newly created voice's unique identifier")
class SpeechToSpeechRequest(BaseModel):
model_id: str = Field(..., description="Model ID for speech-to-speech")
voice_settings: str = Field(..., description="JSON string of voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
remove_background_noise: bool = Field(..., description="Remove background noise from input audio")
class DialogueInput(BaseModel):
text: str = Field(..., description="Text content to convert to speech")
voice_id: str = Field(..., description="Voice identifier for this dialogue segment")
class DialogueSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability (0-1)")
class TextToDialogueRequest(BaseModel):
inputs: list[DialogueInput] = Field(..., description="List of dialogue segments")
model_id: str = Field(..., description="Model ID for dialogue generation")
language_code: str | None = Field(None, description="ISO-639-1 language code")
settings: DialogueSettings | None = Field(None, description="Voice settings")
seed: int | None = Field(None, description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")

View File

@@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageOutputOptions(BaseModel):
mimeType: str = Field("image/png")
compressionQuality: int | None = Field(None)
class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiImageGenerationConfig(GeminiGenerationConfig):

View File

@@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
Url: str = Field(...)
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)

View File

@@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
}
class RecraftModel(str, Enum):
recraftv3 = 'recraftv3'
recraftv2 = 'recraftv2'
class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
@@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024'
RECRAFT_V4_SIZES = [
"1024x1024",
"1536x768",
"768x1536",
"1280x832",
"832x1280",
"1216x896",
"896x1216",
"1152x896",
"896x1152",
"832x1344",
"1280x896",
"896x1280",
"1344x768",
"768x1344",
]
RECRAFT_V4_PRO_SIZES = [
"2048x2048",
"3072x1536",
"1536x3072",
"2560x1664",
"1664x2560",
"2432x1792",
"1792x2432",
"2304x1792",
"1792x2304",
"1664x2688",
"1434x1024",
"1024x1434",
"2560x1792",
"1792x2560",
]
class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
@@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
# text_layout
class RecraftReturnedObject(BaseModel):

View File

@@ -57,6 +57,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
advanced=True,
),
IO.Int.Input(
"seed",
@@ -200,6 +201,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
advanced=True,
),
IO.Image.Input(
"input_image",
@@ -296,6 +298,7 @@ class FluxProExpandNode(IO.ComfyNode):
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
advanced=True,
),
IO.Int.Input(
"top",
@@ -433,6 +436,7 @@ class FluxProFillNode(IO.ComfyNode):
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
advanced=True,
),
IO.Float.Input(
"guidance",
@@ -577,6 +581,7 @@ class Flux2ProImageNode(IO.ComfyNode):
default=True,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation.",
advanced=True,
),
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
],

View File

@@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaResponse,
BriaRemoveBackgroundRequest,
BriaRemoveBackgroundResponse,
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaImageEditResponse,
BriaStatusResponse,
InputModerationSettings,
)
@@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_video_duration,
)
@@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input(
"prompt_content_moderation", default=False
),
IO.Boolean.Input(
"visual_input_moderation", default=False
),
IO.Boolean.Input(
"visual_output_moderation", default=True
),
IO.Boolean.Input("prompt_content_moderation", default=False),
IO.Boolean.Input("visual_input_moderation", default=False),
IO.Boolean.Input("visual_output_moderation", default=True),
],
),
IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
@@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
raise ValueError(
"One of prompt or structured_prompt is required to be non-empty."
)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
mask_url = None
if mask is not None:
mask_url = (
await upload_images_to_comfyapi(
cls,
convert_mask_to_image(mask),
max_images=1,
mime_type="image/png",
wait_label="Uploading mask",
)
)[0]
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
wait_label="Uploading image",
),
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
prompt_content_moderation=moderation.get(
"prompt_content_moderation", False
),
visual_input_content_moderation=moderation.get(
"visual_input_moderation", False
),
visual_output_content_moderation=moderation.get(
"visual_output_moderation", False
),
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
),
response_model=BriaStatusResponse,
)
@@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaResponse,
response_model=BriaImageEditResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
@@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
)
class BriaRemoveImageBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="api node/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input("visual_input_moderation", default=False),
IO.Boolean.Input("visual_output_moderation", default=True),
],
),
],
tooltip="Moderation settings",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.018}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
moderation: dict,
seed: int,
) -> IO.NodeOutput:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
data=BriaRemoveBackgroundRequest(
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
sync=False,
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
class BriaRemoveVideoBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="api node/video/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"background_color",
options=[
"Black",
"White",
"Gray",
"Red",
"Green",
"Blue",
"Yellow",
"Cyan",
"Magenta",
"Orange",
],
tooltip="Background color for the output video.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
background_color: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
data=BriaRemoveVideoBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_color=background_color,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
]

View File

@@ -114,6 +114,7 @@ class ByteDanceImageNode(IO.ComfyNode):
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
advanced=True,
),
],
outputs=[
@@ -259,12 +260,14 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=False,
tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True,
advanced=True,
),
IO.Boolean.Input(
"fail_on_partial",
default=True,
tooltip="If enabled, abort execution if any requested images are missing or return an error.",
optional=True,
advanced=True,
),
],
outputs=[
@@ -432,18 +435,21 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
advanced=True,
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
advanced=True,
),
IO.Boolean.Input(
"generate_audio",
default=False,
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
optional=True,
advanced=True,
),
],
outputs=[
@@ -561,18 +567,21 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
advanced=True,
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
advanced=True,
),
IO.Boolean.Input(
"generate_audio",
default=False,
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
optional=True,
advanced=True,
),
],
outputs=[
@@ -694,18 +703,21 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
advanced=True,
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
advanced=True,
),
IO.Boolean.Input(
"generate_audio",
default=False,
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
optional=True,
advanced=True,
),
],
outputs=[
@@ -834,6 +846,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
advanced=True,
),
],
outputs=[

View File

@@ -0,0 +1,924 @@
import json
import uuid
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.elevenlabs import (
AddVoiceRequest,
AddVoiceResponse,
DialogueInput,
DialogueSettings,
SpeechToSpeechRequest,
SpeechToTextRequest,
SpeechToTextResponse,
TextToDialogueRequest,
TextToSoundEffectsRequest,
TextToSpeechRequest,
TextToSpeechVoiceSettings,
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
audio_ndarray_to_bytesio,
audio_tensor_to_contiguous_ndarray,
sync_op,
sync_op_raw,
upload_audio_to_comfyapi,
validate_string,
)
ELEVENLABS_MUSIC_SECTIONS = "ELEVENLABS_MUSIC_SECTIONS" # Custom type for music sections
ELEVENLABS_COMPOSITION_PLAN = "ELEVENLABS_COMPOSITION_PLAN" # Custom type for composition plan
ELEVENLABS_VOICE = "ELEVENLABS_VOICE" # Custom type for voice selection
# Predefined ElevenLabs voices: (voice_id, display_name, gender, accent)
ELEVENLABS_VOICES = [
("CwhRBWXzGAHq8TQ4Fs17", "Roger", "male", "american"),
("EXAVITQu4vr4xnSDxMaL", "Sarah", "female", "american"),
("FGY2WhTYpPnrIDTdsKH5", "Laura", "female", "american"),
("IKne3meq5aSn9XLyUdCD", "Charlie", "male", "australian"),
("JBFqnCBsd6RMkjVDRZzb", "George", "male", "british"),
("N2lVS1w4EtoT3dr4eOWO", "Callum", "male", "american"),
("SAz9YHcvj6GT2YYXdXww", "River", "neutral", "american"),
("SOYHLrjzK2X1ezoPC6cr", "Harry", "male", "american"),
("TX3LPaxmHKxFdv7VOQHJ", "Liam", "male", "american"),
("Xb7hH8MSUJpSbSDYk0k2", "Alice", "female", "british"),
("XrExE9yKIg1WjnnlVkGX", "Matilda", "female", "american"),
("bIHbv24MWmeRgasZH58o", "Will", "male", "american"),
("cgSgspJ2msm6clMCkdW9", "Jessica", "female", "american"),
("cjVigY5qzO86Huf0OWal", "Eric", "male", "american"),
("hpp4J3VqNfWAUOO0d1Us", "Bella", "female", "american"),
("iP95p4xoKVk53GoZ742B", "Chris", "male", "american"),
("nPczCjzI2devNBz1zQrb", "Brian", "male", "american"),
("onwK4e9ZLuTAKqWW03F9", "Daniel", "male", "british"),
("pFZP5JQG7iQjIQuC4Bku", "Lily", "female", "british"),
("pNInz6obpgDQGcFmaJgB", "Adam", "male", "american"),
("pqHfZKP75CvOlQylNhV4", "Bill", "male", "american"),
]
ELEVENLABS_VOICE_OPTIONS = [f"{name} ({gender}, {accent})" for _, name, gender, accent in ELEVENLABS_VOICES]
ELEVENLABS_VOICE_MAP = {
f"{name} ({gender}, {accent})": voice_id for voice_id, name, gender, accent in ELEVENLABS_VOICES
}
class ElevenLabsSpeechToText(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsSpeechToText",
display_name="ElevenLabs Speech to Text",
category="api node/audio/ElevenLabs",
description="Transcribe audio to text. "
"Supports automatic language detection, speaker diarization, and audio event tagging.",
inputs=[
IO.Audio.Input(
"audio",
tooltip="Audio to transcribe.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"scribe_v2",
[
IO.Boolean.Input(
"tag_audio_events",
default=False,
tooltip="Annotate sounds like (laughter), (music), etc. in transcript.",
),
IO.Boolean.Input(
"diarize",
default=False,
tooltip="Annotate which speaker is talking.",
),
IO.Float.Input(
"diarization_threshold",
default=0.22,
min=0.1,
max=0.4,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Speaker separation sensitivity. "
"Lower values are more sensitive to speaker changes.",
),
IO.Float.Input(
"temperature",
default=0.0,
min=0.0,
max=2.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. "
"0.0 uses model default. Higher values increase randomness.",
),
IO.Combo.Input(
"timestamps_granularity",
options=["word", "character", "none"],
default="word",
tooltip="Timing precision for transcript words.",
),
],
),
],
tooltip="Model to use for transcription.",
),
IO.String.Input(
"language_code",
default="",
tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). "
"Leave empty for automatic detection.",
),
IO.Int.Input(
"num_speakers",
default=0,
min=0,
max=32,
display_mode=IO.NumberDisplay.slider,
tooltip="Maximum number of speakers to predict. Set to 0 for automatic detection.",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
tooltip="Seed for reproducibility (determinism not guaranteed).",
),
],
outputs=[
IO.String.Output(display_name="text"),
IO.String.Output(display_name="language_code"),
IO.String.Output(display_name="words_json"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.0073,"format":{"approximate":true,"suffix":"/minute"}}""",
),
)
@classmethod
async def execute(
cls,
audio: Input.Audio,
model: dict,
language_code: str,
num_speakers: int,
seed: int,
) -> IO.NodeOutput:
if model["diarize"] and num_speakers:
raise ValueError(
"Number of speakers cannot be specified when diarization is enabled. "
"Either disable diarization or set num_speakers to 0."
)
request = SpeechToTextRequest(
model_id=model["model"],
cloud_storage_url=await upload_audio_to_comfyapi(
cls, audio, container_format="mp4", codec_name="aac", mime_type="audio/mp4"
),
language_code=language_code if language_code.strip() else None,
tag_audio_events=model["tag_audio_events"],
num_speakers=num_speakers if num_speakers > 0 else None,
timestamps_granularity=model["timestamps_granularity"],
diarize=model["diarize"],
diarization_threshold=model["diarization_threshold"] if model["diarize"] else None,
seed=seed,
temperature=model["temperature"],
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/elevenlabs/v1/speech-to-text", method="POST"),
response_model=SpeechToTextResponse,
data=request,
content_type="multipart/form-data",
)
words_json = json.dumps(
[w.model_dump(exclude_none=True) for w in response.words] if response.words else [],
indent=2,
)
return IO.NodeOutput(response.text, response.language_code, words_json)
class ElevenLabsVoiceSelector(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsVoiceSelector",
display_name="ElevenLabs Voice Selector",
category="api node/audio/ElevenLabs",
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
inputs=[
IO.Combo.Input(
"voice",
options=ELEVENLABS_VOICE_OPTIONS,
tooltip="Choose a voice from the predefined ElevenLabs voices.",
),
],
outputs=[
IO.Custom(ELEVENLABS_VOICE).Output(display_name="voice"),
],
is_api_node=False,
)
@classmethod
def execute(cls, voice: str) -> IO.NodeOutput:
voice_id = ELEVENLABS_VOICE_MAP.get(voice)
if not voice_id:
raise ValueError(f"Unknown voice: {voice}")
return IO.NodeOutput(voice_id)
class ElevenLabsTextToSpeech(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsTextToSpeech",
display_name="ElevenLabs Text to Speech",
category="api node/audio/ElevenLabs",
description="Convert text to speech.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
"voice",
tooltip="Voice to use for speech synthesis. Connect from Voice Selector or Instant Voice Clone.",
),
IO.String.Input(
"text",
multiline=True,
default="",
tooltip="The text to convert to speech.",
),
IO.Float.Input(
"stability",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Voice stability. Lower values give broader emotional range, "
"higher values produce more consistent but potentially monotonous speech.",
),
IO.Combo.Input(
"apply_text_normalization",
options=["auto", "on", "off"],
tooltip="Text normalization mode. 'auto' lets the system decide, "
"'on' always applies normalization, 'off' skips it.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"eleven_multilingual_v2",
[
IO.Float.Input(
"speed",
default=1.0,
min=0.7,
max=1.3,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.",
),
IO.Float.Input(
"similarity_boost",
default=0.75,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Similarity boost. Higher values make the voice more similar to the original.",
),
IO.Boolean.Input(
"use_speaker_boost",
default=False,
tooltip="Boost similarity to the original speaker voice.",
),
IO.Float.Input(
"style",
default=0.0,
min=0.0,
max=0.2,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Style exaggeration. Higher values increase stylistic expression "
"but may reduce stability.",
),
],
),
IO.DynamicCombo.Option(
"eleven_v3",
[
IO.Float.Input(
"speed",
default=1.0,
min=0.7,
max=1.3,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.",
),
IO.Float.Input(
"similarity_boost",
default=0.75,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Similarity boost. Higher values make the voice more similar to the original.",
),
],
),
],
tooltip="Model to use for text-to-speech.",
),
IO.String.Input(
"language_code",
default="",
tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). "
"Leave empty for automatic detection.",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
tooltip="Seed for reproducibility (determinism not guaranteed).",
),
IO.Combo.Input(
"output_format",
options=["mp3_44100_192", "opus_48000_192"],
tooltip="Audio output format.",
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""",
),
)
@classmethod
async def execute(
cls,
voice: str,
text: str,
stability: float,
apply_text_normalization: str,
model: dict,
language_code: str,
seed: int,
output_format: str,
) -> IO.NodeOutput:
validate_string(text, min_length=1)
request = TextToSpeechRequest(
text=text,
model_id=model["model"],
language_code=language_code if language_code.strip() else None,
voice_settings=TextToSpeechVoiceSettings(
stability=stability,
similarity_boost=model["similarity_boost"],
speed=model["speed"],
use_speaker_boost=model.get("use_speaker_boost", None),
style=model.get("style", None),
),
seed=seed,
apply_text_normalization=apply_text_normalization,
)
response = await sync_op_raw(
cls,
ApiEndpoint(
path=f"/proxy/elevenlabs/v1/text-to-speech/{voice}",
method="POST",
query_params={"output_format": output_format},
),
data=request,
as_binary=True,
)
return IO.NodeOutput(audio_bytes_to_audio_input(response))
class ElevenLabsAudioIsolation(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsAudioIsolation",
display_name="ElevenLabs Voice Isolation",
category="api node/audio/ElevenLabs",
description="Remove background noise from audio, isolating vocals or speech.",
inputs=[
IO.Audio.Input(
"audio",
tooltip="Audio to process for background noise removal.",
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""",
),
)
@classmethod
async def execute(
cls,
audio: Input.Audio,
) -> IO.NodeOutput:
audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"])
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac")
response = await sync_op_raw(
cls,
ApiEndpoint(path="/proxy/elevenlabs/v1/audio-isolation", method="POST"),
files={"audio": ("audio.mp4", audio_bytes_io, "audio/mp4")},
content_type="multipart/form-data",
as_binary=True,
)
return IO.NodeOutput(audio_bytes_to_audio_input(response))
class ElevenLabsTextToSoundEffects(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsTextToSoundEffects",
display_name="ElevenLabs Text to Sound Effects",
category="api node/audio/ElevenLabs",
description="Generate sound effects from text descriptions.",
inputs=[
IO.String.Input(
"text",
multiline=True,
default="",
tooltip="Text description of the sound effect to generate.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"eleven_sfx_v2",
[
IO.Float.Input(
"duration",
default=5.0,
min=0.5,
max=30.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of generated sound in seconds.",
),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Create a smoothly looping sound effect.",
),
IO.Float.Input(
"prompt_influence",
default=0.3,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="How closely generation follows the prompt. "
"Higher values make the sound follow the text more closely.",
),
],
),
],
tooltip="Model to use for sound effect generation.",
),
IO.Combo.Input(
"output_format",
options=["mp3_44100_192", "opus_48000_192"],
tooltip="Audio output format.",
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"approximate":true,"suffix":"/minute"}}""",
),
)
@classmethod
async def execute(
cls,
text: str,
model: dict,
output_format: str,
) -> IO.NodeOutput:
validate_string(text, min_length=1)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/elevenlabs/v1/sound-generation",
method="POST",
query_params={"output_format": output_format},
),
data=TextToSoundEffectsRequest(
text=text,
duration_seconds=model["duration"],
prompt_influence=model["prompt_influence"],
loop=model.get("loop", None),
),
as_binary=True,
)
return IO.NodeOutput(audio_bytes_to_audio_input(response))
class ElevenLabsInstantVoiceClone(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsInstantVoiceClone",
display_name="ElevenLabs Instant Voice Clone",
category="api node/audio/ElevenLabs",
description="Create a cloned voice from audio samples. "
"Provide 1-8 audio recordings of the voice to clone.",
inputs=[
IO.Autogrow.Input(
"files",
template=IO.Autogrow.TemplatePrefix(
IO.Audio.Input("audio"),
prefix="audio",
min=1,
max=8,
),
tooltip="Audio recordings for voice cloning.",
),
IO.Boolean.Input(
"remove_background_noise",
default=False,
tooltip="Remove background noise from voice samples using audio isolation.",
),
],
outputs=[
IO.Custom(ELEVENLABS_VOICE).Output(display_name="voice"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr="""{"type":"usd","usd":0.15}"""),
)
@classmethod
async def execute(
cls,
files: IO.Autogrow.Type,
remove_background_noise: bool,
) -> IO.NodeOutput:
file_tuples: list[tuple[str, tuple[str, bytes, str]]] = []
for key in files:
audio = files[key]
sample_rate: int = audio["sample_rate"]
waveform = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, "mp4", "aac")
file_tuples.append(("files", (f"{key}.mp4", audio_bytes_io.getvalue(), "audio/mp4")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/elevenlabs/v1/voices/add", method="POST"),
response_model=AddVoiceResponse,
data=AddVoiceRequest(
name=str(uuid.uuid4()),
remove_background_noise=remove_background_noise,
),
files=file_tuples,
content_type="multipart/form-data",
)
return IO.NodeOutput(response.voice_id)
ELEVENLABS_STS_VOICE_SETTINGS = [
IO.Float.Input(
"speed",
default=1.0,
min=0.7,
max=1.3,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.",
),
IO.Float.Input(
"similarity_boost",
default=0.75,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Similarity boost. Higher values make the voice more similar to the original.",
),
IO.Boolean.Input(
"use_speaker_boost",
default=False,
tooltip="Boost similarity to the original speaker voice.",
),
IO.Float.Input(
"style",
default=0.0,
min=0.0,
max=0.2,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Style exaggeration. Higher values increase stylistic expression but may reduce stability.",
),
]
class ElevenLabsSpeechToSpeech(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsSpeechToSpeech",
display_name="ElevenLabs Speech to Speech",
category="api node/audio/ElevenLabs",
description="Transform speech from one voice to another while preserving the original content and emotion.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
"voice",
tooltip="Target voice for the transformation. "
"Connect from Voice Selector or Instant Voice Clone.",
),
IO.Audio.Input(
"audio",
tooltip="Source audio to transform.",
),
IO.Float.Input(
"stability",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Voice stability. Lower values give broader emotional range, "
"higher values produce more consistent but potentially monotonous speech.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"eleven_multilingual_sts_v2",
ELEVENLABS_STS_VOICE_SETTINGS,
),
IO.DynamicCombo.Option(
"eleven_english_sts_v2",
ELEVENLABS_STS_VOICE_SETTINGS,
),
],
tooltip="Model to use for speech-to-speech transformation.",
),
IO.Combo.Input(
"output_format",
options=["mp3_44100_192", "opus_48000_192"],
tooltip="Audio output format.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
tooltip="Seed for reproducibility.",
),
IO.Boolean.Input(
"remove_background_noise",
default=False,
tooltip="Remove background noise from input audio using audio isolation.",
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""",
),
)
@classmethod
async def execute(
cls,
voice: str,
audio: Input.Audio,
stability: float,
model: dict,
output_format: str,
seed: int,
remove_background_noise: bool,
) -> IO.NodeOutput:
audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"])
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac")
voice_settings = TextToSpeechVoiceSettings(
stability=stability,
similarity_boost=model["similarity_boost"],
style=model["style"],
use_speaker_boost=model["use_speaker_boost"],
speed=model["speed"],
)
response = await sync_op_raw(
cls,
ApiEndpoint(
path=f"/proxy/elevenlabs/v1/speech-to-speech/{voice}",
method="POST",
query_params={"output_format": output_format},
),
data=SpeechToSpeechRequest(
model_id=model["model"],
voice_settings=voice_settings.model_dump_json(exclude_none=True),
seed=seed,
remove_background_noise=remove_background_noise,
),
files={"audio": ("audio.mp4", audio_bytes_io.getvalue(), "audio/mp4")},
content_type="multipart/form-data",
as_binary=True,
)
return IO.NodeOutput(audio_bytes_to_audio_input(response))
def _generate_dialogue_inputs(count: int) -> list:
"""Generate input widgets for a given number of dialogue entries."""
inputs = []
for i in range(1, count + 1):
inputs.extend(
[
IO.String.Input(
f"text{i}",
multiline=True,
default="",
tooltip=f"Text content for dialogue entry {i}.",
),
IO.Custom(ELEVENLABS_VOICE).Input(
f"voice{i}",
tooltip=f"Voice for dialogue entry {i}. Connect from Voice Selector or Instant Voice Clone.",
),
]
)
return inputs
class ElevenLabsTextToDialogue(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ElevenLabsTextToDialogue",
display_name="ElevenLabs Text to Dialogue",
category="api node/audio/ElevenLabs",
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
inputs=[
IO.Float.Input(
"stability",
default=0.5,
min=0.0,
max=1.0,
step=0.5,
display_mode=IO.NumberDisplay.slider,
tooltip="Voice stability. Lower values give broader emotional range, "
"higher values produce more consistent but potentially monotonous speech.",
),
IO.Combo.Input(
"apply_text_normalization",
options=["auto", "on", "off"],
tooltip="Text normalization mode. 'auto' lets the system decide, "
"'on' always applies normalization, 'off' skips it.",
),
IO.Combo.Input(
"model",
options=["eleven_v3"],
tooltip="Model to use for dialogue generation.",
),
IO.DynamicCombo.Input(
"inputs",
options=[
IO.DynamicCombo.Option("1", _generate_dialogue_inputs(1)),
IO.DynamicCombo.Option("2", _generate_dialogue_inputs(2)),
IO.DynamicCombo.Option("3", _generate_dialogue_inputs(3)),
IO.DynamicCombo.Option("4", _generate_dialogue_inputs(4)),
IO.DynamicCombo.Option("5", _generate_dialogue_inputs(5)),
IO.DynamicCombo.Option("6", _generate_dialogue_inputs(6)),
IO.DynamicCombo.Option("7", _generate_dialogue_inputs(7)),
IO.DynamicCombo.Option("8", _generate_dialogue_inputs(8)),
IO.DynamicCombo.Option("9", _generate_dialogue_inputs(9)),
IO.DynamicCombo.Option("10", _generate_dialogue_inputs(10)),
],
tooltip="Number of dialogue entries.",
),
IO.String.Input(
"language_code",
default="",
tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). "
"Leave empty for automatic detection.",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=4294967295,
tooltip="Seed for reproducibility.",
),
IO.Combo.Input(
"output_format",
options=["mp3_44100_192", "opus_48000_192"],
tooltip="Audio output format.",
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""",
),
)
@classmethod
async def execute(
cls,
stability: float,
apply_text_normalization: str,
model: str,
inputs: dict,
language_code: str,
seed: int,
output_format: str,
) -> IO.NodeOutput:
num_entries = int(inputs["inputs"])
dialogue_inputs: list[DialogueInput] = []
for i in range(1, num_entries + 1):
text = inputs[f"text{i}"]
voice_id = inputs[f"voice{i}"]
validate_string(text, min_length=1)
dialogue_inputs.append(DialogueInput(text=text, voice_id=voice_id))
request = TextToDialogueRequest(
inputs=dialogue_inputs,
model_id=model,
language_code=language_code if language_code.strip() else None,
settings=DialogueSettings(stability=stability),
seed=seed,
apply_text_normalization=apply_text_normalization,
)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/elevenlabs/v1/text-to-dialogue",
method="POST",
query_params={"output_format": output_format},
),
data=request,
as_binary=True,
)
return IO.NodeOutput(audio_bytes_to_audio_input(response))
class ElevenLabsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ElevenLabsSpeechToText,
ElevenLabsVoiceSelector,
ElevenLabsTextToSpeech,
ElevenLabsAudioIsolation,
ElevenLabsTextToSoundEffects,
ElevenLabsInstantVoiceClone,
ElevenLabsSpeechToSpeech,
ElevenLabsTextToDialogue,
]
async def comfy_entrypoint() -> ElevenLabsExtension:
return ElevenLabsExtension()

View File

@@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64
import os
from enum import Enum
from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
@@ -119,6 +120,13 @@ async def create_image_parts(
return image_parts
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
if mime is None:
return False
return fnmatch(mime.value, pattern)
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
for part in candidate.content.parts:
if part_type == "text" and part.text:
parts.append(part)
elif part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
parts.append(part)
if not parts and blocked_reasons:
@@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
parts = get_parts_by_type(response, "image/*")
for part in parts:
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
@@ -308,6 +316,7 @@ class GeminiNode(IO.ComfyNode):
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
@@ -585,6 +594,7 @@ class GeminiImage(IO.ComfyNode):
tooltip="Choose 'IMAGE' for image-only output, or "
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
advanced=True,
),
IO.String.Input(
"system_prompt",
@@ -592,6 +602,7 @@ class GeminiImage(IO.ComfyNode):
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
@@ -626,7 +637,7 @@ class GeminiImage(IO.ComfyNode):
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
image_config = GeminiImageConfig() if aspect_ratio == "auto" else GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
parts.extend(await create_image_parts(cls, images))
@@ -646,7 +657,7 @@ class GeminiImage(IO.ComfyNode):
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
@@ -706,6 +717,7 @@ class GeminiImage2(IO.ComfyNode):
options=["IMAGE+TEXT", "IMAGE"],
tooltip="Choose 'IMAGE' for image-only output, or "
"'IMAGE+TEXT' to return both the generated image and a text response.",
advanced=True,
),
IO.Image.Input(
"images",
@@ -725,6 +737,7 @@ class GeminiImage2(IO.ComfyNode):
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[

View File

@@ -1,31 +1,48 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
TextureEditTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
To3DUVFileInput,
To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
upload_3d_model_to_comfyapi,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
def _is_tencent_rate_limited(status: int, body: object) -> bool:
return (
status == 400
and isinstance(body, dict)
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
for i in response_objs:
if i.Type.lower() == file_type.lower():
return i
if raise_if_not_found:
raise ValueError(f"'{file_type}' file type is not found in the response.")
return None
@@ -35,8 +52,9 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model (Pro)",
display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
"model",
@@ -120,6 +138,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -131,11 +150,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@@ -145,8 +167,9 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model (Pro)",
display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
"model",
@@ -268,6 +291,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -279,11 +303,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
class TencentModelTo3DUVNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="api node/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
tooltip="Input 3D model (GLB, OBJ, or FBX)",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
)
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported file format: '{file_format}'. "
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(
Type=file_format.upper(),
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
)
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
)
class Tencent3DTextureEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="api node/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.6}""",
),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
prompt: str,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
response_model=To3DProTaskCreateResponse,
data=TextureEditTaskRequest(
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
Prompt=prompt,
EnablePBR=True,
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
class Tencent3DPartNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="api node/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -293,6 +563,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [
TencentTextToModelNode,
TencentImageToModelNode,
# TencentModelTo3DUVNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
]

View File

@@ -261,6 +261,7 @@ class IdeogramV1(IO.ComfyNode):
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
@@ -394,6 +395,7 @@ class IdeogramV2(IO.ComfyNode):
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
@@ -411,6 +413,7 @@ class IdeogramV2(IO.ComfyNode):
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
advanced=True,
),
IO.String.Input(
"negative_prompt",
@@ -564,6 +567,7 @@ class IdeogramV3(IO.ComfyNode):
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
@@ -590,6 +594,7 @@ class IdeogramV3(IO.ComfyNode):
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
advanced=True,
),
IO.Image.Input(
"character_image",

View File

@@ -2262,6 +2262,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio",
category="api node/video/Kling",
essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
IO.Video.Input("video"),
@@ -2333,6 +2334,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
max=2.0,
display_mode=IO.NumberDisplay.slider,
tooltip="Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.",
advanced=True,
),
],
outputs=[
@@ -2454,6 +2456,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"image_type",
options=[i.value for i in KlingImageGenImageReferenceType],
advanced=True,
),
IO.Float.Input(
"image_fidelity",
@@ -2463,6 +2466,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Reference intensity for user-uploaded images",
advanced=True,
),
IO.Float.Input(
"human_fidelity",
@@ -2472,6 +2476,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Subject reference similarity",
advanced=True,
),
IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]),
IO.Combo.Input(
@@ -2587,7 +2592,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
IO.Boolean.Input("generate_audio", default=True, advanced=True),
],
outputs=[
IO.Video.Output(),
@@ -2655,7 +2660,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
IO.Boolean.Input("generate_audio", default=True, advanced=True),
],
outputs=[
IO.Video.Output(),

View File

@@ -74,6 +74,7 @@ class TextToVideoNode(IO.ComfyNode):
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
advanced=True,
),
],
outputs=[
@@ -151,6 +152,7 @@ class ImageToVideoNode(IO.ComfyNode):
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
advanced=True,
),
],
outputs=[

Some files were not shown because too many files have changed in this diff Show More