Compare commits

...

299 Commits

Author SHA1 Message Date
Alexander Piskun
afb54219fa feat(api-nodes): allow to use "IMAGE+TEXT" in NanoBanana2 (#12729) 2026-03-01 23:24:33 -08:00
rattus
7175c11a4e comfy aimdo 0.2.4 (#12727)
Comfy Aimdo 0.2.4 fixes a VRAM buffer alignment issue that happens in
someworkflows where action is able to bypass the pytorch allocator
and go straight to the cuda hook.
2026-03-01 22:21:41 -08:00
rattus
dfbf99a061 model_mangament: make dynamic --disable-smart-memory work (#12724)
This was previously considering the pool of dynamic models as one giant
entity for the sake of smart memory, but that isnt really the useful
or what a user would reasonably expect. Make Dynamic VRAM properly purge
its models just like the old --disable-smart-memory but conditioning
the dynamic-for-dynamic bypass on smart memory.

Re-enable dynamic smart memory.
2026-03-01 19:18:56 -08:00
comfyanonymous
602f6bd82c Make --disable-smart-memory disable dynamic vram. (#12722) 2026-03-01 15:28:39 -05:00
rattus
c0d472e5b9 comfy-aimdo 0.2.3 (#12720) 2026-03-01 11:14:56 -08:00
drozbay
4d79f4f028 fix: handle substep sigmas in context window set_step (#12719)
Multi-step samplers (eg. dpmpp_2s_ancestral) call the model at intermediate sigma values not present in the schedule. This caused set_step to crash with "No sample_sigmas matched current timestep" when context windows were enabled.

The fix is to keep self._step from the last exact match when a substep sigma is encountered, since substeps are still logically part of their parent step and should use the same context windows.

Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2026-03-01 09:38:30 -08:00
Christian Byrne
850e8b42ff feat: add text preview support to jobs API (#12169)
* feat: add text preview support to jobs API

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375
Co-authored-by: Amp <amp@ampcode.com>

* test: update tests to expect text as previewable media type

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375

---------
2026-02-28 21:38:19 -08:00
Christian Byrne
d159142615 refactor: rename Mahiro CFG to Similarity-Adaptive Guidance (#12172)
* refactor: rename Mahiro CFG to Similarity-Adaptive Guidance

Rename the display name to better describe what the node does:
adaptively blends guidance based on cosine similarity between
positive and negative conditions.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1
Co-authored-by: Amp <amp@ampcode.com>

* feat: add search aliases for old mahiro name

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1

* rename: Similarity-Adaptive Guidance → Positive-Biased Guidance (per reviewer)

- display_name changed to 'Positive-Biased Guidance' to avoid SAG acronym collision
- search_aliases expanded: mahiro, mahiro cfg, similarity-adaptive guidance, positive-biased cfg
- ruff format applied

---------

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-28 20:59:24 -08:00
comfyanonymous
1080bd442a Disable dynamic vram on wsl. (#12706) 2026-02-28 22:23:28 -05:00
comfyanonymous
17106cb124 Move parsing of requirements logic to function. (#12701) 2026-02-28 22:21:32 -05:00
rattus
48bb0bd18a cli_args: Default comfy to DynamicVram mode (#12658) 2026-02-28 16:52:30 -05:00
rattus
5f41584e96 Disable dynamic_vram when weight hooks applied (#12653)
* sd: add support for clip model reconstruction

* nodes: SetClipHooks: Demote the dynamic model patcher

* mp: Make dynamic_disable more robust

The backup need to not be cloned. In addition add a delegate object
to ModelPatcherDynamic so that non-cloning code can do
ModelPatcherDynamic demotion

* sampler_helpers: Demote to non-dynamic model patcher when hooking

* code rabbit review comments
2026-02-28 16:50:18 -05:00
Jukka Seppänen
1f6744162f feat: Support SCAIL WanVideo model (#12614) 2026-02-28 16:49:12 -05:00
fappaz
95e1059661 fix(ace15): handle missing lm_metadata in memory estimation during checkpoint export #12669 (#12686) 2026-02-28 01:18:40 -05:00
Christian Byrne
80d49441e5 refactor: use AspectRatio enum members as ASPECT_RATIOS dict keys (#12689)
Amp-Thread-ID: https://ampcode.com/threads/T-019ca1cb-0150-7549-8b1b-6713060d3408

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-27 20:53:46 -08:00
comfyanonymous
9d0e114ee3 PyOpenGL-accelerate is not necessary. (#12692) 2026-02-27 23:34:58 -05:00
Talmaj
ac4412d0fa Native LongCat-Image implementation (#12597) 2026-02-27 23:04:34 -05:00
comfyanonymous
94f1a1cc9d Limit overlap in image tile and combine nodes to prevent issues. (#12688) 2026-02-27 20:16:24 -05:00
rattus
e721e24136 ops: implement lora requanting for non QuantizedTensor fp8 (#12668)
Allow non QuantizedTensor layer to set want_requant to get the post lora
calculation stochastic cast down to the original input dtype.

This is then used by the legacy fp8 Linear implementation to set the
compute_dtype to the preferred lora dtype but then want_requant it back
down to fp8.

This fixes the issue with --fast fp8_matrix_mult is combined with
--fast dynamic_vram which doing a lora on an fp8_ non QT model.
2026-02-27 19:05:51 -05:00
Reiner "Tiles" Prokein
25ec3d96a3 Class WanVAE, def encode, feat_map is using self.decoder instead of self.encoder (#12682) 2026-02-27 19:03:45 -05:00
Christian Byrne
1f1ec377ce feat: add ResolutionSelector node for aspect ratio and megapixel-based resolution calculation (#12199)
Amp-Thread-ID: https://ampcode.com/threads/T-019c179e-cd8c-768f-ae66-207c7a53c01d

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-27 09:13:57 -08:00
pythongosssss
0a7f8e11b6 fix torch.cat requiring inputs to all be same dimensions (#12673) 2026-02-27 08:13:24 -08:00
vickytsang
35e9fce775 Enable Pytorch Attention for gfx950 (#12641) 2026-02-26 20:16:12 -05:00
Jukka Seppänen
c7f7d52b68 feat: Support SDPose-OOD (#12661) 2026-02-26 19:59:05 -05:00
rattus
08b26ed7c2 bug_report template: Push harder for logs (#12657)
We get a lot od bug reports without logs, especially for performance
issues.
2026-02-26 18:59:24 -05:00
fappaz
b233dbe0bc feat(ace-step): add ACE-Step 1.5 lycoris key alias mapping for LoKR #12638 (#12665) 2026-02-26 18:19:19 -05:00
comfyanonymous
3811780e4f Portable with cu128 isn't useful anymore. (#12666)
Users should either use the cu126 one or the regular one (cu130 at the moment)

The cu128 portable is still included in the latest github release but I will stop including it as soon as it becomes slightly annoying to deal with. This might happen as soon as next week.
2026-02-26 17:12:29 -05:00
comfyanonymous
3dd10a59c0 ComfyUI v0.15.1 2026-02-26 15:59:22 -05:00
ComfyUI Wiki
88d05fe483 chore: update workflow templates to v0.9.4 (#12664) 2026-02-26 15:52:45 -05:00
Alexander Piskun
fd41ec97cc feat(api-nodes): add NanoBanana2 (#12660) 2026-02-26 15:52:10 -05:00
rattus
420e900f69 main: load aimdo earlier (#12655)
Some custom node packs are naughty, and violate the
dont-load-torch-on-load rule. This causes aimdo to lose preference on
its allocator hook on linux.

Go super early on the aimdo first-stage init before custom nodes
are mentioned at all.
2026-02-26 15:19:38 -05:00
pythongosssss
38ca94599f pyopengl-accelerate can cause object to be numpy ints instead of bare ints, which the glDeleteTextures function does not accept, explicitly cast to int (#12650) 2026-02-26 03:07:35 -08:00
Christian Byrne
74b5a337dc fix: move essentials_category to correct replacement nodes (#12568)
Move essentials_category from deprecated/incorrect nodes to their replacements:
- ImageBatch → BatchImagesNode (ImageBatch is deprecated)
- Blur → removed (should use subgraph blueprint)
- GetVideoComponents → Video Slice

Amp-Thread-ID: https://ampcode.com/threads/T-019c8340-4da2-723b-a09f-83895c5bbda5
2026-02-26 01:00:32 -08:00
comfyanonymous
8a4d85c708 Cleanups to the last PR. (#12646) 2026-02-26 01:30:31 -05:00
Tavi Halperin
a4522017c5 feat: per-guide attention strength control in self-attention (#12518)
Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
2026-02-26 01:25:23 -05:00
Jukka Seppänen
907e5dcbbf initial FlowRVS support (#12637) 2026-02-25 23:38:46 -05:00
comfyanonymous
7253531670 Fix ltxav te mem estimation. (#12643) 2026-02-25 23:13:47 -05:00
comfyanonymous
e14b04478c Fix LTXAV text enc min length. (#12640)
Should have been 1024 instead of 512
2026-02-25 22:36:02 -05:00
Christian Byrne
eb8737d675 Update requirements.txt (#12642) 2026-02-25 18:30:48 -08:00
rattus
0467f690a8 comfy aimdo 0.2.2 (#12635)
Comfy Aimdo 0.2.2 moves the cuda allocator hook from the cudart API to
the cuda driver API on windows. This is needed to handle Windows+cu13
where cudart is statically linked.
2026-02-25 16:50:05 -05:00
rattus
4f5b7dbf1f Fix Aimdo fallback on probe to not use zero-copy sft (#12634)
* utils: dont use comfy sft loader in aimdo fallback

This was going to the raw command line switch and should respect main.py
probe of whether aimdo actually loaded successfully.

* ops: dont use deferred linear load in Aimdo fallback

Avoid changes of behaviour on --fast dynamic_vram when aimdo doesnt work.
2026-02-25 16:49:48 -05:00
rattus
3ebe1ac22e Disable dynamic_vram when using torch compiler (#12612)
* mp: attach re-construction arguments to model patcher

When making a model-patcher from a unet or ckpt, attach a callable
function that can be called to replay the model construction. This
can be used to deep clone model patcher WRT the actual model.

Originally written by Kosinkadink
f4b99bc623

* mp: Add disable_dynamic clone argument

Add a clone argument that lets a caller clone a ModelPatcher but disable
dynamic to demote the clone to regular MP. This is useful for legacy
features where dynamic_vram support is missing or TBD.

* torch_compile: disable dynamic_vram

This is a bigger feature. Disable for the interim to preserve
functionality.
2026-02-24 19:13:46 -05:00
rattus
befa83d434 comfy aimdo 0.2.1 (#12620)
Changes:

throttle VRAM threshold checks to restore performance in high-layer-rate
conditions.
2026-02-24 16:02:26 -05:00
Jedrzej Kosinski
33f83d53ae Fix KeyError when prompt entries lack class_type key (#12595)
Skip entries in the prompt dict that don't contain a class_type key
in apply_replacements(), preventing crashes on metadata or non-node
entries.

Fixes Comfy-Org/ComfyUI#12517
2026-02-24 16:02:05 -05:00
comfyanonymous
b874bd2b8c ComfyUI v0.15.0 2026-02-24 12:42:15 -05:00
ComfyUI Wiki
0aa02453bb chore: update embedded docs to v0.4.3 (#12601) 2026-02-24 12:41:36 -05:00
comfyanonymous
599f9c5010 Don't crash right away if op is uninitialized. (#12615) 2026-02-24 12:28:25 -05:00
ComfyUI Wiki
11fefa58e9 chore: update workflow templates to v0.9.3 (#12610) 2026-02-24 09:04:51 -08:00
Alexander Piskun
d8090013b8 feat(api-nodes): add ByteDance Seedream-5 model (#12609)
* feat(api-nodes): add ByteDance Seedream-5 model

* made error message more correct

* rename seedream 5.0 model
2026-02-24 09:03:30 -08:00
Christian Byrne
048dd2f321 Patch frontend to 1.39.16 (from 1.39.14) (#12604)
* Update requirements.txt

* Update requirements.txt

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-24 00:44:40 -08:00
comfyanonymous
84aba95e03 Temporality unbreak some LTXAV workflows to give people time to migrate. (#12605) 2026-02-24 00:50:03 -05:00
comfyanonymous
9b1c63eb69 Add SplitImageToTileList and ImageMergeTileList nodes. (#12599)
With these you can split an image into tiles, do operations and then combine it back to a single image.
2026-02-23 21:01:17 -05:00
ComfyUI Wiki
7a7debcaf1 chore: update workflow templates to v0.9.2 (#12596) 2026-02-23 18:27:20 -05:00
Alexander Piskun
dba2766e53 feat(api-nodes): add KlingAvatar node (#12591) 2026-02-23 11:27:16 -08:00
comfyanonymous
caa43d2395 Fix issue loading fp8 ltxav checkpoints. (#12582) 2026-02-22 16:00:02 -05:00
comfyanonymous
07ca6852e8 Fix dtype issue in embeddings connector. (#12570) 2026-02-22 03:18:20 -05:00
comfyanonymous
f266b8d352 Move LTXAV av embedding connectors to diffusion model. (#12569) 2026-02-21 22:29:58 -05:00
Christian Byrne
b6cb30bab5 chore: tune CodeRabbit config to limit review scope and disable for drafts (#12567)
* chore: tune CodeRabbit config to limit review scope and disable for drafts

- Add tone_instructions to focus only on newly introduced issues
- Add global path_instructions entry to ignore pre-existing issues in moved/reformatted code
- Disable draft PR reviews (drafts: false) and add WIP title keywords
- Disable ruff tool to prevent linter-based outside-diff-range comments

Addresses feedback from maintainers about CodeRabbit flagging pre-existing
issues in code that was merely moved or de-indented (e.g., PR #12557),
which can discourage community contributions and cause scope creep.

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba

* chore: add 'DO NOT MERGE' to ignore_title_keywords

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba
2026-02-21 18:32:15 -08:00
Christian Byrne
ee72752162 Add category to Normalized Attention Guidance node (#12565) 2026-02-21 19:51:21 -05:00
Alexander Brown
7591d781a7 fix: specify UTF-8 encoding when reading subgraph files (#12563)
On Windows, Python defaults to cp1252 encoding when no encoding is
specified. JSON files containing UTF-8 characters (e.g., non-ASCII
characters) cause UnicodeDecodeError when read with cp1252.

This fixes the error that occurs when loading blueprint subgraphs
on Windows systems.

https://claude.ai/code/session_014WHi3SL9Gzsi3U6kbSjbSb

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-21 15:05:00 -08:00
rattus
0bfb936ab4 comfy-aimdo 0.2 - Improved pytorch allocator integration (#12557)
Integrate comfy-aimdo 0.2 which takes a different approach to
installing the memory allocator hook. Instead of using the complicated
and buggy pytorch MemPool+CudaPluggableAlloctor, cuda is directly hooked
making the process much more transparent to both comfy and pytorch. As
far as pytorch knows, aimdo doesnt exist anymore, and just operates
behind the scenes.

Remove all the mempool setup stuff for dynamic_vram and bump the
comfy-aimdo version. Remove the allocator object from memory_management
and demote its use as an enablment check to a boolean flag.

Comfy-aimdo 0.2 also support the pytorch cuda async allocator, so
remove the dynamic_vram based force disablement of cuda_malloc and
just go back to the old settings of allocators based on command line
input.
2026-02-21 10:52:57 -08:00
pythongosssss
602b2505a4 add support for pyopengl < 3.1.4 where the size parameter does not exist (#12555) 2026-02-21 06:14:57 -08:00
Christian Byrne
04a55d5019 fix: swap essentials_category from CLIPTextEncode to PrimitiveStringMultiline (#12553)
Remove CLIPTextEncode from Basics essentials category and add
PrimitiveStringMultiline (String Multiline) in its place.

Amp-Thread-ID: https://ampcode.com/threads/T-019c7efb-d916-7244-8c43-77b615ba0622

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 23:46:46 -08:00
Christian Byrne
5fb8f06495 feat: add essential subgraph blueprints (#12552)
Add 24 non-cloud essential blueprints from comfyui-wiki/Subgraph-Blueprints.
These cover common workflows: text/image/video generation, editing,
inpainting, outpainting, upscaling, depth maps, pose, captioning, and more.

Cloud-only blueprints (5) are excluded and will be added once
client-side distribution filtering lands.

Amp-Thread-ID: https://ampcode.com/threads/T-019c6f43-6212-7308-bea6-bfc35a486cbf
2026-02-20 23:40:13 -08:00
Terry Jia
5a182bfaf1 update glsl blueprint with gradient (#12548)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 23:02:36 -08:00
Terry Jia
f394af8d0f feat: add gradient-slider display mode for FLOAT inputs (#12536)
* feat: add gradient-slider display mode for FLOAT inputs

* fix: use precise type annotation list[list[float]] for gradient_stops

Amp-Thread-ID: https://ampcode.com/threads/T-019c7eea-be2b-72ce-a51f-838376f9b7a7

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: bymyself <cbyrne@comfy.org>
2026-02-20 22:52:32 -08:00
Arthur R Longbottom
aeb5bdc8f6 Fix non-contiguous audio waveform crash in video save (#12550)
Fixes #12549
2026-02-20 23:37:55 -05:00
comfyanonymous
64953bda0a Update nightly installation command for ROCm (#12547) 2026-02-20 21:32:05 -05:00
Christian Byrne
b254cecd03 chore: add CodeRabbit configuration for automated code review (#12539)
Amp-Thread-ID: https://ampcode.com/threads/T-019c7915-2abf-743c-9c74-b93d87d63055

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 00:25:54 -08:00
Alexander Piskun
1bb956fb66 [API Nodes] add ElevenLabs nodes (#12207)
* feat(api-nodes): add ElevenLabs API nodes

* added price badge for ElevenLabsInstantVoiceClone node

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 22:12:28 -08:00
pythongosssss
96d6bd1a4a Add GLSL shader node using PyOpenGL (#12148)
* adds support for executing simple glsl shaders
using moderngl package

* tidy

* Support multiple outputs

* Try fix build

* fix casing

* fix line endings

* convert to using PyOpenGL and glfw

* remove cpu support

* tidy

* add additional support for egl & osmesa backends

* fix ci
perf: only read required outputs

* add diagnostics, update mac initialization

* GLSL glueprints + node fixes (#12492)

* Add image operation blueprints

* Add channels

* Add glow

* brightness/contrast

* hsb

* add glsl shader update system

* shader nit iteration

* add multipass for faster blur

* more fixes

* rebuild blueprints

* print -> logger

* Add edge preserving blur

* fix: move _initialized flag to end of GLContext.__init__

Prevents '_vao' attribute error when init fails partway through
and subsequent calls skip initialization due to early _initialized flag.

* update valid ranges
- threshold 0-100
- step 0+

* fix value ranges

* rebuild node to remove extra inputs

* Fix gamma step

* clamp saturation in colorize instead of wrapping

* Fix crash on 1x1 px images

* rework description

* remove unnecessary f


Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Hunter Senft-Grupp <hunter@comfy.org>
2026-02-19 23:22:13 -05:00
comfyanonymous
5f2117528a Force min length 1 when tokenizing for text generation. (#12538) 2026-02-19 22:57:44 -05:00
comfyanonymous
0301ccf745 Small cleanup and try to get qwen 3 work with the text gen. (#12537) 2026-02-19 22:42:28 -05:00
Christian Byrne
4d172e9ad7 feat: mark 429 widgets as advanced for collapsible UI (#12197)
* feat: mark 429 widgets as advanced for collapsible UI

Mark widgets as advanced across core, comfy_extras, and comfy_api_nodes
to support the new collapsible advanced inputs section in the frontend.

Changes:
- 267 advanced markers in comfy_extras/
- 162 advanced markers in comfy_api_nodes/
- All files pass python3 -m py_compile verification

Widgets marked advanced (hidden by default):
- Scheduler internals: sigma_max, sigma_min, rho, mu, beta, alpha
- Sampler internals: eta, s_noise, order, rtol, atol, h_init, pcoeff, etc.
- Memory optimization: tile_size, overlap, temporal_size, temporal_overlap
- Pipeline controls: add_noise, start_at_step, end_at_step
- Timing controls: start_percent, end_percent
- Layer selection: stop_at_clip_layer, layers, block_number
- Video encoding: codec, crf, format
- Device/dtype: device, noise_device, dtype, weight_dtype

Widgets kept basic (always visible):
- Core params: strength, steps, cfg, denoise, seed, width, height
- Model selectors: ckpt_name, lora_name, vae_name, sampler_name
- Common controls: upscale_method, crop, batch_size, fps, opacity

Related: frontend PR #11939
Amp-Thread-ID: https://ampcode.com/threads/T-019c1734-6b61-702e-b333-f02c399963fc

* fix: remove advanced=True from DynamicCombo.Input (unsupported)

Amp-Thread-ID: https://ampcode.com/threads/T-019c1734-6b61-702e-b333-f02c399963fc

* fix: address review - un-mark model merge, video, image, and training node widgets as advanced

Per comfyanonymous review:
- Model merge arguments should not be advanced (all 14 model-specific merge classes)
- SaveAnimatedWEBP lossless/quality/method should not be advanced
- SaveWEBM/SaveVideo codec/crf/format should not be advanced
- TrainLoraNode options should not be advanced (7 inputs)

Amp-Thread-ID: https://ampcode.com/threads/T-019c322b-a3a8-71b7-9962-d44573ca6352

* fix: un-mark batch_size and webcam width/height as advanced (should stay basic)

Amp-Thread-ID: https://ampcode.com/threads/T-019c3236-1417-74aa-82a3-bcb365fbe9d1

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 19:20:02 -08:00
Yourz
5632b2df9d feat: add essentials_category (#12357)
* feat: add essentials_category field to node schema

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b25-cd90-7218-9071-03cb46b351b3

* feat: add ESSENTIALS_CATEGORY to core nodes

Marked nodes:
- Basic: LoadImage, SaveImage, LoadVideo, SaveVideo, Load3D, CLIPTextEncode
- Image Tools: ImageScale, ImageInvert, ImageBatch, ImageCrop, ImageRotate, ImageBlur
- Image Tools/Preprocessing: Canny
- Image Generation: LoraLoader
- Audio: LoadAudio, SaveAudio

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b25-cd90-7218-9071-03cb46b351b3

* Add ESSENTIALS_CATEGORY to more nodes

- SaveGLB (Basic)
- GetVideoComponents (Video Tools)
- TencentTextToModelNode, TencentImageToModelNode (3D)
- RecraftRemoveBackgroundNode (Image Tools)
- KlingLipSyncAudioToVideoNode (Video Generation)
- OpenAIChatNode (Text Generation)
- StabilityTextToAudio (Audio)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b69-81c1-71c3-8096-450a39e20910

* fix: correct essentials category for Canny node

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* refactor: replace essentials_category string literals with constants

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* refactor: revert constants, use string literals for essentials_category

Amp-Thread-ID: https://ampcode.com/threads/T-019c7303-ab53-7341-be76-a5da1f7a657e
Co-authored-by: Amp <amp@ampcode.com>

* fix: update basics

---------

Co-authored-by: bymyself <cbyrne@comfy.org>
Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-19 19:00:26 -08:00
Alexander Piskun
2687652530 fix(api-nodes): force Gemini to return uncompressed images (#12516) 2026-02-19 04:10:39 -08:00
Jukka Seppänen
6d11cc7354 feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392) 2026-02-18 20:49:43 -05:00
comfyanonymous
f262444dd4 Add simple 3 band equalizer node for audio. (#12519) 2026-02-18 18:36:35 -05:00
Alexander Piskun
239ddd3327 fix(api-nodes): add price badge for Rodin Gen-2 node (#12512) 2026-02-17 23:15:23 -08:00
Hunter
83dd65f23a fix: use glob matching for Gemini image MIME types (#12511)
gemini-3-pro-image-preview nondeterministically returns image/jpeg
instead of image/png. get_image_from_response() hardcoded
get_parts_by_type(response, "image/png"), silently dropping JPEG
responses and falling back to torch.zeros (all-black output).

Add _mime_matches() helper using fnmatch for glob-style MIME matching.
Change get_image_from_response() to request "image/*" so any image
format returned by the API is correctly captured.
2026-02-18 00:03:54 -05:00
Terry Jia
8ad38d2073 BBox widget (#11594)
* Boundingbox widget

* code improve

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 17:13:39 -08:00
Comfy Org PR Bot
6c14f129af Bump comfyui-frontend-package to 1.39.14 (#12494)
* Bump comfyui-frontend-package to 1.39.13

* Update requirements.txt

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 13:41:34 -08:00
rattus
58dcc97dcf ops: limit return of requants (#12506)
This check was far too broad and the dtype is not a reliable indicator
of wanting the requant (as QT returns the compute dtype as the dtype).
So explictly plumb whether fp8mm wants the requant or not.
2026-02-17 15:32:27 -05:00
comfyanonymous
19236edfa4 ComfyUI v0.14.1 2026-02-17 13:28:06 -05:00
ComfyUI Wiki
73c3f86973 chore: update workflow templates to v0.8.43 (#12507) 2026-02-17 13:25:55 -05:00
Alexander Piskun
262abf437b feat(api-nodes): add Recraft V4 nodes (#12502) 2026-02-17 13:25:44 -05:00
Alexander Piskun
5284e6bf69 feat(api-nodes): add "viduq3-turbo" model and Vidu3StartEnd node; fix the price badges (#12482) 2026-02-17 10:07:14 -08:00
chaObserv
44f8598521 Fix anima LLM adapter forward when manual cast (#12504) 2026-02-17 07:56:44 -08:00
comfyanonymous
fe52843fe5 ComfyUI v0.14.0 2026-02-17 00:39:54 -05:00
comfyanonymous
c39653163d Fix anima preprocess text embeds not using right inference dtype. (#12501) 2026-02-17 00:29:20 -05:00
comfyanonymous
18927538a1 Implement NAG on all the models based on the Flux code. (#12500)
Use the Normalized Attention Guidance node.

Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
2026-02-16 23:30:34 -05:00
Jedrzej Kosinski
8a6fbc2dc2 Allow control_after_generate to be type ControlAfterGenerate in v3 schema (#12187) 2026-02-16 22:20:21 -05:00
Alex Butler
b44fc4c589 add venv* to gitignore (#12431) 2026-02-16 22:16:19 -05:00
comfyanonymous
4454fab7f0 Remove code to support RMSNorm on old pytorch. (#12499) 2026-02-16 20:09:24 -05:00
ComfyUI Wiki
1978f59ffd chore: update workflow templates to v0.8.42 (#12491) 2026-02-16 17:33:43 -05:00
comfyanonymous
88e6370527 Remove workaround for old pytorch. (#12480) 2026-02-15 20:43:53 -05:00
rattus
c0370044cd MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446)
* lora: add weight shape calculations.

This lets the loader know if a lora will change the shape of a weight
so it can take appropriate action.

* MPDynamic: force load flux img_in weight

This weight is a bit special, in that the lora changes its geometry.
This is rather unique, not handled by existing estimate and doesn't
work for either offloading or dynamic_vram.

Fix for dynamic_vram as a special case. Ideally we can fully precalculate
these lora geometry changes at load time, but just get these models
working first.
2026-02-15 20:30:09 -05:00
rattus
ecd2a19661 Fix lora Extraction in offload conditions (+ dynamic_vram mode) (#12479)
* lora_extract: Add a trange

If you bite off more than your GPU can chew, this kinda just hangs.
Give a rough indication of progress counting the weights in a trange.

* lora_extract: Support on-the-fly patching

Use the on-the-fly approach from the regular model saving logic for
lora extraction too. Switch off force_cast_weights accordingly.

This gets extraction working in dynamic vram while also supporting
extraction on GPU offloaded.
2026-02-15 20:28:51 -05:00
Alexander Piskun
2c1d06a4e3 feat(api-nodes): add Bria RMBG nodes (#12465)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-15 17:22:30 -08:00
Alexander Piskun
e2c71ceb00 feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes (#12428)
* feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes

* add image output to TencentModelTo3DUV node

* commented out two nodes

* added rate_limit check to other hunyuan3d nodes
2026-02-15 05:33:18 -08:00
Jedrzej Kosinski
596ed68691 Node Replacement API (#12014) 2026-02-15 02:12:30 -08:00
Alexander Piskun
ce4a1ab48d chore(api-nodes): remove "gpt-4o" model (#12467) 2026-02-15 01:31:59 -08:00
comfyanonymous
e1ede29d82 Remove unsafe pickle loading code that was used on pytorch older than 2.4 (#12473)
ComfyUI hasn't started on pytorch 2.4 since last month.
2026-02-14 22:53:52 -05:00
Christian Byrne
df1e5e8514 Update frontend package to 1.38.14 (#12469) 2026-02-14 11:01:10 -08:00
krigeta
dc9822b7df Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359) 2026-02-13 22:23:52 -05:00
comfyanonymous
712efb466b Add left padding to LTXAV text encoder. (#12456) 2026-02-13 21:56:54 -05:00
comfyanonymous
726af73867 Fix some custom nodes. (#12455) 2026-02-13 20:21:10 -05:00
comfyanonymous
831351a29e Support generating attention masks for left padded text encoders. (#12454) 2026-02-13 20:15:23 -05:00
comfyanonymous
e1add563f9 Use torch RMSNorm for flux models and refactor hunyuan video code. (#12432) 2026-02-13 15:35:13 -05:00
rattus
8902907d7a dynamic_vram: Training fixes (#12442) 2026-02-13 15:29:37 -05:00
comfyanonymous
e03fe8b591 Update command to install AMD stable linux pytorch. (#12437) 2026-02-12 23:29:12 -05:00
rattus
ae79e33345 llama: use a more efficient rope implementation (#12434)
Get rid of the cat and unary negation and inplace add-cmul the two
halves of the rope. Precompute -sin once at the start of the model
rather than every transformer block.

This is slightly faster on both GPU and CPU bound setups.
2026-02-12 19:56:42 -05:00
rattus
117e214354 ModelPatcherDynamic: force load non leaf weights (#12433)
The current behaviour of the default ModelPatcher is to .to a model
only if its fully loaded, which is how random non-leaf weights get
loaded in non-LowVRAM conditions.

The however means they never get loaded in dynamic_vram. In the
dynamic_vram case, force load them to the GPU.
2026-02-12 19:51:50 -05:00
Alexander Piskun
4a93a62371 fix(api-nodes): add separate retry budget for 429 rate limit responses (#12421) 2026-02-12 01:38:51 -08:00
comfyanonymous
66c18522fb Add a tip for common error. (#12414) 2026-02-11 22:12:16 -05:00
askmyteapot
e5ae670a40 Update ace15.py to allow min_p sampling (#12373) 2026-02-11 20:28:48 -05:00
rattus
3fe61cedda model_patcher: guard against none model_dtype (#12410)
Handle the case where the _model_dtype exists but is none with the
intended fallback.
2026-02-11 14:54:02 -05:00
rattus
2a4328d639 ace15: Use dynamic_vram friendly trange (#12409)
Factor out the ksampler trange and use it in ACE LLM to prevent the
silent stall at 0 and rate distortion due to first-step model load.
2026-02-11 14:53:42 -05:00
rattus
d297a749a2 dynamic_vram: Fix windows Aimdo crash + Fix LLM performance (#12408)
* model_management: lazy-cache aimdo_tensor

These tensors cosntructed from aimdo-allocations are CPU expensive to
make on the pytorch side. Add a cache version that will be valid with
signature match to fast path past whatever torch is doing.

* dynamic_vram: Minimize fast path CPU work

Move as much as possible inside the not resident if block and cache
the formed weight and bias rather than the flat intermediates. In
extreme layer weight rates this adds up.
2026-02-11 14:50:16 -05:00
Alexander Piskun
2b7cc7e3b6 [API Nodes] enable Magnific Upscalers (#12179)
* feat(api-nodes): enable Magnific Upscalers

* update price badges

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:30:19 -08:00
Benjamin Lu
4993411fd9 Dispatch desktop auto-bump when a ComfyUI release is published (#12398)
* Dispatch desktop auto-bump on ComfyUI release publish

* Fix release webhook secret checks in step conditions

* Require desktop dispatch token in release webhook

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Luke Mino-Altherr <lminoaltherr@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:15:13 -08:00
Alexander Piskun
2c7cef4a23 fix(api-nodes): retry on connection errors during polling instead of aborting (#12393) 2026-02-11 10:51:49 -08:00
comfyanonymous
76a7fa96db Make built in lora training work on anima. (#12402) 2026-02-10 22:04:32 -05:00
Kohaku-Blueleaf
cdcf4119b3 [Trainer] training with proper offloading (#12189)
* Fix bypass dtype/device moving

* Force offloading mode for training

* training context var

* offloading implementation in training node

* fix wrong input type

* Support bypass load lora model, correct adapter/offloading handling
2026-02-10 21:45:19 -05:00
AustinMroz
dbe70b6821 Add a VideoSlice node (#12107)
* Base TrimVideo implementation

* Raise error if as_trimmed call fails

* Bigger max start_time, tooltips, and formatting

* Count packets unless codec has subframes

* Remove incorrect nested decode

* Add null check for audio streams

* Support non-strict duration

* Added strict_duration bool to node definition

* Empty commit for approval

* Fix duration

* Support 5.1 audio layout on save

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-10 14:42:21 -08:00
guill
00fff6019e feat(jobs): add 3d to PREVIEWABLE_MEDIA_TYPES for first-class 3D output support (#12381)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-10 14:37:14 -08:00
rattus
123a7874a9 ops: Fix vanilla-fp8 loaded lora quality (#12390)
This was missing the stochastic rounding required for fp8 downcast
to be consistent with model_patcher.patch_weight_to_device.

Missed in testing as I spend too much time with quantized tensors
and overlooked the simpler ones.
2026-02-10 13:38:28 -05:00
rattus
f719f9c062 sd: delay VAE dtype archive until after override (#12388)
VAEs have host specific dtype logic that should override the dynamic
_model_dtype. Defer the archiving of model dtypes until after.
2026-02-10 13:37:46 -05:00
rattus
fe053ba5eb mp: dont deep-clone objects from model_options (#12382)
If there are non-trivial python objects nested in the model_options, this
causes all sorts of issues. Traverse lists and dicts so clones can safely
overide settings and BYO objects but stop there on the deepclone.
2026-02-10 13:37:17 -05:00
comfyanonymous
6648ab68bc ComfyUI v0.13.0 2026-02-10 13:26:29 -05:00
ComfyUI Wiki
6615db925c chore: update workflow templates to v0.8.38 (#12394) 2026-02-10 13:24:56 -05:00
Alexander Piskun
8ca842a8ed feat(api-nodes-Kling): add new models (V3, O3) (#12389)
* feat(api-nodes-Kling): add new models (V3, O3)

* remove storyboard from VideoToVideo node

* added check for total duration of storyboards

* fixed other small things

* updated display name for nodes

* added "fake" seed
2026-02-10 09:34:54 -08:00
Alexander Piskun
c1b63a7e78 fix(Moonvalley-API-Nodes): adjust "steps" parameter to not raise exception (#12370) 2026-02-09 21:58:27 -05:00
ComfyUI Wiki
349a636a2b chore: update workflow templates to v0.8.37 (#12377) 2026-02-09 21:25:34 -05:00
comfyanonymous
a4be04c5d7 Ace step prompts match now. (#12376) 2026-02-09 19:45:56 -05:00
blepping
baf8c87455 Iimprovements to ACE-Steps 1.5 text encoding (part 2) (#12350) 2026-02-09 19:41:49 -05:00
rattus
62315fbb15 Dynamic VRAM fixes - Ace 1.5 performance + a VRAM leak (#12368)
* revert threaded model loader change

This change was only needed to get around the pytorch 2.7 mempool bugs,
and should have been reverted along with #12260. This fixes a different
memory leak where pytorch gets confused about cache emptying.

* load non comfy weights

* MPDynamic: Pre-generate the tensors for vbars

Apparently this is an expensive operation that slows down things.

* bump to aimdo 1.8

New features:
watermark limit feature
logging enhancements
-O2 build on linux
2026-02-09 16:16:08 -05:00
comfyanonymous
a0302cc6a8 Make tonemap latent work on any dim latents. (#12363) 2026-02-08 21:16:40 -05:00
comfyanonymous
f350a84261 Disable prompt weights for ltxv2. (#12354) 2026-02-07 19:16:28 -05:00
ComfyUI Wiki
3760d74005 chore: update embedded docs to v0.4.1 (#12346) 2026-02-07 18:34:52 -05:00
chaObserv
9bf5aa54db Add search_aliases to sa-solver and seeds-2 node (#12327) 2026-02-07 17:38:51 -05:00
Jukka Seppänen
5ff4fdedba Fix LazyCache (#12344) 2026-02-07 11:25:30 -08:00
comfyanonymous
17e7df43d1 Pad ace step 1.5 ref audio if not long enough. (#12341) 2026-02-07 00:02:11 -05:00
comfyanonymous
039955c527 Some fixes to previous pr. (#12339) 2026-02-06 20:14:52 -05:00
tdrussell
6a26328842 Support fp16 for Cosmos-Predict2 and Anima (#12249) 2026-02-06 20:12:15 -05:00
comfyanonymous
204e65b8dc Fix bug with last pr (#12338) 2026-02-06 19:48:20 -05:00
asagi4
a831c19b70 Fix return_word_ids=True with Anima tokenizer (#12328) 2026-02-06 19:38:04 -05:00
comfyanonymous
eba6c940fd Make ace step 1.5 base model work properly with default workflow. (#12337) 2026-02-06 19:14:56 -05:00
Jukka Seppänen
a1c101f861 EasyCache: Support LTX2 (#12231) 2026-02-06 00:43:09 -05:00
comfyanonymous
c2d7f07dbf Fix issue when using disable_unet_model_creation (#12315) 2026-02-05 19:24:09 -05:00
comfyanonymous
458292fef0 Fix some lowvram stuff with ace step 1.5 (#12312) 2026-02-05 19:15:04 -05:00
comfyanonymous
6555dc65b8 Make ace step 1.5 work without the llm. (#12311) 2026-02-05 16:43:45 -05:00
AustinMroz
2b70ab9ad0 Add a Create List node (#12173) 2026-02-05 01:18:21 -05:00
Comfy Org PR Bot
00efcc6cd0 Bump comfyui-frontend-package to 1.38.13 (#12238) 2026-02-05 01:17:37 -05:00
comfyanonymous
cb459573c8 ComfyUI v0.12.3 2026-02-05 01:13:35 -05:00
comfyanonymous
35183543e0 Add VAE tiled decode node for audio. (#12299) 2026-02-05 01:12:04 -05:00
blepping
a246cc02b2 Improvements to ACE-Steps 1.5 text encoding (#12283) 2026-02-05 00:17:37 -05:00
comfyanonymous
a50c32d63f Disable sage attention on ace step 1.5 (#12297) 2026-02-04 22:15:30 -05:00
comfyanonymous
6125b80979 Add llm sampling options and make reference audio work on ace step 1.5 (#12295) 2026-02-04 21:29:22 -05:00
comfyanonymous
c8fcbd66ee Try to fix ace text encoder slowness on some configs. (#12290) 2026-02-04 19:37:05 -05:00
comfyanonymous
26dd7eb421 Fix ace step nan issue on some hardware/pytorch configs. (#12289) 2026-02-04 18:25:06 -05:00
Alexander Piskun
e77b34dfea add File3DAny output to Load3D node; extend SaveGLB to accept File3DAny as input (#12276)
* add File3DAny output to Load3D node; extend SaveGLB node to accept File3DAny as input

* fix(grammar): capitalize letter
2026-02-04 11:35:38 -08:00
rattus
ef73070ea4 mp: Fix checkpoint saving (#12268)
Fix regression in the recent model saving refactor. Pass the non unet
pieces down the layers so that checkpoints are complete.
2026-02-04 02:08:45 -05:00
rattus
d30c609f5a utils: safetensors: dont slice data on torch level (#12266)
Torch has alignment enforcement when viewing with data type changes
but only relative to itself. Do all tensor constructions straight
off the memory-view individually so pytorch doesnt see an alignment
problem.

The is needed for handling misaligned safetensors weights, which are
reasonably common in third party models.

This limits usage of this safetensors loader to GPU compute only
as CPUs kernnel are very likely to bus error. But it works for
dynamic_vram, where we really dont want to take a deep copy and we
always use GPU copy_ which disentangles the misalignment.
2026-02-04 01:48:47 -05:00
comfyanonymous
5087f1d497 ComfyUI v0.12.2 2026-02-04 00:08:59 -05:00
comfyanonymous
a31681564d Fix crash with ace step 1.5 (#12264) 2026-02-04 00:03:21 -05:00
rattus
855849c658 mm: Remove Aimdo exemption for empty_cache (#12260)
Its more important to get the torch caching allocator GC up and running
than supporting the pyt2.7 bug. Switch it on.

Defeature dynamic_vram + pyt2.7.
2026-02-03 21:39:19 -05:00
comfyanonymous
fe2511468d Support the 4B ace step 1.5 lm model. (#12257)
Can be used as an alternative to the 1.7B
2026-02-03 19:01:38 -05:00
comfyanonymous
3be0175166 ComfyUI v0.12.1 2026-02-03 15:01:46 -05:00
comfyanonymous
b8315e66cb Fix tiled vae for ace step 1.5 (#12253) 2026-02-03 14:40:45 -05:00
comfyanonymous
ab1050bec3 Support ace step 1.5 base model loras. (#12252) 2026-02-03 13:54:23 -05:00
Alexander Piskun
fb23935c11 feat(comfy_api): add basic 3D Model file types (#12129)
* feat(comfy_api): add basic 3D Model file types

* update Tripo nodes to use File3DGLB

* update Rodin3D nodes to use File3DGLB

* address PR review feedback:

- Rename File3D parameter 'path' to 'source'
- Convert File3D.data property to get_data()
- Make .glb extension check case-insensitive in nodes_rodin.py
- Restrict SaveGLB node to only accept File3DGLB

* Fixed a bug in the Meshy Rig and Animation nodes

* Fix backward compatability
2026-02-03 10:31:46 -08:00
comfyanonymous
85fc35e8fa Fix mac issue. (#12250) 2026-02-03 12:19:39 -05:00
comfyanonymous
223364743c llama: cast logits as a comfy-weight (#12248)
This is using a different layers weight with .to(). Change it to use
the ops caster if the original layer is a comfy weight so that it picks
up dynamic_vram and async_offload functionality in full.

Co-authored-by: Rattus <rattus128@gmail.com>
2026-02-03 11:31:36 -05:00
comfyanonymous
affe881354 Fix some issues with mac. (#12247) 2026-02-03 11:07:04 -05:00
comfyanonymous
f5030e26fd Add progress bar to ace step. (#12242) 2026-02-03 04:09:30 -05:00
comfyanonymous
66e1b07402 ComfyUI v0.12.0 2026-02-03 02:20:59 -05:00
ComfyUI Wiki
be4345d1c9 chore: update workflow templates to v0.8.31 (#12239) 2026-02-02 23:08:43 -08:00
comfyanonymous
3c1a1a2df8 Basic support for the ace step 1.5 model. (#12237) 2026-02-03 00:06:18 -05:00
Alexander Piskun
ba5bf3f1a8 [API Nodes] HitPaw API nodes (#12117)
* feat(api-nodes): add HitPaw API nodes

* remove face_soft_2x model as not working

---------

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>
2026-02-02 19:17:59 -08:00
comfyanonymous
c05a08ae66 Add back function. (#12234) 2026-02-02 19:52:07 -05:00
rattus
de9ada6a41 Dynamic VRAM unloading fix (#12227)
* mp: fix full dynamic unloading

This was not unloading dynamic models when requesting a full unload via
the unpatch() code path.

This was ok, i your workflow was all dynamic models but fails with big
VRAM leaks if you need to fully unload something for a regular ModelPatcher

It also fices the "unload models" button.

* mm: load models outside of Aimdo Mempool

In dynamic_vram mode, escape the Aimdo mempool and load into the regular
mempool. Use a dummy thread to do it.
2026-02-02 17:35:20 -05:00
rattus
37f711d4a1 mm: Fix cast buffers with intel offloading (#12229)
Intel has offloading support but there were some nvidia calls in the
new cast buffer stuff.
2026-02-02 17:34:46 -05:00
comfyanonymous
dd86b15521 Enable embeddings for some qwen 3 models. (#12218) 2026-02-02 03:51:09 -05:00
comfyanonymous
021ba20719 Fix issue with parameters on root model object. (#12216) 2026-02-01 20:12:52 -05:00
rattus
b60be02aaf requirements: bump comfy-aimdo to 0.1.7 (#12211) 2026-02-01 20:10:15 -05:00
rattus
2b5da3b72e dynamic_vram: silence pytorch buffer warning (#12210)
This is log clutter and concerning to users. Its a false alarm.
2026-02-01 20:09:55 -05:00
rattus
794d05bdb1 dynamic_vram: respect argument cast dtypes in non-comfy weights (#12209)
This function has a dtype argument that allows the caller to set the
dtype in the cast. TIL Some models override this on weight casts, which
means its the highest priority.

Priority scheme is: argument > model dtype > state dict dtype
2026-02-01 20:09:21 -05:00
rattus
361b9a82a3 fix pinning with model defined dtype (#12208)
pinned memory was converted back to pinning the CPU side weight without
any changes. Fix the pinner to use the CPU weight and not the model defined
geometry. This will either save RAM or stop buffer overruns when the types
mismatch.

Fix the model defined weight caster to use the [ s.weight, s.bias ]
interpretation, as xfer_dest might be the flattened pin now. Fix the detection
of needing to cast to not be conditional on !pin.
2026-02-01 08:42:32 -08:00
comfyanonymous
667a1b8878 Fix some custom nodes breaking. (#12203) 2026-02-01 01:55:18 -05:00
Christian Byrne
32621c6a11 fix: improve error message when node type is missing (#12194)
- Change error type from 'invalid_prompt' to 'missing_node_type' for frontend detection
- Add extra_info with node_id, class_type, and node_title (from _meta.title)
- Improve user-facing message: 'Node X not found. The custom node may not be installed.'
2026-02-01 01:13:48 -05:00
rattus
f8acd9c402 Reduce RAM usage, fix VRAM OOMs, and fix Windows shared memory spilling with adaptive model loading (#11845) 2026-02-01 01:01:11 -05:00
comfyanonymous
873de5f37a KV cache implementation for using llama models for text generation. (#12195) 2026-01-31 21:11:11 -05:00
Jedrzej Kosinski
aa6f7a83bb Send is_input_list on v1 and v3 schema to frontend (#12188) 2026-01-31 20:05:11 -05:00
Jedrzej Kosinski
6ea8c128a3 Assets Part 2 - add more endpoints (#12125) 2026-01-31 02:22:05 -05:00
Alexander Piskun
6e469a3f35 feat(api-nodes): add Q3 models and support for Extend and MultiFrame Vidu endpoints (#12175)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-30 22:44:08 -08:00
comfyanonymous
b8f848bfe3 Fix model not working with any res. (#12186) 2026-01-31 00:12:48 -05:00
comfyanonymous
4064062e7d Update python patch version in dep workflow. (#12184) 2026-01-30 20:20:06 -05:00
pythongosssss
8aabe2403e Add color type and Color to RGB Int node (#12145)
* add color type and color to rgb int node

* review fix for allowing output

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-30 15:01:33 -08:00
Alexander Piskun
0167653781 feat(api-nodes): add RecraftCreateStyleNode node (#12055)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-30 14:04:43 -08:00
Jedrzej Kosinski
0a7993729c Remove NodeInfoV3-related code; we are almost 100% guaranteed to stick with NodeInfoV1 for the foreseable future (#12147)
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-01-30 10:21:48 -08:00
comfyanonymous
bbe2c13a70 Make empty hunyuan latent 1.0 work with the 1.5 model. (#12171) 2026-01-29 23:52:22 -05:00
Christian Byrne
3aace5c8dc fix: count non-dict items in outputs_count (#12166)
Move count increment before isinstance(item, dict) check so that
non-dict output items (like text strings from PreviewAny node)
are included in outputs_count.

This aligns OSS Python with Cloud's Go implementation which uses
len(itemsArray) to count ALL items regardless of type.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0bb5-14e0-744f-8808-1e57653f3ae3

Co-authored-by: Amp <amp@ampcode.com>
2026-01-29 17:10:08 -08:00
comfyanonymous
b0d9708974 ComfyUI v0.11.1 2026-01-29 00:27:23 -05:00
comfyanonymous
c9b633d84f Add missing spacial downscale ratios. (#12146) 2026-01-28 20:52:51 -05:00
ComfyUI Wiki
1711020904 chore: update workflow templates to v0.8.27 (#12141) 2026-01-28 12:48:02 -05:00
Dr.Lt.Data
d9b8567547 bump manager version to 4.1b1 (#12140) 2026-01-28 12:47:37 -05:00
Alexander Piskun
6c5f906bf2 feat(api-nodes): add Grok Imagine nodes (#12136) 2026-01-28 12:46:57 -05:00
comfyanonymous
4f5bd39b1c Update Python 3.14 compatibility notes in README (#12127) 2026-01-27 19:58:48 -05:00
guill
dcff27fe3f Add support for dev-only nodes. (#12106)
When a node is declared as dev-only, it doesn't show in the default UI
unless the dev mode is enabled in the settings. The intention is to
allow nodes related to unit testing to be included in ComfyUI
distributions without confusing the average user.
2026-01-27 13:03:29 -08:00
comfyanonymous
09725967cf ComfyUI version v0.11.0 2026-01-26 23:08:01 -05:00
ComfyUI Wiki
5f62440fbb chore: update workflow templates to v0.8.24 (#12103) 2026-01-26 22:47:33 -05:00
ComfyUI Wiki
ac91c340f4 Update workflow templates to v0.8.23 (#12102) 2026-01-26 21:39:39 -05:00
comfyanonymous
2db3b0ff90 Update amd portable for rocm 7.2 (#12101)
* Update amd portable for rocm 7.2

* Update Python patch version in release workflow
2026-01-26 19:49:31 -05:00
rattus
6516ab335d wan-vae: Switch off feature cache for single frame (#12090)
The code throughout is None safe to just skip the feature cache saving
step if none. Set it none in single frame use so qwen doesn't burn VRAM
on the unused cache.
2026-01-26 19:40:19 -05:00
Jukka Seppänen
ad53e78f11 Fix Noise_EmptyNoise when using nested latents (#12089) 2026-01-26 19:25:00 -05:00
Alexander Piskun
29011ba87e [API Nodes] add Magnific nodes (#11986)
* feat(api-nodes): add Magnific nodes

* aggressive downscaling should not be performed

* disable upscaler nodes

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-26 14:10:09 -08:00
Alexander Piskun
cd4985e2f3 chore(api-nodes): remove ByteDanceImageEditNode node (seededit) (#12069)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-26 13:58:33 -08:00
Tavi Halperin
bfe31d0b9d IC-LoRA: support small grid (#12074) 2026-01-26 15:33:19 -05:00
comfyanonymous
2129e7d278 Fix mistral 3 tokenizer code failing on latest transformers version and other breakage. (#12095)
* Fix mistral 3 tokenizer code failing on latest transformers version.

* Add requests to the requirements
2026-01-26 11:39:00 -05:00
comfyanonymous
7ee77ff038 Add name to LoraLoaderModelOnly. (#12078) 2026-01-25 21:01:55 -05:00
comfyanonymous
26c5bbb875 Move nodes from previous PR into their own file. (#12066) 2026-01-24 23:02:32 -05:00
Kohaku-Blueleaf
a97c98068f [Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
2026-01-24 22:56:22 -05:00
comfyanonymous
635406e283 Only enable fp16 on z image models that actually support it. (#12065) 2026-01-24 22:32:28 -05:00
pythongosssss
ed6002cb60 add support for kwargs inputs to allow arbitrary inputs from frontend (#12063)
used to output selected combo index

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-24 17:30:40 -08:00
Alexander Piskun
bc72d7f8d1 [API Nodes] add TencentHunyuan3D nodes (#12026)
* feat(api-nodes): add TencentHunyuan3D nodes

* add "(Pro)" to display name

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-24 17:10:09 -08:00
comfyanonymous
aef4e13588 Make empty latent node work with other models. (#12062) 2026-01-24 19:23:20 -05:00
rattus
4e6a1b66a9 speed up and reduce VRAM of QWEN VAE and WAN (less so) (#12036)
* ops: introduce autopad for conv3d

This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.

* wan-vae: rework causal padding

This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.

* wan-vae: implement zero pad fast path

The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.
2026-01-23 19:56:14 -05:00
comfyanonymous
9cf299a9f9 Make regular empty latent node work properly on flux 2 variants. (#12050) 2026-01-23 19:50:48 -05:00
ComfyUI Wiki
e89b22993a Support ModelScope-Trainer/DiffSynth LoRA format for Flux.2 Klein models (#12042) 2026-01-23 15:27:49 -05:00
Jukka Seppänen
55bd606e92 LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046)
* Disable timestep embed compression when inpainting

Spatial inpainting not compatible with the compression

* Reduce crossattn peak VRAM

* LTX2: Refactor forward function for better VRAM efficiency
2026-01-23 15:26:38 -05:00
Christian Byrne
79cdbc81cb feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12040)
- Add search_aliases for discoverability: resize, scale, dimensions, etc.
- Add node description for hover tooltip
- Add tooltips to all inputs explaining their behavior
- Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last

Addresses user feedback that 'resize' search returned nothing useful and
options like 'match size' and 'scale to multiple' were not self-explanatory.
2026-01-22 22:04:27 -08:00
comfyanonymous
f443b9f2ca Revert "feat: Improve ResizeImageMaskNode UX with tooltips and search aliases…" (#12038)
This reverts commit 4e3038114a.
2026-01-22 23:02:37 -05:00
Christian Byrne
4e3038114a feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12013)
- Add search_aliases for discoverability: resize, scale, dimensions, etc.
- Add node description for hover tooltip
- Add tooltips to all inputs explaining their behavior
- Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last

Addresses user feedback that 'resize' search returned nothing useful and
options like 'match size' and 'scale to multiple' were not self-explanatory.
2026-01-22 18:46:55 -08:00
Christian Byrne
bbb8864778 add search aliases to all nodes (#12035)
* feat: Add search_aliases field to node schema

Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate).

Changes:
- Add `search_aliases: list[str]` to V3 Schema
- Add `SEARCH_ALIASES` support for V1 nodes
- Include field in `/object_info` response
- Add aliases to high-priority core nodes

V1 usage:
```python
class MyNode:
    SEARCH_ALIASES = ["alt name", "synonym"]
```

V3 usage:
```python
io.Schema(
    node_id="MyNode",
    search_aliases=["alt name", "synonym"],
    ...
)
```

## Related PRs
- Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this)
- Docs: Comfy-Org/docs#XXXX (draft - merge after stable)

* Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1

* feat: add SEARCH_ALIASES for core nodes (#12016)

Add search aliases to 22 core nodes in nodes.py to improve node discoverability:
- Checkpoint/model loaders: CheckpointLoader, DiffusersLoader
- Conditioning nodes: ConditioningAverage, ConditioningSetArea, ConditioningSetMask, ConditioningZeroOut
- Style nodes: StyleModelApply
- Image nodes: LoadImageMask, LoadImageOutput, ImageBatch, ImageInvert, ImagePadForOutpaint
- Latent nodes: LoadLatent, SaveLatent, LatentBlend, LatentComposite, LatentCrop, LatentFlip, LatentFromBatch, LatentUpscale, LatentUpscaleBy, RepeatLatentBatch

* feat: add SEARCH_ALIASES for image, mask, and string nodes (#12017)

Add search aliases to nodes in comfy_extras for better discoverability:
- nodes_mask.py: mask manipulation nodes
- nodes_images.py: image processing nodes
- nodes_post_processing.py: post-processing effect nodes
- nodes_string.py: string manipulation nodes
- nodes_compositing.py: compositing nodes
- nodes_morphology.py: morphological operation nodes
- nodes_latent.py: latent space nodes

Uses search_aliases parameter in io.Schema() for v3 nodes.

* feat: add SEARCH_ALIASES for audio and video nodes (#12018)

Add search aliases to audio and video nodes for better discoverability:
- nodes_audio.py: audio loading, saving, and processing nodes
- nodes_video.py: video loading and processing nodes
- nodes_wan.py: WAN model nodes

Uses search_aliases parameter in io.Schema() for v3 nodes.

* feat: add SEARCH_ALIASES for model and misc nodes (#12019)

Add search aliases to model-related and miscellaneous nodes:
- Model nodes: nodes_model_merging.py, nodes_model_advanced.py, nodes_lora_extract.py
- Sampler nodes: nodes_custom_sampler.py, nodes_align_your_steps.py
- Control nodes: nodes_controlnet.py, nodes_attention_multiply.py, nodes_hooks.py
- Training nodes: nodes_train.py, nodes_dataset.py
- Utility nodes: nodes_logic.py, nodes_canny.py, nodes_differential_diffusion.py
- Architecture-specific: nodes_sd3.py, nodes_pixart.py, nodes_lumina2.py, nodes_kandinsky5.py, nodes_hidream.py, nodes_fresca.py, nodes_hunyuan3d.py
- Media nodes: nodes_load_3d.py, nodes_webcam.py, nodes_preview_any.py, nodes_wanmove.py

Uses search_aliases parameter in io.Schema() for v3 nodes, SEARCH_ALIASES class attribute for legacy nodes.
2026-01-22 18:36:58 -08:00
Omri Marom
d7f3241bf6 qwen_image: propagate attention mask. (#11966) 2026-01-22 20:02:31 -05:00
comfyanonymous
09a2e67151 Support loading flux 2 klein checkpoints saved with SaveCheckpoint. (#12033) 2026-01-22 18:20:48 -05:00
rattus
0fd1b78736 Reduce LTX2 VAE VRAM consumption (#12028)
* causal_video_ae: Remove attention ResNet

This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.

* ltx-vae: consoldate causal/non-causal code paths

* ltx-vae: add cache rolling adder

* ltx-vae: use cached adder for resnet

* ltx-vae: Implement rolling VAE

Implement a temporal rolling VAE for the LTX2 VAE.

Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.

So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.

Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.

Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.

Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
2026-01-22 16:54:18 -05:00
Terry Jia
8490eedadf add ply & 3dgs format in 3d node (#11474) 2026-01-22 09:46:56 -08:00
Alexander Piskun
72f6be1690 chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022) 2026-01-21 23:42:04 -08:00
Jukka Seppänen
16b9aabd52 Support Multi/InfiniteTalk (#10179)
* re-init

* Update model_multitalk.py

* whitespace...

* Update model_multitalk.py

* remove print

* this is redundant

* remove import

* Restore preview functionality

* Move block_idx to transformer_options

* Remove LoopingSamplerCustomAdvanced

* Remove looping functionality, keep extension functionality

* Update model_multitalk.py

* Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn

* Chunk attention map calculation for multiple speakers to reduce peak VRAM usage

* Update model_multitalk.py

* Add ModelPatch type back

* Fix for latest upstream

* Use DynamicCombo for cleaner node

Basically just so that single_speaker mode hides mask inputs and 2nd audio input

* Update nodes_wan.py
2026-01-21 23:09:48 -05:00
Jukka Seppänen
245f6139b6 More targeted embedding_connector loading for LTX2 text encoder (#11992)
Reduces errors
2026-01-21 23:05:06 -05:00
Jukka Seppänen
3365ad18a5 Support LTX2 tiny vae (taeltx_2) (#11929) 2026-01-21 23:03:51 -05:00
Jedrzej Kosinski
f09904720d Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) 2026-01-21 23:01:35 -05:00
comfyanonymous
abe2ec26a6 Support the Anima model. (#12012) 2026-01-21 19:44:28 -05:00
Christian Byrne
bdeac8897e feat: Add search_aliases field to node schema (#12010)
* feat: Add search_aliases field to node schema

Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate).

Changes:
- Add `search_aliases: list[str]` to V3 Schema
- Add `SEARCH_ALIASES` support for V1 nodes
- Include field in `/object_info` response
- Add aliases to high-priority core nodes

V1 usage:
```python
class MyNode:
    SEARCH_ALIASES = ["alt name", "synonym"]
```

V3 usage:
```python
io.Schema(
    node_id="MyNode",
    search_aliases=["alt name", "synonym"],
    ...
)
```

## Related PRs
- Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this)
- Docs: Comfy-Org/docs#XXXX (draft - merge after stable)

* Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1
2026-01-21 15:36:02 -08:00
Alexander Piskun
451af70154 fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002) 2026-01-21 04:03:45 -08:00
Markury
0fc15700be Add LyCoris LoKr MLP layer support for Flux2 (#11997) 2026-01-20 23:18:33 -05:00
comfyanonymous
e755268e7b Config for Qwen 3 0.6B model. (#11998) 2026-01-20 23:08:31 -05:00
Mylo
c4a14df9a3 Dynamically detect chroma radiance patch size (#11991) 2026-01-20 18:46:11 -05:00
Ivan Zorin
965d0ed509 fix: remove normalization of audio in LTX Mel spectrogram creation (#11990)
For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation.
This aligs inference with training and prevents loud audio from being attenuated.
2026-01-20 18:44:28 -05:00
Alexander Piskun
ddc541ffda feat(api-nodes): add WaveSpeed nodes (#11945) 2026-01-20 13:05:40 -08:00
comfyanonymous
8ccc0c94fa Make omni stuff work on regular z image for easier testing. (#11985) 2026-01-20 00:32:00 -05:00
Comfy Org PR Bot
4edb87aa50 Bump comfyui-frontend-package to 1.37.11 (#11976) 2026-01-19 23:57:50 -05:00
ComfyUI Wiki
0fc3b6e3a6 chore: update workflow templates to v0.8.15 (#11984) 2026-01-19 23:17:56 -05:00
comfyanonymous
2108167f9f Support zimage omni base model. (#11979) 2026-01-19 23:17:38 -05:00
comfyanonymous
9d273d3ab1 ComfyUI v0.10.0 2026-01-19 22:40:18 -05:00
comfyanonymous
70c91b8248 Fix #11963 (#11982) 2026-01-19 22:32:40 -05:00
rkfg
0da5a0fe58 Convert mono audio to fake stereo for LTXV VAE encoding (#11965) 2026-01-19 22:12:02 -05:00
comfyanonymous
e0eacb0688 Simpler way to implement the #11980 loras. (#11981) 2026-01-19 22:00:36 -05:00
Jedrzej Kosinski
7458e20465 Make Autogrow validation work properly (#11977)
* In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node)

* Allow autogrow to work with all inputs being optional

* Revert accidentally pushed changes to nodes_logic.py
2026-01-19 16:58:30 -08:00
Jedrzej Kosinski
b931b37e30 feat(api-nodes): add Bria Edit node (#11978)
Co-authored-by: Alexander Piskun <bigcat88@icloud.com>
2026-01-19 16:47:14 -08:00
ComfyUI Wiki
866a4619db chore: update workflow templates to v0.8.14 (#11974) 2026-01-19 14:21:35 -08:00
comfyanonymous
1a72bf2046 Readme update. (#11957) 2026-01-18 19:53:43 -08:00
Alexander Piskun
034fac7054 chore(api-nodes): auto-discover all nodes_*.py files to avoid merge conflicts when adding new API nodes (#11943) 2026-01-17 22:40:39 -08:00
Christian Byrne
a498556d0d feat: add advanced parameter to Input classes for advanced widgets support (#11939)
Add 'advanced' boolean parameter to Input and WidgetInput base classes
and propagate to all typed Input subclasses (Boolean, Int, Float, String,
Combo, MultiCombo, Webcam, MultiType, MatchType, ImageCompare).

When set to True, the frontend will hide these inputs by default in a
collapsible 'Advanced Inputs' section in the right side panel, reducing
visual clutter for power-user options.

This enables nodes to expose advanced configuration options (like encoding
parameters, quality settings, etc.) without overwhelming typical users.

Frontend support: ComfyUI_frontend PR #7812
2026-01-17 19:06:03 -08:00
Alexander Piskun
f7ca41ff62 chore(api-nodes): remove check for pyav>=14.2 in code (it was added to requirements.txt long ago) (#11934) 2026-01-17 18:57:57 -08:00
Alexander Piskun
ac26065e61 chore(api-nodes): remove non-used; extract model to separate files (#11927)
* chore(api-nodes): remove non-used; extract model to separate files

* chore(api-nodes): remove non-needed prefix in filenames
2026-01-17 18:52:45 -08:00
comfyanonymous
190c4416cc Bump comfy-kitchen dependency to version 0.2.7 (#11941) 2026-01-17 21:20:35 -05:00
Theephop
0fd10ffa09 fix: use .cpu() for waveform conversion in AudioFrame creation (#11787) 2026-01-17 20:18:24 -05:00
Alex Butler
00c775950a Update readme rdna3 nightly url (#11937) 2026-01-17 20:18:04 -05:00
comfyanonymous
7ac999bf30 Add image sizes to clip vision outputs. (#11923) 2026-01-16 23:02:28 -05:00
ComfyUI Wiki
0c6b36c6ac chore: update workflow templates to v0.8.11 (#11918) 2026-01-16 17:22:50 -05:00
Alexander Piskun
9125613b53 feat(api-nodes): extend ByteDance nodes with seedance-1-5-pro model (#11871) 2026-01-15 22:09:07 -08:00
Jedrzej Kosinski
732b707397 Added try-except around seed_assets call in get_object_info with a logging statement (#11901) 2026-01-15 23:15:15 -05:00
comfyanonymous
4c816d5c69 Adjust memory usage factor calculation for flux2 klein. (#11900) 2026-01-15 20:06:40 -05:00
ComfyUI Wiki
6125b3a5e7 Update workflow templates to v0.8.10 (#11899)
* chore: update workflow templates to v0.8.9

* Update requirements.txt
2026-01-15 13:12:13 -08:00
ComfyUI Wiki
12918a5f78 chore: update workflow templates to v0.8.7 (#11896) 2026-01-15 11:08:21 -08:00
comfyanonymous
8f40b43e02 ComfyUI v0.9.2 2026-01-15 10:57:35 -05:00
comfyanonymous
3b832231bb Flux2 Klein support. (#11890) 2026-01-15 10:33:15 -05:00
Jukka Seppänen
be518db5a7 Remove extraneous clip missing warnings when loading LTX2 embeddings_connector weights (#11874) 2026-01-14 17:54:04 -05:00
rattus
80441eb15e utils: fix lanczos grayscale upscaling (#11873) 2026-01-14 17:53:16 -05:00
Alexander Piskun
07f2462eae feat(api-nodes): add Meshy 3D nodes (#11843)
* feat(api-nodes): add Meshy 3D nodes

* rebased, added JSONata price badges
2026-01-14 11:25:38 -08:00
comfyanonymous
d150440466 Fix VAELoader (#11880) 2026-01-14 10:54:50 -08:00
comfyanonymous
6165c38cb5 Optimize nvfp4 lora applying. (#11866)
This changes results a bit but it also speeds up things a lot.
2026-01-14 00:49:38 -05:00
Silver
712cca36a1 feat: throttle ProgressBar updates to reduce WebSocket flooding (#11504) 2026-01-13 22:41:44 -05:00
Johnpaul Chiwetelu
ac4d8ea9b3 feat: add CI container version bump automation (#11692)
* feat: add CI container version bump automation

Adds a workflow that triggers on releases to create PRs in the
comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile.

Supports both release events and manual workflow dispatch for testing.

* feat: add CI container version bump automation

Adds a workflow that triggers on releases to create PRs in the
comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile.

Supports both release events and manual workflow dispatch for testing.

* ci: update CI container repository owner

* refactor: rename `update-ci-container.yaml` workflow to `update-ci-container.yml`

* Remove post-merge instructions from the CI container update workflow.
2026-01-13 22:39:22 -05:00
nomadoor
c9196f355e Fix scale_shorter_dimension portrait check (#11862) 2026-01-13 18:25:09 -08:00
Christian Byrne
7eb959ce93 fix: update ComfyUI repo reference to Comfy-Org/ComfyUI (#11858) 2026-01-13 21:03:16 -05:00
nomadoor
469dd9c16a Adds crop to multiple mode to ResizeImageMaskNode. (#11838)
* Add crop-to-multiple resize mode

* Make scale-to-multiple shape handling explicit
2026-01-13 16:48:10 -08:00
comfyanonymous
eff2b9d412 Optimize nvfp4 lora applying. (#11856) 2026-01-13 19:37:19 -05:00
comfyanonymous
15b312de7a Optimize nvfp4 lora applying. (#11854) 2026-01-13 19:23:58 -05:00
Alexander Piskun
1419047fdb [Api Nodes]: Improve Price Badge Declarations (#11582)
* api nodes: price badges moved to nodes code

* added price badges for 4 more node-packs

* added price badges for 10 more node-packs

* added new price badges for Omni STD mode

* add support for autogrow groups

* use full names for "widgets", "inputs" and "groups"

* add strict typing for JSONata rules

* add price badge for WanReferenceVideoApi node

* add support for DynamicCombo

* sync price badges changes (https://github.com/Comfy-Org/ComfyUI_frontend/pull/7900)

* sync badges for Vidu2 nodes

* fixed incorrect price for RecraftCrispUpscaleNode

* fixed incorrect price badges for LTXV nodes

* fixed price badge for MinimaxHailuoVideoNode

* fixed price badges for PixVerse nodes
2026-01-13 16:18:28 -08:00
ric-yu
79f6bb5e4f add blueprints dir for built-in blueprints (#11853) 2026-01-13 16:14:40 -08:00
Jukka Seppänen
e4b4fb3479 Load metadata on VAELoader (#11846)
Needed to load the proper LTX2 VAE if separated from checkpoint
2026-01-13 17:37:21 -05:00
Acly
d9dc02a7d6 Support "lite" version of alibaba-pai Z-Image Controlnet (#11849)
* reduced number of control layers (3) compared to full model
2026-01-13 15:03:53 -05:00
Alexander Piskun
c543ad81c3 fix(api-nodes-gemini): raise exception when no candidates due to safety block (#11848) 2026-01-13 08:30:13 -08:00
comfyanonymous
5ac1372533 ComfyUI v0.9.1 2026-01-13 01:44:06 -05:00
comfyanonymous
1dcbd9efaf Bump ltxav mem estimation a bit. (#11842) 2026-01-13 01:42:07 -05:00
327 changed files with 29924 additions and 2822 deletions

127
.coderabbit.yaml Normal file
View File

@@ -0,0 +1,127 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: false
poem: false
review_status: false
review_details: false
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
sequence_diagrams: false
estimate_code_review_effort: false
assess_linked_issues: false
related_issues: false
related_prs: false
suggested_labels: false
auto_apply_labels: false
suggested_reviewers: false
auto_assign_reviewers: false
in_progress_fortune: false
enable_prompt_for_ai_agents: true
path_filters:
- "!comfy_api_nodes/apis/**"
- "!**/generated/*.pyi"
- "!.ci/**"
- "!script_examples/**"
- "!**/__pycache__/**"
- "!**/*.ipynb"
- "!**/*.png"
- "!**/*.bat"
path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**"
instructions: |
Core ML/diffusion engine. Focus on:
- Backward compatibility (breaking changes affect all custom nodes)
- Memory management and GPU resource handling
- Performance implications in hot paths
- Thread safety for concurrent execution
- path: "comfy_api_nodes/**"
instructions: |
Third-party API integration nodes. Focus on:
- No hardcoded API keys or secrets
- Proper error handling for API failures (timeouts, rate limits, auth errors)
- Correct Pydantic model usage
- Security of user data passed to external APIs
- path: "comfy_extras/**"
instructions: |
Community-contributed extra nodes. Focus on:
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
- No breaking changes to existing node interfaces
- path: "comfy_execution/**"
instructions: |
Execution engine (graph execution, caching, jobs). Focus on:
- Caching correctness
- Concurrent execution safety
- Graph validation edge cases
- path: "nodes.py"
instructions: |
Core node definitions (2500+ lines). Focus on:
- Backward compatibility of NODE_CLASS_MAPPINGS
- Consistency of INPUT_TYPES return format
- path: "alembic_db/**"
instructions: |
Database migrations. Focus on:
- Migration safety and rollback support
- Data preservation during schema changes
auto_review:
enabled: true
auto_incremental_review: true
drafts: false
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches:
docstrings:
enabled: false
unit_tests:
enabled: false
tools:
ruff:
enabled: false
pylint:
enabled: false
flake8:
enabled: false
gitleaks:
enabled: true
shellcheck:
enabled: false
markdownlint:
enabled: false
yamllint:
enabled: false
languagetool:
enabled: false
github-checks:
enabled: true
timeout_ms: 90000
ast-grep:
essential_rules: true
chat:
auto_reply: true
knowledge_base:
opt_out: false
learnings:
scope: "auto"

View File

@@ -16,7 +16,7 @@ body:
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@@ -20,7 +20,7 @@ jobs:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "9"
python_patch: "11"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
@@ -65,11 +65,11 @@ jobs:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 7.1.1"
name: "Release AMD ROCm 7.2"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm711"
cache_tag: "rocm72"
python_minor: "12"
python_patch: "10"
rel_name: "amd"

View File

@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@@ -13,7 +13,7 @@ jobs:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
repository: "Comfy-Org/ComfyUI"
path: "ComfyUI"
- uses: actions/setup-python@v4
with:

View File

@@ -0,0 +1,59 @@
name: "CI: Update CI Container"
on:
release:
types: [published]
workflow_dispatch:
inputs:
version:
description: 'ComfyUI version (e.g., v0.7.0)'
required: true
type: string
jobs:
update-ci-container:
runs-on: ubuntu-latest
# Skip pre-releases unless manually triggered
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
steps:
- name: Get version
id: version
run: |
if [ "${{ github.event_name }}" = "release" ]; then
VERSION="${{ github.event.release.tag_name }}"
else
VERSION="${{ inputs.version }}"
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Checkout comfyui-ci-container
uses: actions/checkout@v4
with:
repository: comfy-org/comfyui-ci-container
token: ${{ secrets.CI_CONTAINER_PAT }}
- name: Check current version
id: current
run: |
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
- name: Update Dockerfile
run: |
VERSION="${{ steps.version.outputs.version }}"
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
- name: Create Pull Request
id: create-pr
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.CI_CONTAINER_PAT }}
branch: automation/comfyui-${{ steps.version.outputs.version }}
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
body: |
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
labels: automation
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"

View File

@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "11"
# push:
# branches:
# - master

2
.gitignore vendored
View File

@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
venv/
venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example

View File

@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
@@ -208,11 +206,11 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions:
@@ -227,11 +225,11 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
@@ -240,7 +238,7 @@ These have less hardware support than the builds above but they work on windows.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):

View File

@@ -1,5 +1,8 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
@@ -8,6 +11,9 @@ import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
@@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
@@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
@@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
file_size = os.path.getsize(abs_path)
logging.info(
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
abs_path,
file_size,
file_size / (1024 * 1024),
content_type,
filename,
)
async def file_sender():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
yield chunk
return web.Response(
body=file_sender(),
content_type=content_type,
headers={
"Content-Disposition": cd,
"Content-Length": str(file_size),
},
)
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
@@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets_endpoint(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
try:
payload = await request.json()
roots = payload.get("roots", ["models", "input", "output"])
except Exception:
roots = ["models", "input", "output"]
valid_roots = [r for r in roots if r in ("models", "input", "output")]
if not valid_roots:
return _error_response(400, "INVALID_BODY", "No valid roots specified")
try:
seed_assets(tuple(valid_roots))
except Exception:
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
return web.json_response({"seeded": valid_roots}, status=200)

View File

@@ -1,5 +1,4 @@
import json
import uuid
from typing import Any, Literal
from pydantic import (
@@ -8,9 +7,9 @@ from pydantic import (
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
return None
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
return v.lower() or None
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("preview_id", mode="before")
@field_validator("tags")
@classmethod
def _norm_uuid(cls, v):
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip()
s = str(v).strip().lower()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self

View File

@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)

View File

@@ -1,9 +1,17 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from sqlalchemy import select, exists, func
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
@@ -42,6 +66,7 @@ def apply_tag_filters(
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
@@ -94,7 +119,11 @@ def apply_metadata_filter(
return stmt
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
).first()
return row is not None
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
@@ -171,12 +230,14 @@ def list_asset_infos_page(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
@@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
@@ -265,3 +814,163 @@ def list_tags_with_usage(
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -1,5 +1,6 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,13 +1,33 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
@@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
return "created_at"
def asset_exists(asset_hash: str) -> bool:
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
@@ -63,7 +100,6 @@ def list_assets(
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
@@ -76,7 +112,12 @@ def list_assets(
has_more=(offset + len(summaries)) < total,
)
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
@@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,

View File

@@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
@@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
@@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,

View File

@@ -17,7 +17,7 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
@@ -45,25 +45,7 @@ def get_installed_frontend_version():
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-frontend-package", None)
def check_frontend_version():
@@ -217,25 +199,7 @@ class FrontendManager:
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-workflow-templates", None)
@classmethod
def default_frontend_path(cls) -> str:

107
app/node_replace_manager.py Normal file
View File

@@ -0,0 +1,107 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
if "class_type" not in node_struct or "inputs" not in node_struct:
continue
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@@ -10,6 +10,7 @@ import hashlib
class Source:
custom_node = "custom_node"
templates = "templates"
class SubgraphEntry(TypedDict):
source: str
@@ -38,9 +39,21 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": source,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": {"node_pack": node_pack},
}
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
with open(entry['path'], 'r', encoding='utf-8') as f:
entry['data'] = f.read()
return entry
@@ -60,53 +73,60 @@ class SubgraphManager:
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
"""Load subgraphs from custom nodes."""
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
pattern = os.path.join(folder, "*/subgraphs/*.json")
for file in glob.glob(pattern):
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
node_pack = "custom_nodes." + file.split('/')[-3]
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
subgraphs_dict[entry_id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry
self.cached_blueprint_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_all_subgraphs(self, loadedModules, force_reload=False):
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
return {**custom_node_subgraphs, **blueprint_subgraphs}
async def get_subgraph(self, id: str, loadedModules):
"""Get a specific subgraph by ID from any source."""
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
if entry is not None and entry.get('data') is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
subgraph = await self.get_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@@ -0,0 +1,44 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Brightness slider -100..100
uniform float u_float1; // Contrast slider -100..100
in vec2 v_texCoord;
out vec4 fragColor;
const float MID_GRAY = 0.18; // 18% reflectance
// sRGB gamma 2.2 approximation
vec3 srgbToLinear(vec3 c) {
return pow(max(c, 0.0), vec3(2.2));
}
vec3 linearToSrgb(vec3 c) {
return pow(max(c, 0.0), vec3(1.0/2.2));
}
float mapBrightness(float b) {
return clamp(b / 100.0, -1.0, 1.0);
}
float mapContrast(float c) {
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
}
void main() {
vec4 orig = texture(u_image0, v_texCoord);
float brightness = mapBrightness(u_float0);
float contrast = mapContrast(u_float1);
vec3 lin = srgbToLinear(orig.rgb);
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
// Convert back to sRGB
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
fragColor = vec4(result, orig.a);
}

View File

@@ -0,0 +1,72 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Mode
uniform float u_float0; // Amount (0 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const int MODE_LINEAR = 0;
const int MODE_RADIAL = 1;
const int MODE_BARREL = 2;
const int MODE_SWIRL = 3;
const int MODE_DIAGONAL = 4;
const float AMOUNT_SCALE = 0.0005;
const float RADIAL_MULT = 4.0;
const float BARREL_MULT = 8.0;
const float INV_SQRT2 = 0.70710678118;
void main() {
vec2 uv = v_texCoord;
vec4 original = texture(u_image0, uv);
float amount = u_float0 * AMOUNT_SCALE;
if (amount < 0.000001) {
fragColor = original;
return;
}
// Aspect-corrected coordinates for circular effects
float aspect = u_resolution.x / u_resolution.y;
vec2 centered = uv - 0.5;
vec2 corrected = vec2(centered.x * aspect, centered.y);
float r = length(corrected);
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
vec2 offset = vec2(0.0);
if (u_int0 == MODE_LINEAR) {
// Horizontal shift (no aspect correction needed)
offset = vec2(amount, 0.0);
}
else if (u_int0 == MODE_RADIAL) {
// Outward from center, stronger at edges
offset = dir * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_BARREL) {
// Lens distortion simulation (r² falloff)
offset = dir * r * r * amount * BARREL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_SWIRL) {
// Perpendicular to radial (rotational aberration)
vec2 perp = vec2(-dir.y, dir.x);
offset = perp * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_DIAGONAL) {
// 45° offset (no aspect correction needed)
offset = vec2(amount, amount) * INV_SQRT2;
}
float red = texture(u_image0, uv + offset).r;
float green = original.g;
float blue = texture(u_image0, uv - offset).b;
fragColor = vec4(red, green, blue, original.a);
}

View File

@@ -0,0 +1,78 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // temperature (-100 to 100)
uniform float u_float1; // tint (-100 to 100)
uniform float u_float2; // vibrance (-100 to 100)
uniform float u_float3; // saturation (-100 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const float INPUT_SCALE = 0.01;
const float TEMP_TINT_PRIMARY = 0.3;
const float TEMP_TINT_SECONDARY = 0.15;
const float VIBRANCE_BOOST = 2.0;
const float SATURATION_BOOST = 2.0;
const float SKIN_PROTECTION = 0.5;
const float EPSILON = 0.001;
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
// Scale inputs: -100/100 → -1/1
float temperature = u_float0 * INPUT_SCALE;
float tint = u_float1 * INPUT_SCALE;
float vibrance = u_float2 * INPUT_SCALE;
float saturation = u_float3 * INPUT_SCALE;
// Temperature (warm/cool): positive = warm, negative = cool
color.r += temperature * TEMP_TINT_PRIMARY;
color.b -= temperature * TEMP_TINT_PRIMARY;
// Tint (green/magenta): positive = green, negative = magenta
color.g += tint * TEMP_TINT_PRIMARY;
color.r -= tint * TEMP_TINT_SECONDARY;
color.b -= tint * TEMP_TINT_SECONDARY;
// Single clamp after temperature/tint
color = clamp(color, 0.0, 1.0);
// Vibrance with skin protection
if (vibrance != 0.0) {
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float sat = maxC - minC;
float gray = dot(color, LUMA_WEIGHTS);
if (vibrance < 0.0) {
// Desaturate: -100 → gray
color = mix(vec3(gray), color, 1.0 + vibrance);
} else {
// Boost less saturated colors more
float vibranceAmt = vibrance * (1.0 - sat);
// Branchless skin tone protection
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
float warmth = (color.r - color.b) / max(maxC, EPSILON);
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
}
}
// Saturation
if (saturation != 0.0) {
float gray = dot(color, LUMA_WEIGHTS);
float satMix = saturation < 0.0
? 1.0 + saturation // -100 → gray
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
color = mix(vec3(gray), color, satMix);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -0,0 +1,94 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Blur radius (020, default ~5)
uniform float u_float1; // Edge threshold (0100, default ~30)
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
in vec2 v_texCoord;
out vec4 fragColor;
const int MAX_RADIUS = 20;
const float EPSILON = 0.0001;
// Perceptual luminance
float getLuminance(vec3 rgb) {
return dot(rgb, vec3(0.299, 0.587, 0.114));
}
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
float sigmaSpatial, float sigmaColor)
{
vec4 center = texture(u_image0, uv);
vec3 centerRGB = center.rgb;
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
vec3 sumRGB = vec3(0.0);
float sumWeight = 0.0;
int step = max(u_int0, 1);
float radius2 = float(radius * radius);
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
if (dy < -radius || dy > radius) continue;
if (abs(dy) % step != 0) continue;
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
if (dx < -radius || dx > radius) continue;
if (abs(dx) % step != 0) continue;
vec2 offset = vec2(float(dx), float(dy));
float dist2 = dot(offset, offset);
if (dist2 > radius2) continue;
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
// Spatial Gaussian
float spatialWeight = exp(dist2 * invSpatial2);
// Perceptual color distance (weighted RGB)
vec3 diff = sampleRGB - centerRGB;
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
float colorWeight = exp(colorDist * invColor2);
float w = spatialWeight * colorWeight;
sumRGB += sampleRGB * w;
sumWeight += w;
}
}
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
return vec4(resultRGB, center.a); // preserve center alpha
}
void main() {
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
int radius = int(radiusF + 0.5);
if (radius == 0) {
fragColor = texture(u_image0, v_texCoord);
return;
}
// Edge threshold → color sigma
// Squared curve for better low-end control
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
t *= t;
float sigmaColor = mix(0.01, 0.5, t);
// Spatial sigma tied to radius
float sigmaSpatial = max(radiusF * 0.75, 0.5);
fragColor = bilateralFilter(
v_texCoord,
texelSize,
radius,
sigmaSpatial,
sigmaColor
);
}

View File

@@ -0,0 +1,124 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // grain amount [0.0 1.0] typical: 0.20.8
uniform float u_float1; // grain size [0.3 3.0] lower = finer grain
uniform float u_float2; // color amount [0.0 1.0] 0 = monochrome, 1 = RGB grain
uniform float u_float3; // luminance bias [0.0 1.0] 0 = uniform, 1 = shadows only
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// High-quality integer hash (pcg-like)
uint pcg(uint v) {
uint state = v * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
// 2D -> 1D hash input
uint hash2d(uvec2 p) {
return pcg(p.x + pcg(p.y));
}
// Hash to float [0, 1]
float hashf(uvec2 p) {
return float(hash2d(p)) / float(0xffffffffu);
}
// Hash to float with offset (for RGB channels)
float hashf(uvec2 p, uint offset) {
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
}
// Convert uniform [0,1] to roughly Gaussian distribution
// Using simple approximation: average of multiple samples
float toGaussian(uvec2 p) {
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
return (sum - 2.0) * 0.7; // Centered, scaled
}
float toGaussian(uvec2 p, uint offset) {
float sum = hashf(p, offset) + hashf(p, offset + 1u)
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
return (sum - 2.0) * 0.7;
}
// Smooth noise with better interpolation
float smoothNoise(vec2 p) {
vec2 i = floor(p);
vec2 f = fract(p);
// Quintic interpolation (less banding than cubic)
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui);
float b = toGaussian(ui + uvec2(1u, 0u));
float c = toGaussian(ui + uvec2(0u, 1u));
float d = toGaussian(ui + uvec2(1u, 1u));
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
float smoothNoise(vec2 p, uint offset) {
vec2 i = floor(p);
vec2 f = fract(p);
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui, offset);
float b = toGaussian(ui + uvec2(1u, 0u), offset);
float c = toGaussian(ui + uvec2(0u, 1u), offset);
float d = toGaussian(ui + uvec2(1u, 1u), offset);
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Luminance (Rec.709)
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
// Grain UV (resolution-independent)
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
uvec2 grainPixel = uvec2(grainUV);
float g;
vec3 grainRGB;
if (u_int0 == 1) {
// Grainy mode: pure hash noise (no interpolation = no banding)
g = toGaussian(grainPixel);
grainRGB = vec3(
toGaussian(grainPixel, 100u),
toGaussian(grainPixel, 200u),
toGaussian(grainPixel, 300u)
);
} else {
// Smooth mode: interpolated with quintic curve
g = smoothNoise(grainUV);
grainRGB = vec3(
smoothNoise(grainUV, 100u),
smoothNoise(grainUV, 200u),
smoothNoise(grainUV, 300u)
);
}
// Luminance weighting (less grain in highlights)
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
// Strength
float strength = u_float0 * 0.15;
// Color vs monochrome grain
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
color.rgb += grainColor * strength * lumWeight;
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
}

View File

@@ -0,0 +1,133 @@
#version 300 es
precision mediump float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint
uniform float u_float0; // Intensity
uniform float u_float1; // Radius
uniform float u_float2; // Threshold
in vec2 v_texCoord;
out vec4 fragColor;
const int BLEND_ADD = 0;
const int BLEND_SCREEN = 1;
const int BLEND_SOFT = 2;
const int BLEND_OVERLAY = 3;
const int BLEND_LIGHTEN = 4;
const float GOLDEN_ANGLE = 2.39996323;
const int MAX_SAMPLES = 48;
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
float hash(vec2 p) {
p = fract(p * vec2(123.34, 456.21));
p += dot(p, p + 45.32);
return fract(p.x * p.y);
}
vec3 hexToRgb(int h) {
return vec3(
float((h >> 16) & 255),
float((h >> 8) & 255),
float(h & 255)
) * (1.0 / 255.0);
}
vec3 blend(vec3 base, vec3 glow, int mode) {
if (mode == BLEND_SCREEN) {
return 1.0 - (1.0 - base) * (1.0 - glow);
}
if (mode == BLEND_SOFT) {
return mix(
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
step(0.5, glow)
);
}
if (mode == BLEND_OVERLAY) {
return mix(
2.0 * base * glow,
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
step(0.5, base)
);
}
if (mode == BLEND_LIGHTEN) {
return max(base, glow);
}
return base + glow;
}
void main() {
vec4 original = texture(u_image0, v_texCoord);
float intensity = u_float0 * 0.05;
float radius = u_float1 * u_float1 * 0.012;
if (intensity < 0.001 || radius < 0.1) {
fragColor = original;
return;
}
float threshold = 1.0 - u_float2 * 0.01;
float t0 = threshold - 0.15;
float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution;
float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
int samples = int(float(MAX_SAMPLES) * sampleScale);
float noise = hash(gl_FragCoord.xy);
float angleOffset = noise * GOLDEN_ANGLE;
float radiusJitter = 0.85 + noise * 0.3;
float ca = cos(GOLDEN_ANGLE);
float sa = sin(GOLDEN_ANGLE);
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
vec3 glow = vec3(0.0);
float totalWeight = 0.0;
// Center tap
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
glow += original.rgb * centerMask * 2.0;
totalWeight += 2.0;
for (int i = 1; i < MAX_SAMPLES; i++) {
if (i >= samples) break;
float fi = float(i);
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
vec2 offset = dir * dist * texelSize;
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
float mask = smoothstep(t0, t1, dot(c, LUMA));
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
w = max(w, 0.0);
w *= w;
glow += c * mask * w;
totalWeight += w;
dir = vec2(
dir.x * ca - dir.y * sa,
dir.x * sa + dir.y * ca
);
}
glow *= intensity / max(totalWeight, 0.001);
if (u_int1 > 0) {
glow *= hexToRgb(u_int1);
}
vec3 result = blend(original.rgb, glow, u_int0);
result += (noise - 0.5) * (1.0 / 255.0);
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
}

View File

@@ -0,0 +1,222 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
uniform float u_float0; // Hue (-180 to 180)
uniform float u_float1; // Saturation (-100 to 100)
uniform float u_float2; // Lightness/Brightness (-100 to 100)
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
in vec2 v_texCoord;
out vec4 fragColor;
// Color range modes
const int MODE_MASTER = 0;
const int MODE_RED = 1;
const int MODE_YELLOW = 2;
const int MODE_GREEN = 3;
const int MODE_CYAN = 4;
const int MODE_BLUE = 5;
const int MODE_MAGENTA = 6;
const int MODE_COLORIZE = 7;
// Color space modes
const int COLORSPACE_HSL = 0;
const int COLORSPACE_HSB = 1;
const float EPSILON = 0.0001;
//=============================================================================
// RGB <-> HSL Conversions
//=============================================================================
vec3 rgb2hsl(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = 0.0;
float l = (maxC + minC) * 0.5;
if (delta > EPSILON) {
s = l < 0.5
? delta / (maxC + minC)
: delta / (2.0 - maxC - minC);
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
t = fract(t);
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
if (t < 0.5) return q;
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
if (hsl.y < EPSILON) return vec3(hsl.z);
float q = hsl.z < 0.5
? hsl.z * (1.0 + hsl.y)
: hsl.z + hsl.y - hsl.z * hsl.y;
float p = 2.0 * hsl.z - q;
return vec3(
hue2rgb(p, q, hsl.x + 1.0/3.0),
hue2rgb(p, q, hsl.x),
hue2rgb(p, q, hsl.x - 1.0/3.0)
);
}
vec3 rgb2hsb(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
float b = maxC;
if (delta > EPSILON) {
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, b);
}
vec3 hsb2rgb(vec3 hsb) {
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
}
//=============================================================================
// Color Range Weight Calculation
//=============================================================================
float hueDistance(float a, float b) {
float d = abs(a - b);
return min(d, 1.0 - d);
}
float getHueWeight(float hue, float center, float overlap) {
float baseWidth = 1.0 / 6.0;
float feather = baseWidth * overlap;
float d = hueDistance(hue, center);
float inner = baseWidth * 0.5;
float outer = inner + feather;
return 1.0 - smoothstep(inner, outer, d);
}
float getModeWeight(float hue, int mode, float overlap) {
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
if (mode == MODE_RED) {
return max(
getHueWeight(hue, 0.0, overlap),
getHueWeight(hue, 1.0, overlap)
);
}
float center = float(mode - 1) / 6.0;
return getHueWeight(hue, center, overlap);
}
//=============================================================================
// Adjustment Functions
//=============================================================================
float adjustLightness(float l, float amount) {
return amount > 0.0
? l + (1.0 - l) * amount
: l + l * amount;
}
float adjustBrightness(float b, float amount) {
return clamp(b + amount, 0.0, 1.0);
}
float adjustSaturation(float s, float amount) {
return amount > 0.0
? s + (1.0 - s) * amount
: s + s * amount;
}
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
float l = adjustLightness(lum, light);
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
return hsl2rgb(hsl);
}
//=============================================================================
// Main
//=============================================================================
void main() {
vec4 original = texture(u_image0, v_texCoord);
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
vec3 result;
if (u_int0 == MODE_COLORIZE) {
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
fragColor = vec4(result, original.a);
return;
}
vec3 hsx = (u_int1 == COLORSPACE_HSL)
? rgb2hsl(original.rgb)
: rgb2hsb(original.rgb);
float weight = getModeWeight(hsx.x, u_int0, overlap);
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
weight = 0.0;
}
if (weight > EPSILON) {
float h = fract(hsx.x + hueShift * weight);
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
float v = (u_int1 == COLORSPACE_HSL)
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
vec3 adjusted = vec3(h, s, v);
result = (u_int1 == COLORSPACE_HSL)
? hsl2rgb(adjusted)
: hsb2rgb(adjusted);
} else {
result = original.rgb;
}
fragColor = vec4(result, original.a);
}

View File

@@ -0,0 +1,111 @@
#version 300 es
#pragma passes 2
precision highp float;
// Blur type constants
const int BLUR_GAUSSIAN = 0;
const int BLUR_BOX = 1;
const int BLUR_RADIAL = 2;
// Radial blur config
const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
void main() {
vec2 texelSize = 1.0 / u_resolution;
float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable
if (u_int0 == BLUR_RADIAL) {
// Only execute on first pass
if (u_pass > 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec2 center = vec2(0.5);
vec2 dir = v_texCoord - center;
float dist = length(dir);
if (dist < 1e-4) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec4 sum = vec4(0.0);
float totalWeight = 0.0;
float angleStep = radius * RADIAL_STRENGTH;
dir /= dist;
float cosStep = cos(angleStep);
float sinStep = sin(angleStep);
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
vec2 rotDir = vec2(
dir.x * cos(negAngle) - dir.y * sin(negAngle),
dir.x * sin(negAngle) + dir.y * cos(negAngle)
);
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
vec2 uv = center + rotDir * dist;
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
sum += texture(u_image0, uv) * w;
totalWeight += w;
rotDir = vec2(
rotDir.x * cosStep - rotDir.y * sinStep,
rotDir.x * sinStep + rotDir.y * cosStep
);
}
fragColor0 = sum / max(totalWeight, 0.001);
return;
}
// Separable Gaussian / Box blur
int samples = int(ceil(radius));
if (samples == 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
// Direction: pass 0 = horizontal, pass 1 = vertical
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
vec4 color = vec4(0.0);
float totalWeight = 0.0;
float sigma = radius / 2.0;
for (int i = -samples; i <= samples; i++) {
vec2 offset = dir * float(i) * texelSize;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float weight;
if (u_int0 == BLUR_GAUSSIAN) {
weight = gaussian(float(i), sigma);
} else {
// BLUR_BOX
weight = 1.0;
}
color += sample_color * weight;
totalWeight += weight;
}
fragColor0 = color / totalWeight;
}

View File

@@ -0,0 +1,19 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
layout(location = 1) out vec4 fragColor1;
layout(location = 2) out vec4 fragColor2;
layout(location = 3) out vec4 fragColor3;
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Output each channel as grayscale to separate render targets
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
}

View File

@@ -0,0 +1,71 @@
#version 300 es
precision highp float;
// Levels Adjustment
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
// u_float0: input black (0-255) default: 0
// u_float1: input white (0-255) default: 255
// u_float2: gamma (0.01-9.99) default: 1.0
// u_float3: output black (0-255) default: 0
// u_float4: output white (0-255) default: 255
uniform sampler2D u_image0;
uniform int u_int0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
result = pow(result, vec3(1.0 / gamma));
result = mix(vec3(outBlack), vec3(outWhite), result);
return result;
}
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
result = pow(result, 1.0 / gamma);
result = mix(outBlack, outWhite, result);
return result;
}
void main() {
vec4 texColor = texture(u_image0, v_texCoord);
vec3 color = texColor.rgb;
float inBlack = u_float0 / 255.0;
float inWhite = u_float1 / 255.0;
float gamma = u_float2;
float outBlack = u_float3 / 255.0;
float outWhite = u_float4 / 255.0;
vec3 result;
if (u_int0 == 0) {
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 1) {
result = color;
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 2) {
result = color;
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 3) {
result = color;
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
}
else {
result = color;
}
fragColor = vec4(result, texColor.a);
}

View File

@@ -0,0 +1,28 @@
# GLSL Shader Sources
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
## File Naming Convention
`{Blueprint_Name}_{node_id}.frag`
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
- **node_id**: The GLSLShader node ID within the subgraph
## Usage
```bash
# Extract shaders from blueprint JSONs to this folder
python update_blueprints.py extract
# Patch edited shaders back into blueprint JSONs
python update_blueprints.py patch
```
## Workflow
1. Run `extract` to pull current shaders from JSONs
2. Edit `.frag` files
3. Run `patch` to update the blueprint JSONs
4. Test
5. Commit both `.frag` files and updated JSONs

View File

@@ -0,0 +1,28 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
vec2 texel = 1.0 / u_resolution;
// Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord);
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
// Edge enhancement (Laplacian)
vec4 edges = center * 4.0 - top - bottom - left - right;
// Add edges back scaled by strength
vec4 sharpened = center + edges * u_float0;
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
}

View File

@@ -0,0 +1,61 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
float getLuminance(vec3 color) {
return dot(color, vec3(0.2126, 0.7152, 0.0722));
}
void main() {
vec2 texel = 1.0 / u_resolution;
float radius = max(u_float1, 0.5);
float amount = u_float0;
float threshold = u_float2;
vec4 original = texture(u_image0, v_texCoord);
// Gaussian blur for the "unsharp" mask
int samples = int(ceil(radius));
float sigma = radius / 2.0;
vec4 blurred = vec4(0.0);
float totalWeight = 0.0;
for (int x = -samples; x <= samples; x++) {
for (int y = -samples; y <= samples; y++) {
vec2 offset = vec2(float(x), float(y)) * texel;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float dist = length(vec2(float(x), float(y)));
float weight = gaussian(dist, sigma);
blurred += sample_color * weight;
totalWeight += weight;
}
}
blurred /= totalWeight;
// Unsharp mask = original - blurred
vec3 mask = original.rgb - blurred.rgb;
// Luminance-based threshold with smooth falloff
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
mask *= thresholdScale;
// Sharpen: original + mask * amount
vec3 sharpened = original.rgb + mask * amount;
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
}

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""
Shader Blueprint Updater
Syncs GLSL shader files between this folder and blueprint JSON files.
File naming convention:
{Blueprint Name}_{node_id}.frag
Usage:
python update_blueprints.py extract # Extract shaders from JSONs to here
python update_blueprints.py patch # Patch shaders back into JSONs
python update_blueprints.py # Same as patch (default)
"""
import json
import logging
import sys
import re
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
GLSL_DIR = Path(__file__).parent
BLUEPRINTS_DIR = GLSL_DIR.parent
def get_blueprint_files():
"""Get all blueprint JSON files."""
return sorted(BLUEPRINTS_DIR.glob("*.json"))
def sanitize_filename(name):
"""Convert blueprint name to safe filename."""
return re.sub(r'[^\w\-]', '_', name)
def extract_shaders():
"""Extract all shaders from blueprint JSONs to this folder."""
extracted = 0
for json_path in get_blueprint_files():
blueprint_name = json_path.stem
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning("Skipping %s: %s", json_path.name, e)
continue
# Find GLSLShader nodes in subgraphs
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('type') == 'GLSLShader':
node_id = node.get('id')
widgets = node.get('widgets_values', [])
# Find shader code (first string that looks like GLSL)
for widget in widgets:
if isinstance(widget, str) and widget.startswith('#version'):
safe_name = sanitize_filename(blueprint_name)
frag_name = f"{safe_name}_{node_id}.frag"
frag_path = GLSL_DIR / frag_name
with open(frag_path, 'w') as f:
f.write(widget)
logger.info(" Extracted: %s", frag_name)
extracted += 1
break
logger.info("\nExtracted %d shader(s)", extracted)
def patch_shaders():
"""Patch shaders from this folder back into blueprint JSONs."""
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
shader_updates = {}
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
# Parse filename: {blueprint_name}_{node_id}.frag
parts = frag_path.stem.rsplit('_', 1)
if len(parts) != 2:
logger.warning("Skipping %s: invalid filename format", frag_path.name)
continue
blueprint_name, node_id_str = parts
try:
node_id = int(node_id_str)
except ValueError:
logger.warning("Skipping %s: invalid node_id", frag_path.name)
continue
with open(frag_path, 'r') as f:
shader_code = f.read()
if blueprint_name not in shader_updates:
shader_updates[blueprint_name] = []
shader_updates[blueprint_name].append((node_id, shader_code))
# Apply updates to JSON files
patched = 0
for json_path in get_blueprint_files():
blueprint_name = sanitize_filename(json_path.stem)
if blueprint_name not in shader_updates:
continue
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.error("Error reading %s: %s", json_path.name, e)
continue
modified = False
for node_id, shader_code in shader_updates[blueprint_name]:
# Find the node and update
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
widgets = node.get('widgets_values', [])
if len(widgets) > 0 and widgets[0] != shader_code:
widgets[0] = shader_code
modified = True
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
patched += 1
if modified:
with open(json_path, 'w') as f:
json.dump(data, f)
if patched == 0:
logger.info("No changes to apply.")
else:
logger.info("\nPatched %d shader(s)", patched)
def main():
if len(sys.argv) < 2:
command = "patch"
else:
command = sys.argv[1].lower()
if command == "extract":
logger.info("Extracting shaders from blueprints...")
extract_shaders()
elif command in ("patch", "update", "apply"):
logger.info("Patching shaders into blueprints...")
patch_shaders()
else:
logger.info(__doc__)
sys.exit(1)
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
blueprints/Glow.json Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}

1
blueprints/Sharpen.json Normal file
View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}

View File

View File

@@ -25,11 +25,11 @@ class AudioEncoderModel():
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()

View File

@@ -1,13 +0,0 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -257,3 +258,6 @@ elif args.fast == []:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
def enables_dynamic_vram():
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@@ -47,10 +47,10 @@ class ClipVisionModel():
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
@@ -66,6 +66,7 @@ class ClipVisionModel():
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]

View File

@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):
@@ -236,6 +238,8 @@ class ComfyNodeABC(ABC):
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
DEV_ONLY: bool
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
API_NODE: Optional[bool]
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""

View File

@@ -4,6 +4,25 @@ import comfy.utils
import logging
def is_equal(x, y):
if torch.is_tensor(x) and torch.is_tensor(y):
return torch.equal(x, y)
elif isinstance(x, dict) and isinstance(y, dict):
if x.keys() != y.keys():
return False
return all(is_equal(x[k], y[k]) for k in x)
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
if type(x) is not type(y) or len(x) != len(y):
return False
return all(is_equal(a, b) for a, b in zip(x, y))
else:
try:
return x == y
except Exception:
logging.warning("comparison issue with COND")
return False
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
if not is_equal(self.cond, other.cond):
return False
return True

View File

@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
return # substep from multi-step sampler: keep self._step from the last full step
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:

View File

@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class QwenFunControlNet(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
# Fun checkpoints are more sensitive to high strengths in the generic
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
# unchanged while >1 grows more gently.
original_strength = self.strength
self.strength = math.sqrt(max(self.strength, 0.0))
try:
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
finally:
self.strength = original_strength
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_fun(sd, model_options={}):
load_device = comfy.model_management.get_torch_device()
weight_dtype = comfy.utils.weight_dtype(sd)
unet_dtype = model_options.get("dtype", weight_dtype)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
in_features = sd["control_img_in.weight"].shape[1]
inner_dim = sd["control_img_in.weight"].shape[0]
block_weight = sd["control_blocks.0.attn.to_q.weight"]
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
control_in_features=in_features,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
operations=operations,
device=comfy.model_management.unet_offload_device(),
dtype=unet_dtype,
)
model = controlnet_load_state_dict(model, sd)
latent_format = comfy.latent_formats.Wan21()
control = QwenFunControlNet(
model,
compression_ratio=1,
latent_format=latent_format,
# Fun checkpoints already expect their own 33-channel context handling.
# Enabling generic concat_mask injects an extra mask channel at apply-time
# and breaks the intended fallback packing path.
concat_mask=False,
load_device=load_device,
manual_cast_dtype=manual_cast_dtype,
extra_conds=[],
)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
x = x.view(orig_shape).nan_to_num()
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
return data_lp, scaled_block_scales_fp8
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
return x, to_blocked(blocked_scaled, flatten=False)
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
@@ -158,28 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
orig_shape = list(orig_shape)
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
num_slices = max(1, (x.numel() / block_size))
slice_size = max(1, (round(x.shape[0] / num_slices)))
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
for i in range(0, x.shape[0], slice_size):
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
output_fp4[i:i + slice_size].copy_(fp4)
output_block[i:i + slice_size].copy_(block)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales
return output_fp4, to_blocked(output_block, flatten=False)

View File

@@ -5,7 +5,7 @@ from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -13,6 +13,9 @@ from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

View File

@@ -8,6 +8,7 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
def process_in(self, latent):
return latent * self.scale_factor
@@ -80,6 +81,7 @@ class SD_X4(LatentFormat):
class SC_Prior(LatentFormat):
latent_channels = 16
spacial_downscale_ratio = 42
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
@@ -102,6 +104,7 @@ class SC_Prior(LatentFormat):
]
class SC_B(LatentFormat):
spacial_downscale_ratio = 4
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
@@ -181,6 +184,7 @@ class Flux(SD3):
class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16
def __init__(self):
self.latent_rgb_factors =[
@@ -272,6 +276,7 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
spacial_downscale_ratio = 32
def __init__(self):
self.latent_rgb_factors = [
@@ -515,6 +520,7 @@ class Wan21(LatentFormat):
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
spacial_downscale_ratio = 16
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
@@ -592,6 +598,7 @@ class Wan22(Wan21):
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
spacial_downscale_ratio = 32
scale_factor = 0.75289
latent_rgb_factors = [
@@ -725,6 +732,7 @@ class HunyuanVideo15(LatentFormat):
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
spacial_downscale_ratio = 16
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
@@ -747,8 +755,13 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
def __init__(self):
self.latent_rgb_factors = [

1155
comfy/ldm/ace/ace_step15.py Normal file

File diff suppressed because it is too large Load Diff

214
comfy/ldm/anima/model.py Normal file
View File

@@ -0,0 +1,214 @@
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
import torch
from torch import nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def init_weights(self):
torch.nn.init.zeros_(self.o_proj.weight)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
nn.GELU(),
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
def init_weights(self):
torch.nn.init.zeros_(self.mlp[2].weight)
self.cross_attn.init_weights()
class LLMAdapter(nn.Module):
def __init__(
self,
source_dim=1024,
target_dim=1024,
model_dim=1024,
num_layers=6,
num_heads=16,
use_self_attn=True,
layer_norm=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
if model_dim != target_dim:
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
self.blocks = nn.ModuleList([
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
])
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
return self.norm(self.out_proj(x))
class Anima(MiniTrainDIT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None:
out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
else:
return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property

View File

@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:

View File

@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,

View File

@@ -13,6 +13,7 @@ from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
@@ -334,7 +335,7 @@ class FinalLayer(nn.Module):
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
@@ -462,6 +463,8 @@ class Block(nn.Module):
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
residual_dtype = x_B_T_H_W_D.dtype
compute_dtype = emb_B_T_D.dtype
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -511,7 +514,7 @@ class Block(nn.Module):
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -521,7 +524,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@@ -535,7 +538,7 @@ class Block(nn.Module):
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -554,7 +557,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@@ -562,8 +565,8 @@ class Block(nn.Module):
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
return x_B_T_H_W_D
@@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module):
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
orig_shape = list(x.shape)
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
@@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
# in fp32, but run attention and MLP modules in fp16.
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
# quality degradation and visual artifacts.
if x_B_T_H_W_D.dtype == torch.float16:
x_B_T_H_W_D = x_B_T_H_W_D.float()
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
@@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module):
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
return x_B_C_Tt_Hp_Wp

View File

@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
# Fix import for some custom nodes, TODO: delete eventually.
RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
@@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
apply_rope = _apply_rope
apply_rope1 = _apply_rope1

View File

@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None
@@ -143,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -232,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((img, txt), 1)
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
@@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img = img[:, : img_len]
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

View File

@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()

View File

@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
class CompressedTimestep:
@@ -18,12 +19,12 @@ class CompressedTimestep:
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
@@ -215,22 +216,9 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -240,144 +228,102 @@ class BasicAVTransformerBlock(nn.Module):
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
# video
if run_vx:
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
# video self-attention
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
# audio
if run_ax:
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
# audio self-attention
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
del norm_ax
# audio cross-attention
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
# video - audio cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
# audio to video cross attention
if run_a2v:
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
# video feedforward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
del vshift_mlp, vscale_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
del vgate_mlp, ff_out
# audio feedforward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp
del ashift_mlp, ascale_mlp, agate_mlp
ff_out = self.audio_ff(ax_scaled)
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
del agate_mlp, ff_out
return vx, ax
@@ -505,6 +451,29 @@ class LTXAVModel(LTXVModel):
operations=self.operations,
)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(
@@ -589,9 +558,20 @@ class LTXAVModel(LTXVModel):
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
has_spatial_mask = False
if denoise_mask is not None:
# check if any frame has spatial variation (inpainting)
for frame_idx in range(denoise_mask.shape[2]):
frame_mask = denoise_mask[0, 0, frame_idx]
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
has_spatial_mask = True
break
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
@@ -618,8 +598,9 @@ class LTXAVModel(LTXVModel):
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if orig_shape is not None and len(orig_shape) == 5:
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
@@ -662,10 +643,11 @@ class LTXAVModel(LTXVModel):
)
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]
@@ -744,7 +726,7 @@ class LTXAVModel(LTXVModel):
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
):
vx = x[0]
ax = x[1]
@@ -788,6 +770,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
)
return out
@@ -808,6 +791,7 @@ class LTXAVModel(LTXVModel):
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
},
{"original_block": block_wrap},
)
@@ -829,6 +813,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
)
return [vx, ax]

View File

@@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module):
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.rand(
torch.empty(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
)
* 2.0
- 1.0
)
def get_fractional_positions(self, indices_grid):
@@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module):
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
@@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module):
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward(
self,
@@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import logging
import math
from typing import Dict, Optional, Tuple
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base):
return np.log(x) / np.log(base)
@@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
@@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
"""Process input data. Must be implemented by subclasses."""
pass
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Base implementation returns None (no attenuation). Subclasses that
support guide-based attention control should override this.
"""
return None
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Build self-attention mask for per-guide attenuation
self_attention_mask = self._build_guide_self_attention_mask(
x, transformer_options, merged_args
)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
x, context, attention_mask, timestep, pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
**merged_args,
)
# Process output
@@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
pixel_coords = pixel_coords[:, :, grid_mask, ...]
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
# Compute per-guide surviving token counts from guide_attention_entries.
# Each entry tracks one guide reference; they are appended in order and
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
if total_pfc != len(kf_grid_mask):
raise ValueError(
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
pfc = entry["pre_filter_count"]
entry_mask = kf_grid_mask[offset:offset + pfc]
surviving = int(entry_mask.sum().item())
resolved_entries.append({
**entry,
"surviving_count": surviving,
})
offset += pfc
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
x = self.patchify_proj(x)
return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Reads resolved_guide_entries from merged_args (computed in _process_input)
to build a log-space additive bias mask that attenuates noisy ↔ guide
attention for each guide reference independently.
Returns None if no attenuation is needed (all strengths == 1.0 and no
spatial masks, or no guide tokens).
"""
if isinstance(x, list):
# AV model: x = [vx, ax]; use vx for token count and device
total_tokens = x[0].shape[1]
device = x[0].device
dtype = x[0].dtype
else:
total_tokens = x.shape[1]
device = x.device
dtype = x.dtype
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
if num_guide_tokens == 0:
return None
resolved_entries = merged_args.get("resolved_guide_entries", None)
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
# Guides are appended in order at the end of the sequence.
guide_start = total_tokens - num_guide_tokens
all_weights = []
total_tracked = 0
for entry in resolved_entries:
surviving = entry["surviving_count"]
if surviving == 0:
continue
strength = entry["strength"]
pixel_mask = entry.get("pixel_mask")
latent_shape = entry.get("latent_shape")
if pixel_mask is not None and latent_shape is not None:
f_lat, h_lat, w_lat = latent_shape
per_token = self._downsample_mask_to_latent(
pixel_mask.to(device=device, dtype=dtype),
f_lat, h_lat, w_lat,
)
# per_token shape: (B, f_lat*h_lat*w_lat).
# Collapse batch dim — the mask is assumed identical across the
# batch; validate and take the first element to get (1, tokens).
if per_token.shape[0] > 1:
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
logger.warning(
"pixel_mask differs across batch elements; "
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.
n_weights = min(per_token.shape[1], surviving)
weights = per_token[:, :n_weights] * strength # (1, n_weights)
else:
weights = torch.full(
(1, surviving), strength, device=device, dtype=dtype
)
all_weights.append(weights)
total_tracked += weights.shape[1]
if not all_weights:
return None
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
"""Downsample a pixel-space mask to per-token latent weights.
Args:
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
f_lat: Number of latent frames (pre-dilation original count).
h_lat: Latent height (pre-dilation original height).
w_lat: Latent width (pre-dilation original width).
Returns:
(B, F_lat * H_lat * W_lat) flattened per-token weights.
"""
b = mask.shape[0]
f_pix = mask.shape[2]
# Spatial downsampling: area interpolation per frame
spatial_down = torch.nn.functional.interpolate(
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
size=(h_lat, w_lat),
mode="area",
)
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
# Temporal downsampling: first pixel frame maps to first latent frame,
# remaining pixel frames are averaged in groups for causal temporal structure.
first_frame = spatial_down[:, :, :1, :, :]
if f_pix > 1 and f_lat > 1:
remaining_pix = f_pix - 1
remaining_lat = f_lat - 1
t = remaining_pix // remaining_lat
if t < 1:
# Fewer pixel frames than latent frames — upsample by repeating
# the available pixel frames via nearest interpolation.
rest_flat = rearrange(
spatial_down[:, :, 1:, :, :],
"b 1 f h w -> (b h w) 1 f",
)
rest_up = torch.nn.functional.interpolate(
rest_flat, size=remaining_lat, mode="nearest",
)
rest = rearrange(
rest_up, "(b h w) 1 f -> b 1 f h w",
b=b, h=h_lat, w=w_lat,
)
else:
# Trim trailing pixel frames that don't fill a complete group
usable = remaining_lat * t
rest = rearrange(
spatial_down[:, :, 1:1 + usable, :, :],
"b 1 (f t) h w -> b 1 f t h w",
t=t,
)
rest = rest.mean(dim=3)
latent_mask = torch.cat([first_frame, rest], dim=2)
elif f_lat > 1:
# Single pixel frame but multiple latent frames — repeat the
# single frame across all latent frames.
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
else:
latent_mask = first_frame
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
timestep=timestep,
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
)
return x

View File

@@ -103,20 +103,10 @@ class AudioPreprocessor:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
@staticmethod
def normalize_amplitude(
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
) -> torch.Tensor:
waveform = waveform - waveform.mean(dim=2, keepdim=True)
peak = torch.max(torch.abs(waveform)) + eps
scale = peak.clamp(max=max_amplitude) / peak
return waveform * scale
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
waveform = self.normalize_amplitude(waveform)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,
@@ -189,9 +179,12 @@ class AudioVAE(torch.nn.Module):
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
if waveform.shape[1] == 1:
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
else:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device

View File

@@ -1,11 +1,11 @@
from typing import Tuple, Union
import threading
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
@@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode,
groups=groups,
)
self.temporal_cache_state={}
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
tid = threading.get_ident()
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
if cached is None:
padding_length = self.time_kernel_size - 1
if not causal:
padding_length = padding_length // 2
if x.shape[2] == 0:
return x
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
@property
def weight(self):

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import threading
import torch
from torch import nn
from functools import partial
@@ -6,12 +7,35 @@ import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -205,7 +229,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@@ -254,6 +278,22 @@ class Encoder(nn.Module):
return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@@ -341,18 +381,6 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
@@ -428,8 +456,9 @@ class Decoder(nn.Module):
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
@@ -437,6 +466,7 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
@@ -445,24 +475,12 @@ class Decoder(nn.Module):
else lambda x: x
)
scaled_timestep = None
timestep_shift_scale = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
@@ -483,16 +501,62 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
timestep_shift_scale = ada_values.unbind(dim=1)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
output = []
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module):
"""
@@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
@@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
drop_first_res = False
if y.shape[2] == 0:
y = None
cached = add_exchange_cache(y, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
else:
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
return y
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
@@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5
)
self.temporal_cache_state={}
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
@@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
return output_tensor
return hidden_states
def patchify(x, patch_size_hw, patch_size_t=1):

View File

@@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
def modulate(x, scale, timestep_zero_index=None):
if timestep_zero_index is None:
return x * (1 + scale.unsqueeze(1))
else:
scale = (1 + scale.unsqueeze(1))
actual_batch = scale.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= scale[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= scale[:actual_batch]
return x
def apply_gate(gate, x, timestep_zero_index=None):
if timestep_zero_index is None:
return gate * x
else:
actual_batch = gate.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= gate[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= gate[:actual_batch]
return x
#############################################################################
# Core NextDiT Model #
@@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
timestep_zero_index=None,
transformer_options={},
):
"""
@@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))
))), timestep_zero_index=timestep_zero_index
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
))
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
))), timestep_zero_index=timestep_zero_index
)
else:
assert adaln_input is None
@@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
),
)
def forward(self, x, c):
def forward(self, x, c, timestep_zero_index=None):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
x = self.linear(x)
return x
def pad_zimage(feats, pad_token, pad_tokens_multiple):
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = start_t
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
return x_pos_ids
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
@@ -378,10 +446,12 @@ class NextDiT(nn.Module):
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
siglip_feat_dim=None,
image_model=None,
device=None,
dtype=None,
operations=None,
**kwargs,
) -> None:
super().__init__()
self.dtype = dtype
@@ -491,6 +561,41 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
siglip_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.siglip_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
@@ -531,70 +636,168 @@ class NextDiT(nn.Module):
imgs = torch.stack(imgs, dim=0)
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
if cap_feats is not None:
cap_feats = self.cap_embedder(cap_feats)
cap_feats_len = cap_feats.shape[1]
if self.pad_tokens_multiple is not None:
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
else:
cap_feats_len = 0
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
embeds = (cap_feats,)
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
bsz = 1
pH = pW = self.patch_size
device = x.device
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
if (not omni) or self.siglip_embedder is None:
cap_feats_len = embeds[0].shape[1] + offset
embeds += (None,)
freqs_cis += (None,)
else:
cap_feats_len += offset
if siglip_feats is not None:
b, h, w, c = siglip_feats.shape
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
siglip_feats = self.siglip_embedder(siglip_feats)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
if self.siglip_pad_token is not None:
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
else:
if self.siglip_pad_token is not None:
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
if siglip_feats is None:
embeds += (None,)
freqs_cis += (None,)
else:
embeds += (siglip_feats,)
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
embeds += (x,)
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
def patchify_and_embed(
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = x.shape[0]
cap_mask = None # TODO?
main_siglip = None
orig_x = x
embeds = ([], [], [])
freqs_cis = ([], [], [])
leftover_cap = []
start_t = 0
omni = len(ref_latents) > 0
if omni:
for i, ref in enumerate(ref_latents):
if i < len(ref_contexts):
ref_con = ref_contexts[i]
else:
ref_con = None
if i < len(siglip_feats):
sig_feat = siglip_feats[i]
else:
sig_feat = None
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
for i, e in enumerate(out[0]):
if e is not None:
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
freqs_cis[i].append(out[1][i])
start_t = out[2]
leftover_cap = ref_contexts[len(ref_latents):]
H, W = x.shape[-2], x.shape[-1]
img_sizes = [(H, W)] * bsz
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
img_len = out[0][-1].shape[1]
cap_len = out[0][0].shape[1]
for i, e in enumerate(out[0]):
if e is not None:
e = comfy.utils.repeat_to_batch_size(e, bsz)
embeds[i].append(e)
freqs_cis[i].append(out[1][i])
start_t = out[2]
for cap in leftover_cap:
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
cap_len += out[0][0].shape[1]
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
freqs_cis[0].append(out[1][0])
start_t += out[2]
patches = transformer_options.get("patches", {})
# refine context
cap_feats = torch.cat(embeds[0], dim=1)
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
feats = (cap_feats,)
fc = (cap_freqs_cis,)
if omni and len(embeds[1]) > 0:
siglip_mask = None
siglip_feats_combined = torch.cat(embeds[1], dim=1)
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
if self.siglip_refiner is not None:
for layer in self.siglip_refiner:
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
feats += (siglip_feats_combined,)
fc += (siglip_feats_freqs_cis,)
padded_img_mask = None
x = torch.cat(embeds[-1], dim=1)
fc_x = torch.cat(freqs_cis[-1], dim=1)
if omni:
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
else:
timestep_zero_index = None
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
padded_full_embed = torch.cat(feats + (x,), dim=1)
if timestep_zero_index is not None:
ind = padded_full_embed.shape[1] - x.shape[1]
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -604,7 +807,11 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -619,8 +826,6 @@ class NextDiT(nn.Module):
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
@@ -632,7 +837,7 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
@@ -640,7 +845,7 @@ class NextDiT(nn.Module):
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
@@ -649,8 +854,7 @@ class NextDiT(nn.Module):
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input)
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img

View File

@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape

Some files were not shown because too many files have changed in this diff Show More