Compare commits

...

34 Commits

Author SHA1 Message Date
bymyself
ded0032026 fix: remove essentials_category from CreateVideo (not in spec) 2026-03-15 04:22:24 -07:00
bymyself
91c6ccd39f refactor: keep only node class ESSENTIALS_CATEGORY, remove blueprint/subgraph changes
Frontend will own blueprint categorization separately.
2026-03-15 07:03:03 +00:00
bymyself
d88b19d96f fix: import NotRequired from typing_extensions for Python 3.10 compat 2026-03-15 05:50:04 +00:00
Christian Byrne
4d40630dd9 Merge branch 'master' into toolkit/wire-essentials-categorization 2026-03-14 21:33:16 -07:00
Jukka Seppänen
0904cc3fe5 LTXV: Accumulate VAE decode results on intermediate_device (#12955) 2026-03-14 18:09:09 -07:00
comfyanonymous
4941cd046e Update comfyui-frontend-package to version 1.41.20 (#12954) 2026-03-14 19:53:31 -04:00
comfyanonymous
c711b8f437 Add --fp16-intermediates to use fp16 for intermediate values between nodes (#12953)
This is an experimental WIP option that might not work in your workflow but
should lower memory usage if it does.

Currently only the VAE and the load image node will output in fp16 when
this option is turned on.
2026-03-14 19:18:19 -04:00
Jukka Seppänen
1c5db7397d feat: Support mxfp8 (#12907) 2026-03-14 18:36:29 -04:00
Christian Byrne
e0982a7174 fix: use no-store cache headers to prevent stale frontend chunks (#12911)
After a frontend update (e.g. nightly build), browsers could load
outdated cached index.html and JS/CSS chunks, causing dynamically
imported modules to fail with MIME type errors and vite:preloadError.

Hard refresh (Ctrl+Shift+R) was insufficient to fix the issue because
Cache-Control: no-cache still allows the browser to cache and
revalidate via ETags. aiohttp's FileResponse auto-generates ETags
based on file mtime+size, which may not change after pip reinstall,
so the browser gets 304 Not Modified and serves stale content.

Clearing ALL site data in DevTools did fix it, confirming the HTTP
cache was the root cause.

The fix changes:
- index.html: no-cache -> no-store, must-revalidate
- JS/CSS/JSON entry points: no-cache -> no-store

no-store instructs browsers to never cache these responses, ensuring
every page load fetches the current index.html with correct chunk
references. This is a small tradeoff (~5KB re-download per page load)
for guaranteed correctness after updates.
2026-03-14 18:25:09 -04:00
rattus
4c4be1bba5 comfy-aimdo 0.2.12 (#12941)
comfy-aimdo 0.2.12 fixes support for non-ASCII filepaths in the new
mmap helper.
2026-03-14 07:53:00 -07:00
comfyanonymous
16cd8d8a8f Update README. (#12931) 2026-03-13 22:33:28 -04:00
rattus
7810f49702 comfy aimdo 0.2.11 + Improved RAM Pressure release strategies - Windows speedups (#12925)
* Implement seek and read for pins

Source pins from an mmap is pad because its its a CPU->CPU copy that
attempts to fully buffer the same data twice. Instead, use seek and
read which avoids the mmap buffering while usually being a faster
read in the first place (avoiding mmap faulting etc).

* pinned_memory: Use Aimdo pinner

The aimdo pinner bypasses pytorches CPU allocator which can leak
windows commit charge.

* ops: bypass init() of weight for embedding layer

This similarly consumes large commit charge especially for TEs. It can
cause a permanement leaked commit charge which can destabilize on
systems close to the commit ceiling and generally confuses the RAM
stats.

* model_patcher: implement pinned memory counter

Implement a pinned memory counter for better accounting of what volume
of memory pins have.

* implement touch accounting

Implement accounting of touching mmapped tensors.

* mm+mp: add residency mmap getter

* utils: use the aimdo mmap to load sft files

* model_management: Implement tigher RAM pressure semantics

Implement a pressure release on entire MMAPs as windows does perform
faster when mmaps are unloaded and model loads free ramp into fully
unallocated RAM.

Make the concept of freeing for pins a completely separate concept.
Now that pins are loadable directly from original file and don' touch
the mmap, tighten the freeing budget to just the current loaded model
- what you have left over. This still over-frees pins, but its a lot
better than before.

So after the pins are freed with that algorithm, bounce entire MMAPs
to free RAM based on what the model needs, deducting off any known
resident-in-mmap tensors to the free quota to keep it as tight as
possible.

* comfy-aimdo 0.2.11

Comfy aimdo 0.2.11

* mm: Implement file_slice path for QT

* ruff

* ops: put meta-tensors in place to allow custom nodes to check geo
2026-03-13 22:18:08 -04:00
Dr.Lt.Data
e1f10ca093 bump manager version to 4.1b4 (#12930) 2026-03-13 20:14:27 -04:00
Comfy Org PR Bot
6cd35a0c5f Bump comfyui-frontend-package to 1.41.19 (#12923) 2026-03-13 14:31:25 -04:00
Alexander Piskun
f9ceed9eef fix(api-nodes): Tencent TextToModel and ImageToModel nodes (#12680)
* fix(api-nodes): added "texture_image" output to TencentTextToModel and TencentImageToModel nodes. Fixed `OBJ` output when it is zipped

* support additional solid texture outputs

* fixed and enabled Tencent3DTextureEdit node
2026-03-13 10:10:40 -07:00
Deep Mehta
4a8cf359fe Revert "Revert "feat: Add CacheProvider API for external distributed caching"" (#12915)
* Revert "Revert "feat: Add CacheProvider API for external distributed caching …"

This reverts commit d1d53c14be.

* fix: gate provider lookups to outputs cache and fix UI coercion

- Add `enable_providers` flag to BasicCache so only the outputs cache
  triggers external provider lookups/stores. The objects cache stores
  node class instances, not CacheEntry values, so provider calls were
  wasted round-trips that always missed.
- Remove `or {}` coercion on `result.ui` — an empty dict passes the
  `is not None` gate in execution.py and causes KeyError when the
  history builder indexes `["output"]` and `["meta"]`. Preserving
  `None` correctly skips the ui_node_outputs addition.
2026-03-12 21:17:50 -07:00
comfyanonymous
63d1bbdb40 ComfyUI v0.17.0 2026-03-12 20:44:22 -04:00
PxTicks
5df1427124 Fix audio extraction and truncation bugs (#12652)
Bug report in #12651

- to_skip fix: Prevents negative array slicing when the start offset is negative.
- __duration check: Prevents the extraction loop from breaking after a single audio chunk when the requested duration is 0 (which is a sentinel for unlimited).
2026-03-12 20:44:15 -04:00
comfyanonymous
d1d53c14be Revert "feat: Add CacheProvider API for external distributed caching (#12056)" (#12912)
This reverts commit af7b4a921d.
2026-03-12 20:21:23 -04:00
Deep Mehta
af7b4a921d feat: Add CacheProvider API for external distributed caching (#12056)
* feat: Add CacheProvider API for external distributed caching

Introduces a public API for external cache providers, enabling distributed
caching across multiple ComfyUI instances (e.g., Kubernetes pods).

New files:
- comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue
  dataclasses, thread-safe provider registry, serialization utilities

Modified files:
- comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store,
  _check_providers_lookup), subcache exclusion, prompt ID propagation
- execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to
  PromptExecutor, set _current_prompt_id on caches

Key features:
- Local-first caching (check local before external for performance)
- NaN detection to prevent incorrect external cache hits
- Subcache exclusion (ephemeral subgraph results not cached externally)
- Thread-safe provider snapshot caching
- Graceful error handling (provider errors logged, never break execution)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: use deterministic hash for cache keys instead of pickle

Pickle serialization is NOT deterministic across Python sessions due
to hash randomization affecting frozenset iteration order. This causes
distributed caching to fail because different pods compute different
hashes for identical cache keys.

Fix: Use _canonicalize() + JSON serialization which ensures deterministic
ordering regardless of Python's hash randomization.

This is critical for cross-pod cache key consistency in Kubernetes
deployments.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* test: add unit tests for CacheProvider API

- Add comprehensive tests for _canonicalize deterministic ordering
- Add tests for serialize_cache_key hash consistency
- Add tests for contains_nan utility
- Add tests for estimate_value_size
- Add tests for provider registry (register, unregister, clear)
- Move json import to top-level (fix inline import)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* style: remove unused imports in test_cache_provider.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: move _torch_available before usage and use importlib.util.find_spec

Fixes ruff F821 (undefined name) and F401 (unused import) errors.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: use hashable types in frozenset test and add dict test

Frozensets can only contain hashable types, so use nested frozensets
instead of dicts. Added separate test for dict handling via serialize_cache_key.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: expose CacheProvider API via comfy_api.latest.Caching

- Add Caching class to comfy_api/latest/__init__.py that re-exports
  from comfy_execution.cache_provider (source of truth)
- Fix docstring: "Skip large values" instead of "Skip small values"
  (small compute-heavy values are good cache targets)
- Maintain backward compatibility: comfy_execution.cache_provider
  imports still work

Usage:
    from comfy_api.latest import Caching

    class MyProvider(Caching.CacheProvider):
        def on_lookup(self, context): ...
        def on_store(self, context, value): ...

    Caching.register_provider(MyProvider())

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* docs: clarify should_cache filtering criteria

Change docstring from "Skip large values" to "Skip if download time > compute time"
which better captures the cost/benefit tradeoff for external caching.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* docs: make should_cache docstring implementation-agnostic

Remove prescriptive filtering suggestions - let implementations
decide their own caching logic based on their use case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: add optional ui field to CacheValue

- Add ui field to CacheValue dataclass (default None)
- Pass ui when creating CacheValue for external providers
- Use result.ui (or default {}) when returning from external cache lookup

This allows external cache implementations to store/retrieve UI data
if desired, while remaining optional for implementations that skip it.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: rename _is_cacheable_value to _is_external_cacheable_value

Clearer name since objects are also cached locally - this specifically
checks for external caching eligibility.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: async CacheProvider API + reduce public surface

- Make on_lookup/on_store async on CacheProvider ABC
- Simplify CacheContext: replace cache_key + cache_key_bytes with
  cache_key_hash (str hex digest)
- Make registry/utility functions internal (_prefix)
- Trim comfy_api.latest.Caching exports to core API only
- Make cache get/set async throughout caching.py hierarchy
- Use asyncio.create_task for fire-and-forget on_store
- Add NaN gating before provider calls in Core
- Add await to 5 cache call sites in execution.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: remove unused imports (ruff) and update tests for internal API

- Remove unused CacheContext and _serialize_cache_key imports from
  caching.py (now handled by _build_context helper)
- Update test_cache_provider.py to use _-prefixed internal names
- Update tests for new CacheContext.cache_key_hash field (str)
- Make MockCacheProvider methods async to match ABC

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address coderabbit review feedback

- Add try/except to _build_context, return None when hash fails
- Return None from _serialize_cache_key on total failure (no id()-based fallback)
- Replace hex-like test literal with non-secret placeholder

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: use _-prefixed imports in _notify_prompt_lifecycle

The lifecycle notification method was importing the old non-prefixed
names (has_cache_providers, get_cache_providers, logger) which no
longer exist after the API cleanup.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add sync get_local/set_local for graph traversal

ExecutionList in graph.py calls output_cache.get() and .set() from
sync methods (is_cached, cache_link, get_cache). These cannot await
the now-async get/set. Add get_local/set_local that bypass external
providers and only access the local dict — which is all graph
traversal needs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: remove cloud-specific language from cache provider API

Make all docstrings and comments generic for the OSS codebase.
Remove references to Kubernetes, Redis, GCS, pods, and other
infrastructure-specific terminology.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: align documentation with codebase conventions

Strip verbose docstrings and section banners to match existing minimal
documentation style used throughout the codebase.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add usage example to Caching class, remove pickle fallback

- Add docstring with usage example to Caching class matching the
  convention used by sibling APIs (Execution.set_progress, ComfyExtension)
- Remove non-deterministic pickle fallback from _serialize_cache_key;
  return None on JSON failure instead of producing unretrievable hashes
- Move cache_provider imports to top of execution.py (no circular dep)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: move public types to comfy_api, eager provider snapshot

Address review feedback:
- Move CacheProvider/CacheContext/CacheValue definitions to
  comfy_api/latest/_caching.py (source of truth for public API)
- comfy_execution/cache_provider.py re-exports types from there
- Build _providers_snapshot eagerly on register/unregister instead
  of lazy memoization in _get_cache_providers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: generalize self-inequality check, fail-closed canonicalization

Address review feedback from guill:
- Rename _contains_nan to _contains_self_unequal, use not (x == x)
  instead of math.isnan to catch any self-unequal value
- Remove Unhashable and repr() fallbacks from _canonicalize; raise
  ValueError for unknown types so _serialize_cache_key returns None
  and external caching is skipped (fail-closed)
- Update tests for renamed function and new fail-closed behavior

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: suppress ruff F401 for re-exported CacheContext

CacheContext is imported from _caching and re-exported for use by
caching.py. Add noqa comment to satisfy the linter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: enable external caching for subcache (expanded) nodes

Subcache nodes (from node expansion) now participate in external
provider store/lookup. Previously skipped to avoid duplicates, but
the cost of missing partial-expansion cache hits outweighs redundant
stores — especially with looping behavior on the horizon.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: wrap register/unregister as explicit static methods

Define register_provider and unregister_provider as wrapper functions
in the Caching class instead of re-importing. This locks the public
API signature in comfy_api/ so internal changes can't accidentally
break it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: use debug-level logging for provider registration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: follow ProxiedSingleton pattern for Caching class

Add Caching as a nested class inside ComfyAPI_latest inheriting from
ProxiedSingleton with async instance methods, matching the Execution
and NodeReplacement patterns. Retains standalone Caching class for
direct import convenience.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: inline registration logic in Caching class

Follow the Execution/NodeReplacement pattern — the public API methods
contain the actual logic operating on cache_provider module state,
not wrapper functions delegating to free functions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: single Caching definition inside ComfyAPI_latest

Remove duplicate standalone Caching class. Define it once as a nested
class in ComfyAPI_latest (matching Execution/NodeReplacement pattern),
with a module-level alias for import convenience.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: remove prompt_id from CacheContext, type-safe canonicalization

Remove prompt_id from CacheContext — it's not relevant for cache
matching and added unnecessary plumbing (_current_prompt_id on every
cache). Lifecycle hooks still receive prompt_id directly.

Include type name in canonicalized primitives so that int 7 and
str "7" produce distinct hashes. Also canonicalize dict keys properly
instead of str() coercion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address review feedback on cache provider API

- Hold references to pending store tasks to prevent "Task was destroyed
  but it is still pending" warnings (bigcat88)
- Parallel cache lookups with asyncio.gather instead of sequential
  awaits for better performance (bigcat88)
- Delegate Caching.register/unregister_provider to existing functions
  in cache_provider.py instead of reimplementing (bigcat88)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-12 16:09:07 -07:00
Christian Byrne
8d9faaa181 Update requirements.txt (#12910) 2026-03-12 18:14:59 -04:00
comfyanonymous
47e1e316c5 Lower kv cache memory usage. (#12909) 2026-03-12 16:54:38 -04:00
ComfyUI Wiki
712411d539 chore: update workflow templates to v0.9.21 (#12908) 2026-03-12 12:16:54 -07:00
Terry Jia
3fa8c5686d fix: use frontend-compatible format for Float gradient_stops (#12789)
Co-authored-by: guill <jacob.e.segal@gmail.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-12 10:14:28 -07:00
Terry Jia
73d9599495 add painter node (#12294)
* add painter node

* use io.Color

* code improve

---------

Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-03-12 09:55:29 -07:00
comfyanonymous
44f1246c89 Support flux 2 klein kv cache model: Use the FluxKVCache node. (#12905) 2026-03-12 11:30:50 -04:00
comfyanonymous
8f9ea49571 Bump comfy-kitchen version to 0.2.8 (#12895) 2026-03-12 00:17:31 -04:00
Comfy Org PR Bot
9ce4c3dd87 Bump comfyui-frontend-package to 1.41.16 (#12894)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-03-11 18:16:30 -07:00
Comfy Org PR Bot
abc87d3669 Bump comfyui-frontend-package to 1.41.15 (#12891)
---------

Co-authored-by: Alexander Brown <DrJKL0424@gmail.com>
2026-03-11 17:04:51 -04:00
comfyanonymous
f6274c06b4 Fix issue with batch_size > 1 on some models. (#12892) 2026-03-11 16:37:31 -04:00
Adi Borochov
4f4f8659c2 fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 (#12874)
* fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0

torch.AcceleratorError was introduced in PyTorch 2.8.0. Accessing it
directly raises AttributeError on older versions. Use a try/except
fallback at module load time, consistent with the existing pattern used
for OOM_EXCEPTION.


* fix: address review feedback for AcceleratorError compat

- Fall back to RuntimeError instead of type(None) for ACCELERATOR_ERROR,
  consistent with OOM_EXCEPTION fallback pattern and valid for except clauses
- Add "out of memory" message introspection for RuntimeError fallback case
- Use RuntimeError directly in discard_cuda_async_error except clause
---------
2026-03-11 10:04:13 -07:00
Alexander Piskun
3365008dfe feat(api-nodes): add Reve Image nodes (#12848) 2026-03-11 09:53:55 -07:00
rattus
980621da83 comfy-aimdo 0.2.10 (#12890)
Comfy Aimdo 0.2.10 fixes the aimdo allocator hook for legacy cudaMalloc
consumers. Some consumers of cudaMalloc assume implicit synchronization
built in closed source logic inside cuda. This is preserved by passing
through to cuda as-is and accouting after the fact as opposed to
integrating these hooks with Aimdos VMA based allocator.
2026-03-11 08:49:38 -07:00
bymyself
ad5b8ca494 feat: add essentials_category to nodes and blueprints for Essentials tab
Add ESSENTIALS_CATEGORY or essentials_category to 12 node classes and all
36 blueprint JSONs. Update SubgraphEntry TypedDict and subgraph_manager to
extract and pass through the field.

Fixes COM-15221

Amp-Thread-ID: https://ampcode.com/threads/T-019c83de-f7ab-7779-a451-0ba5940b56a9
2026-02-22 01:40:34 -08:00
44 changed files with 2144 additions and 206 deletions

View File

@@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
@@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
### Cloud
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.

View File

@@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
class HiddenInputTypeDict(TypedDict):

View File

@@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor

View File

@@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@@ -164,10 +181,6 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@@ -182,6 +195,24 @@ class Flux(nn.Module):
else:
pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out
out = blocks_replace[("double_block", i)]({"img": img,
@@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
transformer_options=transformer_options,
**extra_kwargs)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out
out = blocks_replace[("single_block", i)]({"img": img,
@@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
if ref_latents_method == "index":
if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@@ -342,6 +389,13 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@@ -536,7 +537,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
output.append(sample.to(comfy.model_management.intermediate_device()))
return
up_block = self.up_blocks[idx]

View File

@@ -1,9 +1,68 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

View File

@@ -270,10 +270,15 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
@@ -500,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model):
self._set_model(model)
@@ -527,6 +554,9 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self):
return self.model.loaded_size()
@@ -628,7 +658,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@@ -641,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
ram_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
@@ -656,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if ram_to_free > 0:
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@@ -724,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
@@ -1000,6 +1050,12 @@ def intermediate_device():
else:
return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
@@ -1220,6 +1276,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)
@@ -1275,7 +1336,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except torch.AcceleratorError:
except RuntimeError:
#Dump it! We already know about it from the synchronous return
pass
@@ -1657,6 +1718,19 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@@ -297,6 +297,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
@@ -1063,6 +1066,10 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
pass
@@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:

View File

@@ -306,6 +306,33 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
@@ -333,29 +360,16 @@ class disable_weight_init:
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.in_features, self.out_features),
bias_shape=(self.out_features,),
)
def reset_parameters(self):
@@ -547,6 +561,48 @@ class disable_weight_init:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self):
self.bias = None
return None
@@ -801,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@@ -950,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@@ -1,6 +1,7 @@
import torch
import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
from comfy.cli_args import args
@@ -12,18 +13,31 @@ def pin_memory(module):
return
#FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
else:
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module.pin_failed = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@@ -157,6 +196,14 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility

View File

@@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@@ -950,9 +953,9 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
@@ -1025,9 +1028,9 @@ class VAE:
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:

View File

@@ -20,6 +20,8 @@
import torch
import math
import struct
import ctypes
import os
import comfy.memory_management
import safetensors.torch
import numpy as np
@@ -32,7 +34,7 @@ from einops import rearrange
from comfy.cli_args import args
import json
import time
import mmap
import threading
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@@ -81,14 +83,17 @@ _TYPES = {
}
def load_safetensors(ckpt):
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
import comfy_aimdo.model_mmap
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
f = open(ckpt, "rb", buffering=0)
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
mv = mv[8 + header_size:]
header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
sd = {}
for name, info in header.items():
@@ -102,7 +107,14 @@ def load_safetensors(ckpt):
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}),

View File

@@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL
File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -135,6 +169,7 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples:
has_first_frame = True
break
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > start_time + self.__duration:
if self.__duration and frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:

View File

@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min

View File

@@ -0,0 +1,68 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@@ -1,3 +1,7 @@
import zipfile
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types
@@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
download_url_to_bytesio,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
)
class ObjZipResult:
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
def __init__(
self,
obj: Types.File3D,
texture: Input.Image | None = None,
metallic: Input.Image | None = None,
normal: Input.Image | None = None,
roughness: Input.Image | None = None,
):
self.obj = obj
self.texture = texture
self.metallic = metallic
self.normal = normal
self.roughness = roughness
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
identified by their filename suffixes.
"""
data = BytesIO()
await download_url_to_bytesio(url, data)
data.seek(0)
if not zipfile.is_zipfile(data):
data.seek(0)
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
data.seek(0)
obj_bytes = None
textures: dict[str, Input.Image] = {}
with zipfile.ZipFile(data) as zf:
for name in zf.namelist():
lower = name.lower()
if lower.endswith(".obj"):
obj_bytes = zf.read(name)
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
stem = lower.rsplit(".", 1)[0]
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
matched_key = "texture"
for suffix, key in {
"_metallic": "metallic",
"_normal": "normal",
"_roughness": "roughness",
}.items():
if stem.endswith(suffix):
matched_key = key
break
textures[matched_key] = tensor
if obj_bytes is None:
raise ValueError("ZIP archive does not contain an OBJ file.")
return ObjZipResult(
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
texture=textures.get("texture"),
metallic=textures.get("metallic"),
normal=textures.get("normal"),
roughness=textures.get("roughness"),
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
@@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
obj_result.obj,
obj_result.texture,
)
@@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.Image.Output(display_name="optional_metallic"),
IO.Image.Output(display_name="optional_normal"),
IO.Image.Output(display_name="optional_roughness"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
)
@@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
)
@@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode,
TencentImageToModelNode,
TencentModelTo3DUVNode,
# Tencent3DTextureEditNode,
Tencent3DTextureEditNode,
Tencent3DPartNode,
TencentSmartTopologyNode,
]

View File

@@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),

View File

@@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="api node/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
IO.Image.Input("image"),

View File

@@ -0,0 +1,395 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_headers=resp_headers,
response_content=bytes_payload,
)
return bytes_payload

View File

@@ -0,0 +1,138 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
_logger = logging.getLogger(__name__)
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None
def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False
def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@@ -1,3 +1,4 @@
import asyncio
import bisect
import gc
import itertools
@@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class):
def __init__(self, key_class, enable_providers=False):
self.key_class = key_class
self.initialized = False
self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -196,18 +199,138 @@ class BasicCache:
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
def get_local(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return None
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -236,8 +359,8 @@ class BasicCache:
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
@@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
return None
return cache
def get(self, node_id):
async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
return await cache._get_immediate(node_id)
def set(self, node_id, value):
def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -287,18 +421,24 @@ class NullCache:
def poll(self, **kwargs):
pass
def get(self, node_id):
async def get(self, node_id):
return None
def set(self, node_id, value):
def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
def __init__(self, key_class, max_size=100, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
@@ -322,18 +462,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
async def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
async def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
await super().set(node_id, value)
def get(self, node_id):
async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():

View File

@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
self.output_cache.set_local(from_node_id, value)
return value
def cache_update(self, node_id, value):

View File

@@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
essentials_category="Audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
essentials_category="Audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),

View File

@@ -6,6 +6,7 @@ import comfy.model_management
import torch
import math
import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension):
@override
@@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
FluxKVCache,
]

View File

@@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
essentials_category="Image Tools",
is_experimental=True,
is_output_node=True,
inputs=[

View File

@@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Color.Input("bg_color", default="#000000"),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
base_image = image[:1]
h, w = base_image.shape[1], base_image.shape[2]
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

View File

@@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
node_id="ImageBlend",
display_name="Image Blend",
category="image/postprocessing",
essentials_category="Image Tools",
inputs=[
io.Image.Input("image1"),
io.Image.Input("image2"),

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.16.4"
__version__ = "0.17.0"

View File

@@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@@ -126,15 +127,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id)
cached = await caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id)
obj = await caches.objects.get(unique_id)
if obj is None:
obj = class_def()
caches.objects.set(unique_id, obj)
await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -684,6 +685,19 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
if event == "start":
provider.on_prompt_start(prompt_id)
elif event == "end":
provider.on_prompt_end(prompt_id)
except Exception as e:
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@@ -700,66 +714,75 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
self._notify_prompt_lifecycle("start", prompt_id)
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
node_ids = list(prompt.keys())
cache_results = await asyncio.gather(
*(self.caches.outputs.get(node_id) for node_id in node_ids)
)
cached_nodes = [
node_id for node_id, result in zip(node_ids, cache_results)
if result is not None
]
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):

View File

@@ -1 +1 @@
comfyui_manager==4.1b2
comfyui_manager==4.1b4

View File

@@ -32,7 +32,7 @@ async def cache_control(
)
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache")
response.headers.setdefault("Cache-Control", "no-store")
return response
# Early return for non-image files - no cache headers needed

View File

@@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
class ConditioningCombine:
ESSENTIALS_CATEGORY = "Image Generation"
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@@ -1724,6 +1725,8 @@ class LoadImage:
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
@@ -1748,8 +1751,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
@@ -1779,6 +1782,7 @@ class LoadImage:
return True
class LoadImageMask:
ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"]
@@ -1887,6 +1891,7 @@ class ImageScale:
return (s,)
class ImageScaleBy:
ESSENTIALS_CATEGORY = "Image Tools"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
@@ -2450,6 +2455,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_painter.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.16.4"
version = "0.17.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.18
comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,8 +22,8 @@ alembic
SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.9
comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.12
requests
simpleeval>=1.0.0
blake3

View File

@@ -310,7 +310,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache'
response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response

View File

@@ -0,0 +1,403 @@
"""Tests for external cache provider API."""
import importlib.util
import pytest
from typing import Optional
def _torch_available() -> bool:
"""Check if PyTorch is available."""
return importlib.util.find_spec("torch") is not None
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider,
unregister_cache_provider,
_get_cache_providers,
_has_cache_providers,
_clear_cache_providers,
_serialize_cache_key,
_contains_self_unequal,
_estimate_value_size,
_canonicalize,
)
class TestCanonicalize:
"""Test _canonicalize function for deterministic ordering."""
def test_frozenset_ordering_is_deterministic(self):
"""Frozensets should produce consistent canonical form regardless of iteration order."""
# Create two frozensets with same content
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_nested_frozenset_ordering(self):
"""Nested frozensets should also be deterministically ordered."""
inner1 = frozenset([1, 2, 3])
inner2 = frozenset([3, 2, 1])
fs1 = frozenset([("key", inner1)])
fs2 = frozenset([("key", inner2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_dict_ordering(self):
"""Dicts should be sorted by key."""
d1 = {"z": 1, "a": 2, "m": 3}
d2 = {"a": 2, "m": 3, "z": 1}
result1 = _canonicalize(d1)
result2 = _canonicalize(d2)
assert result1 == result2
def test_tuple_preserved(self):
"""Tuples should be marked and preserved."""
t = (1, 2, 3)
result = _canonicalize(t)
assert result[0] == "__tuple__"
def test_list_preserved(self):
"""Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst)
# First element should be canonicalized dict
assert "__dict__" in result[0]
# Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__"
def test_primitives_include_type(self):
"""Primitive types should include type name for disambiguation."""
assert _canonicalize(42) == ("int", 42)
assert _canonicalize(3.14) == ("float", 3.14)
assert _canonicalize("hello") == ("str", "hello")
assert _canonicalize(True) == ("bool", True)
assert _canonicalize(None) == ("NoneType", None)
def test_int_and_str_distinguished(self):
"""int 7 and str '7' must produce different canonical forms."""
assert _canonicalize(7) != _canonicalize("7")
def test_bytes_converted(self):
"""Bytes should be converted to hex string."""
b = b"\x00\xff"
result = _canonicalize(b)
assert result[0] == "__bytes__"
assert result[1] == "00ff"
def test_set_ordering(self):
"""Sets should be sorted like frozensets."""
s1 = {3, 1, 2}
s2 = {1, 2, 3}
result1 = _canonicalize(s1)
result2 = _canonicalize(s2)
assert result1 == result2
assert result1[0] == "__set__"
def test_unknown_type_raises(self):
"""Unknown types should raise ValueError (fail-closed)."""
class CustomObj:
pass
with pytest.raises(ValueError):
_canonicalize(CustomObj())
def test_object_with_value_attr_raises(self):
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
class FakeUnhashable:
def __init__(self):
self.value = float('nan')
with pytest.raises(ValueError):
_canonicalize(FakeUnhashable())
class TestSerializeCacheKey:
"""Test _serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self):
"""Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 == hash2
def test_different_content_different_hash(self):
"""Different content should produce different hash."""
key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 != hash2
def test_returns_hex_string(self):
"""Should return hex string (SHA256 hex digest)."""
key = frozenset([("test", 123)])
result = _serialize_cache_key(key)
assert isinstance(result, str)
assert len(result) == 64 # SHA256 hex digest is 64 chars
def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically."""
# Note: frozensets can only contain hashable types, so we use
# nested frozensets of tuples to represent dict-like structures
key = frozenset([
("node_1", frozenset([
("input_a", ("tuple", "value")),
("input_b", frozenset([("nested", "dict")])),
])),
("node_2", frozenset([
("param", 42),
])),
])
# Hash twice to verify determinism
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
def test_dict_in_cache_key(self):
"""Dicts passed directly to _serialize_cache_key should work."""
key = {"node_1": {"input": "value"}, "node_2": 42}
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
assert isinstance(hash1, str)
assert len(hash1) == 64
def test_unknown_type_returns_none(self):
"""Non-cacheable types should return None (fail-closed)."""
class CustomObj:
pass
assert _serialize_cache_key(CustomObj()) is None
class TestContainsSelfUnequal:
"""Test _contains_self_unequal utility function."""
def test_nan_float_detected(self):
"""NaN floats should be detected (not equal to itself)."""
assert _contains_self_unequal(float('nan')) is True
def test_regular_float_not_detected(self):
"""Regular floats are equal to themselves."""
assert _contains_self_unequal(3.14) is False
assert _contains_self_unequal(0.0) is False
assert _contains_self_unequal(-1.5) is False
def test_infinity_not_detected(self):
"""Infinity is equal to itself."""
assert _contains_self_unequal(float('inf')) is False
assert _contains_self_unequal(float('-inf')) is False
def test_nan_in_list(self):
"""NaN in list should be detected."""
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
assert _contains_self_unequal([1, 2, 3, 4]) is False
def test_nan_in_tuple(self):
"""NaN in tuple should be detected."""
assert _contains_self_unequal((1, float('nan'))) is True
assert _contains_self_unequal((1, 2, 3)) is False
def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected."""
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self):
"""NaN in dict value should be detected."""
assert _contains_self_unequal({"key": float('nan')}) is True
assert _contains_self_unequal({"key": 42}) is False
def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert _contains_self_unequal(nested) is True
def test_non_numeric_types(self):
"""Non-numeric types should not be self-unequal."""
assert _contains_self_unequal("string") is False
assert _contains_self_unequal(None) is False
assert _contains_self_unequal(True) is False
def test_object_with_nan_value_attribute(self):
"""Objects wrapping NaN in .value should be detected."""
class NanWrapper:
def __init__(self):
self.value = float('nan')
assert _contains_self_unequal(NanWrapper()) is True
def test_custom_self_unequal_object(self):
"""Custom objects where not (x == x) should be detected."""
class NeverEqual:
def __eq__(self, other):
return False
assert _contains_self_unequal(NeverEqual()) is True
class TestEstimateValueSize:
"""Test _estimate_value_size utility function."""
def test_empty_outputs(self):
"""Empty outputs should have zero size."""
value = CacheValue(outputs=[])
assert _estimate_value_size(value) == 0
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_tensor_size_estimation(self):
"""Tensor size should be estimated correctly."""
import torch
# 1000 float32 elements = 4000 bytes
tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]])
size = _estimate_value_size(value)
assert size == 4000
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_nested_tensor_in_dict(self):
"""Tensors nested in dicts should be counted."""
import torch
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]])
size = _estimate_value_size(value)
assert size == 400
class TestProviderRegistry:
"""Test cache provider registration and retrieval."""
def setup_method(self):
"""Clear providers before each test."""
_clear_cache_providers()
def teardown_method(self):
"""Clear providers after each test."""
_clear_cache_providers()
def test_register_provider(self):
"""Provider should be registered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
assert _has_cache_providers() is True
providers = _get_cache_providers()
assert len(providers) == 1
assert providers[0] is provider
def test_unregister_provider(self):
"""Provider should be unregistered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
unregister_cache_provider(provider)
assert _has_cache_providers() is False
def test_multiple_providers(self):
"""Multiple providers can be registered."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
providers = _get_cache_providers()
assert len(providers) == 2
def test_duplicate_registration_ignored(self):
"""Registering same provider twice should be ignored."""
provider = MockCacheProvider()
register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored
providers = _get_cache_providers()
assert len(providers) == 1
def test_clear_providers(self):
"""_clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
_clear_cache_providers()
assert _has_cache_providers() is False
assert len(_get_cache_providers()) == 0
class TestCacheContext:
"""Test CacheContext dataclass."""
def test_context_creation(self):
"""CacheContext should be created with all fields."""
context = CacheContext(
node_id="node-456",
class_type="KSampler",
cache_key_hash="a" * 64,
)
assert context.node_id == "node-456"
assert context.class_type == "KSampler"
assert context.cache_key_hash == "a" * 64
class TestCacheValue:
"""Test CacheValue dataclass."""
def test_value_creation(self):
"""CacheValue should be created with outputs."""
outputs = [[{"samples": "tensor_data"}]]
value = CacheValue(outputs=outputs)
assert value.outputs == outputs
class MockCacheProvider(CacheProvider):
"""Mock cache provider for testing."""
def __init__(self):
self.lookups = []
self.stores = []
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context)
return None
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value))

View File

@@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
},
# JavaScript/CSS scenarios
{
"name": "js_no_cache",
"name": "js_no_store",
"path": "/script.js",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "css_no_cache",
"name": "css_no_store",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "index_json_no_cache",
"name": "index_json_no_store",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "localized_index_json_no_cache",
"name": "localized_index_json_no_store",
"path": "/templates/index.zh.json",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
# Non-matching files