mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-14 12:29:21 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16cd8d8a8f | ||
|
|
7810f49702 | ||
|
|
e1f10ca093 | ||
|
|
6cd35a0c5f | ||
|
|
f9ceed9eef | ||
|
|
4a8cf359fe |
11
README.md
11
README.md
@@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
|
||||
## Get Started
|
||||
|
||||
### Local
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
@@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
### Cloud
|
||||
|
||||
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
||||
- Our official paid cloud version for those who can't afford local hardware.
|
||||
|
||||
## Examples
|
||||
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
|
||||
@@ -1,9 +1,68 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
@@ -505,6 +505,28 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
@@ -532,6 +554,9 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@@ -633,7 +658,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
can_unload_sorted = sorted(can_unload)
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
@@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@@ -1225,6 +1270,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
|
||||
@@ -297,6 +297,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
@@ -1063,6 +1066,10 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def pinned_memory_size(self):
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
@@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
return freed
|
||||
|
||||
def pinned_memory_size(self):
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
|
||||
102
comfy/ops.py
102
comfy/ops.py
@@ -306,6 +306,33 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
bias_shape=None):
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k, v in state_dict.items():
|
||||
key = k[prefix_len:]
|
||||
if key == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif bias_shape is not None and key == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
if module.weight is None:
|
||||
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "weight")
|
||||
|
||||
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "bias")
|
||||
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
@@ -333,29 +360,16 @@ class disable_weight_init:
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.in_features, self.out_features),
|
||||
bias_shape=(self.out_features,),
|
||||
)
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
@@ -547,6 +561,48 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||
_freeze=False, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||
_freeze, device, dtype)
|
||||
return
|
||||
|
||||
torch.nn.Module.__init__(self)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.bias = None
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -12,18 +13,31 @@ def pin_memory(module):
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import ctypes
|
||||
import os
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
@@ -32,7 +34,7 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@@ -81,14 +83,17 @@ _TYPES = {
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
|
||||
mv = mv[8 + header_size:]
|
||||
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||
|
||||
mv = mv[(data_base_offset := 8 + header_size):]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@@ -102,7 +107,14 @@ def load_safetensors(ckpt):
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
self.caching = self.Caching()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
@@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
class Caching(ProxiedSingleton):
|
||||
"""
|
||||
External cache provider API for sharing cached node outputs
|
||||
across ComfyUI instances.
|
||||
|
||||
Example::
|
||||
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
class MyCacheProvider(Caching.CacheProvider):
|
||||
async def on_lookup(self, context):
|
||||
... # check external storage
|
||||
|
||||
async def on_store(self, context, value):
|
||||
... # store to external storage
|
||||
|
||||
Caching.register_provider(MyCacheProvider())
|
||||
"""
|
||||
from ._caching import CacheProvider, CacheContext, CacheValue
|
||||
|
||||
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
from comfy_execution.cache_provider import register_cache_provider
|
||||
register_cache_provider(provider)
|
||||
|
||||
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Unregister a previously registered cache provider."""
|
||||
from comfy_execution.cache_provider import unregister_cache_provider
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
@@ -116,6 +147,9 @@ class Types:
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
|
||||
Caching = ComfyAPI_latest.Caching
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
@@ -135,6 +169,7 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"Caching",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
|
||||
42
comfy_api/latest/_caching.py
Normal file
42
comfy_api/latest/_caching.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
node_id: str
|
||||
class_type: str
|
||||
cache_key_hash: str # SHA256 hex digest
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
outputs: list
|
||||
ui: dict = None
|
||||
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""Abstract base class for external cache providers.
|
||||
Exceptions from provider methods are caught by the caller and never break execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""Called after local store. Dispatched via asyncio.create_task."""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""Return False to skip external caching for this node. Default: True."""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
pass
|
||||
@@ -1,3 +1,7 @@
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
@@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
)
|
||||
|
||||
|
||||
class ObjZipResult:
|
||||
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Types.File3D,
|
||||
texture: Input.Image | None = None,
|
||||
metallic: Input.Image | None = None,
|
||||
normal: Input.Image | None = None,
|
||||
roughness: Input.Image | None = None,
|
||||
):
|
||||
self.obj = obj
|
||||
self.texture = texture
|
||||
self.metallic = metallic
|
||||
self.normal = normal
|
||||
self.roughness = roughness
|
||||
|
||||
|
||||
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
|
||||
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
|
||||
|
||||
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
|
||||
identified by their filename suffixes.
|
||||
"""
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(url, data)
|
||||
data.seek(0)
|
||||
if not zipfile.is_zipfile(data):
|
||||
data.seek(0)
|
||||
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
|
||||
data.seek(0)
|
||||
obj_bytes = None
|
||||
textures: dict[str, Input.Image] = {}
|
||||
with zipfile.ZipFile(data) as zf:
|
||||
for name in zf.namelist():
|
||||
lower = name.lower()
|
||||
if lower.endswith(".obj"):
|
||||
obj_bytes = zf.read(name)
|
||||
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
|
||||
stem = lower.rsplit(".", 1)[0]
|
||||
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
|
||||
matched_key = "texture"
|
||||
for suffix, key in {
|
||||
"_metallic": "metallic",
|
||||
"_normal": "normal",
|
||||
"_roughness": "roughness",
|
||||
}.items():
|
||||
if stem.endswith(suffix):
|
||||
matched_key = key
|
||||
break
|
||||
textures[matched_key] = tensor
|
||||
if obj_bytes is None:
|
||||
raise ValueError("ZIP archive does not contain an OBJ file.")
|
||||
return ObjZipResult(
|
||||
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
|
||||
texture=textures.get("texture"),
|
||||
metallic=textures.get("metallic"),
|
||||
normal=textures.get("normal"),
|
||||
roughness=textures.get("roughness"),
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
@@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
)
|
||||
|
||||
|
||||
@@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
IO.Image.Output(display_name="optional_metallic"),
|
||||
IO.Image.Output(display_name="optional_normal"),
|
||||
IO.Image.Output(display_name="optional_roughness"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||
)
|
||||
|
||||
|
||||
@@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
|
||||
)
|
||||
|
||||
|
||||
@@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
TencentSmartTopologyNode,
|
||||
]
|
||||
|
||||
138
comfy_execution/cache_provider.py
Normal file
138
comfy_execution/cache_provider.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Any, Optional, Tuple, List
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
|
||||
# Public types — source of truth is comfy_api.latest._caching
|
||||
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Tuple[CacheProvider, ...] = ()
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
_providers.remove(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||
except ValueError:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||
|
||||
|
||||
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
return _providers_snapshot
|
||||
|
||||
|
||||
def _has_cache_providers() -> bool:
|
||||
return bool(_providers_snapshot)
|
||||
|
||||
|
||||
def _clear_cache_providers() -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = ()
|
||||
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
||||
# _serialize_cache_key returns None and external caching is skipped.
|
||||
if isinstance(obj, frozenset):
|
||||
return ("__frozenset__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, tuple):
|
||||
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||
elif isinstance(obj, list):
|
||||
return [_canonicalize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {"__dict__": sorted(
|
||||
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
)}
|
||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return (type(obj).__name__, obj)
|
||||
elif isinstance(obj, bytes):
|
||||
return ("__bytes__", obj.hex())
|
||||
else:
|
||||
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
||||
|
||||
|
||||
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||
# Returns deterministic SHA256 hex digest, or None on failure.
|
||||
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to serialize cache key: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _contains_self_unequal(obj: Any) -> bool:
|
||||
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
||||
# never hit locally, but serialized form would match externally. Skip these.
|
||||
try:
|
||||
if not (obj == obj):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||
return any(_contains_self_unequal(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
||||
if hasattr(obj, 'value'):
|
||||
return _contains_self_unequal(obj.value)
|
||||
return False
|
||||
|
||||
|
||||
def _estimate_value_size(value: CacheValue) -> int:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
|
||||
def estimate(obj):
|
||||
nonlocal total
|
||||
if isinstance(obj, torch.Tensor):
|
||||
total += obj.numel() * obj.element_size()
|
||||
elif isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
estimate(v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
estimate(item)
|
||||
|
||||
for output in value.outputs:
|
||||
estimate(output)
|
||||
return total
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
@@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
self.key_class = key_class
|
||||
self.initialized = False
|
||||
self.enable_providers = enable_providers
|
||||
self.dynprompt: DynamicPrompt
|
||||
self.cache_key_set: CacheKeySet
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
self._pending_store_tasks: set = set()
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
@@ -196,18 +199,138 @@ class BasicCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
def get_local(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
async def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
await self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
async def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
if not self._is_external_cacheable_value(value):
|
||||
return
|
||||
if _contains_self_unequal(cache_key):
|
||||
return
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return
|
||||
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if provider.should_cache(context, cache_value):
|
||||
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
||||
self._pending_store_tasks.add(task)
|
||||
task.add_done_callback(self._pending_store_tasks.discard)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_store(provider, context, cache_value):
|
||||
from comfy_execution.cache_provider import _logger
|
||||
try:
|
||||
await provider.on_store(context, cache_value)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||
|
||||
async def _check_providers_lookup(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return None
|
||||
if not _has_cache_providers():
|
||||
return None
|
||||
if _contains_self_unequal(cache_key):
|
||||
return None
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if not provider.should_cache(context):
|
||||
continue
|
||||
result = await provider.on_lookup(context)
|
||||
if result is not None:
|
||||
if not isinstance(result, CacheValue):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||
continue
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_external_cacheable_value(self, value):
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
def _build_context(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
||||
try:
|
||||
cache_key_hash = _serialize_cache_key(cache_key)
|
||||
if cache_key_hash is None:
|
||||
return None
|
||||
return CacheContext(
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key_hash=cache_key_hash,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
@@ -236,8 +359,8 @@ class BasicCache:
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
assert self.dynprompt is not None
|
||||
@@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
return await cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
def get_local(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return BasicCache.get_local(cache, node_id)
|
||||
|
||||
async def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
await cache._set_immediate(node_id, value)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
BasicCache.set_local(cache, node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
@@ -287,18 +421,24 @@ class NullCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
return None
|
||||
|
||||
def set(self, node_id, value):
|
||||
def get_local(self, node_id):
|
||||
return None
|
||||
|
||||
async def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
def __init__(self, key_class, max_size=100, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
@@ -322,18 +462,18 @@ class LRUCache(BasicCache):
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
return await self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
async def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
return await self._set_immediate(node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
@@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class, 0)
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, 0, enable_providers=enable_providers)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
|
||||
def set(self, node_id, value):
|
||||
async def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set(node_id, value)
|
||||
await super().set(node_id, value)
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
return await super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
|
||||
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
return self.output_cache.get_local(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
self.output_cache.set_local(from_node_id, value)
|
||||
return value
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
|
||||
147
execution.py
147
execution.py
@@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -126,15 +127,15 @@ class CacheSet:
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
@@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = caches.outputs.get(unique_id)
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
@@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
obj = await caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
await caches.objects.set(unique_id, obj)
|
||||
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
@@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||
execution_list.cache_update(unique_id, cache_entry)
|
||||
caches.outputs.set(unique_id, cache_entry)
|
||||
await caches.outputs.set(unique_id, cache_entry)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
@@ -684,6 +685,19 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if event == "start":
|
||||
provider.on_prompt_start(prompt_id)
|
||||
elif event == "end":
|
||||
provider.on_prompt_end(prompt_id)
|
||||
except Exception as e:
|
||||
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
@@ -700,66 +714,75 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
node_ids = list(prompt.keys())
|
||||
cache_results = await asyncio.gather(
|
||||
*(self.caches.outputs.get(node_id) for node_id in node_ids)
|
||||
)
|
||||
cached_nodes = [
|
||||
node_id for node_id, result in zip(node_ids, cache_results)
|
||||
if result is not None
|
||||
]
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b2
|
||||
comfyui_manager==4.1b4
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.18
|
||||
comfyui-frontend-package==1.41.19
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
@@ -23,7 +23,7 @@ SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.10
|
||||
comfy-aimdo>=0.2.11
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
403
tests-unit/execution_test/test_cache_provider.py
Normal file
403
tests-unit/execution_test/test_cache_provider.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Tests for external cache provider API."""
|
||||
|
||||
import importlib.util
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _torch_available() -> bool:
|
||||
"""Check if PyTorch is available."""
|
||||
return importlib.util.find_spec("torch") is not None
|
||||
|
||||
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider,
|
||||
CacheContext,
|
||||
CacheValue,
|
||||
register_cache_provider,
|
||||
unregister_cache_provider,
|
||||
_get_cache_providers,
|
||||
_has_cache_providers,
|
||||
_clear_cache_providers,
|
||||
_serialize_cache_key,
|
||||
_contains_self_unequal,
|
||||
_estimate_value_size,
|
||||
_canonicalize,
|
||||
)
|
||||
|
||||
|
||||
class TestCanonicalize:
|
||||
"""Test _canonicalize function for deterministic ordering."""
|
||||
|
||||
def test_frozenset_ordering_is_deterministic(self):
|
||||
"""Frozensets should produce consistent canonical form regardless of iteration order."""
|
||||
# Create two frozensets with same content
|
||||
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
|
||||
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_nested_frozenset_ordering(self):
|
||||
"""Nested frozensets should also be deterministically ordered."""
|
||||
inner1 = frozenset([1, 2, 3])
|
||||
inner2 = frozenset([3, 2, 1])
|
||||
|
||||
fs1 = frozenset([("key", inner1)])
|
||||
fs2 = frozenset([("key", inner2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_dict_ordering(self):
|
||||
"""Dicts should be sorted by key."""
|
||||
d1 = {"z": 1, "a": 2, "m": 3}
|
||||
d2 = {"a": 2, "m": 3, "z": 1}
|
||||
|
||||
result1 = _canonicalize(d1)
|
||||
result2 = _canonicalize(d2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_tuple_preserved(self):
|
||||
"""Tuples should be marked and preserved."""
|
||||
t = (1, 2, 3)
|
||||
result = _canonicalize(t)
|
||||
|
||||
assert result[0] == "__tuple__"
|
||||
|
||||
def test_list_preserved(self):
|
||||
"""Lists should be recursively canonicalized."""
|
||||
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
||||
result = _canonicalize(lst)
|
||||
|
||||
# First element should be canonicalized dict
|
||||
assert "__dict__" in result[0]
|
||||
# Second element should be canonicalized frozenset
|
||||
assert result[1][0] == "__frozenset__"
|
||||
|
||||
def test_primitives_include_type(self):
|
||||
"""Primitive types should include type name for disambiguation."""
|
||||
assert _canonicalize(42) == ("int", 42)
|
||||
assert _canonicalize(3.14) == ("float", 3.14)
|
||||
assert _canonicalize("hello") == ("str", "hello")
|
||||
assert _canonicalize(True) == ("bool", True)
|
||||
assert _canonicalize(None) == ("NoneType", None)
|
||||
|
||||
def test_int_and_str_distinguished(self):
|
||||
"""int 7 and str '7' must produce different canonical forms."""
|
||||
assert _canonicalize(7) != _canonicalize("7")
|
||||
|
||||
def test_bytes_converted(self):
|
||||
"""Bytes should be converted to hex string."""
|
||||
b = b"\x00\xff"
|
||||
result = _canonicalize(b)
|
||||
|
||||
assert result[0] == "__bytes__"
|
||||
assert result[1] == "00ff"
|
||||
|
||||
def test_set_ordering(self):
|
||||
"""Sets should be sorted like frozensets."""
|
||||
s1 = {3, 1, 2}
|
||||
s2 = {1, 2, 3}
|
||||
|
||||
result1 = _canonicalize(s1)
|
||||
result2 = _canonicalize(s2)
|
||||
|
||||
assert result1 == result2
|
||||
assert result1[0] == "__set__"
|
||||
|
||||
def test_unknown_type_raises(self):
|
||||
"""Unknown types should raise ValueError (fail-closed)."""
|
||||
class CustomObj:
|
||||
pass
|
||||
with pytest.raises(ValueError):
|
||||
_canonicalize(CustomObj())
|
||||
|
||||
def test_object_with_value_attr_raises(self):
|
||||
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
|
||||
class FakeUnhashable:
|
||||
def __init__(self):
|
||||
self.value = float('nan')
|
||||
with pytest.raises(ValueError):
|
||||
_canonicalize(FakeUnhashable())
|
||||
|
||||
|
||||
class TestSerializeCacheKey:
|
||||
"""Test _serialize_cache_key for deterministic hashing."""
|
||||
|
||||
def test_same_content_same_hash(self):
|
||||
"""Same content should produce same hash."""
|
||||
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
|
||||
hash1 = _serialize_cache_key(key1)
|
||||
hash2 = _serialize_cache_key(key2)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_different_content_different_hash(self):
|
||||
"""Different content should produce different hash."""
|
||||
key1 = frozenset([("node_1", "value_a")])
|
||||
key2 = frozenset([("node_1", "value_b")])
|
||||
|
||||
hash1 = _serialize_cache_key(key1)
|
||||
hash2 = _serialize_cache_key(key2)
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_returns_hex_string(self):
|
||||
"""Should return hex string (SHA256 hex digest)."""
|
||||
key = frozenset([("test", 123)])
|
||||
result = _serialize_cache_key(key)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 64 # SHA256 hex digest is 64 chars
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Complex nested structures should hash deterministically."""
|
||||
# Note: frozensets can only contain hashable types, so we use
|
||||
# nested frozensets of tuples to represent dict-like structures
|
||||
key = frozenset([
|
||||
("node_1", frozenset([
|
||||
("input_a", ("tuple", "value")),
|
||||
("input_b", frozenset([("nested", "dict")])),
|
||||
])),
|
||||
("node_2", frozenset([
|
||||
("param", 42),
|
||||
])),
|
||||
])
|
||||
|
||||
# Hash twice to verify determinism
|
||||
hash1 = _serialize_cache_key(key)
|
||||
hash2 = _serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_dict_in_cache_key(self):
|
||||
"""Dicts passed directly to _serialize_cache_key should work."""
|
||||
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||
|
||||
hash1 = _serialize_cache_key(key)
|
||||
hash2 = _serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert isinstance(hash1, str)
|
||||
assert len(hash1) == 64
|
||||
|
||||
def test_unknown_type_returns_none(self):
|
||||
"""Non-cacheable types should return None (fail-closed)."""
|
||||
class CustomObj:
|
||||
pass
|
||||
assert _serialize_cache_key(CustomObj()) is None
|
||||
|
||||
|
||||
class TestContainsSelfUnequal:
|
||||
"""Test _contains_self_unequal utility function."""
|
||||
|
||||
def test_nan_float_detected(self):
|
||||
"""NaN floats should be detected (not equal to itself)."""
|
||||
assert _contains_self_unequal(float('nan')) is True
|
||||
|
||||
def test_regular_float_not_detected(self):
|
||||
"""Regular floats are equal to themselves."""
|
||||
assert _contains_self_unequal(3.14) is False
|
||||
assert _contains_self_unequal(0.0) is False
|
||||
assert _contains_self_unequal(-1.5) is False
|
||||
|
||||
def test_infinity_not_detected(self):
|
||||
"""Infinity is equal to itself."""
|
||||
assert _contains_self_unequal(float('inf')) is False
|
||||
assert _contains_self_unequal(float('-inf')) is False
|
||||
|
||||
def test_nan_in_list(self):
|
||||
"""NaN in list should be detected."""
|
||||
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
|
||||
assert _contains_self_unequal([1, 2, 3, 4]) is False
|
||||
|
||||
def test_nan_in_tuple(self):
|
||||
"""NaN in tuple should be detected."""
|
||||
assert _contains_self_unequal((1, float('nan'))) is True
|
||||
assert _contains_self_unequal((1, 2, 3)) is False
|
||||
|
||||
def test_nan_in_frozenset(self):
|
||||
"""NaN in frozenset should be detected."""
|
||||
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
|
||||
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
|
||||
|
||||
def test_nan_in_dict_value(self):
|
||||
"""NaN in dict value should be detected."""
|
||||
assert _contains_self_unequal({"key": float('nan')}) is True
|
||||
assert _contains_self_unequal({"key": 42}) is False
|
||||
|
||||
def test_nan_in_nested_structure(self):
|
||||
"""NaN in deeply nested structure should be detected."""
|
||||
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||
assert _contains_self_unequal(nested) is True
|
||||
|
||||
def test_non_numeric_types(self):
|
||||
"""Non-numeric types should not be self-unequal."""
|
||||
assert _contains_self_unequal("string") is False
|
||||
assert _contains_self_unequal(None) is False
|
||||
assert _contains_self_unequal(True) is False
|
||||
|
||||
def test_object_with_nan_value_attribute(self):
|
||||
"""Objects wrapping NaN in .value should be detected."""
|
||||
class NanWrapper:
|
||||
def __init__(self):
|
||||
self.value = float('nan')
|
||||
assert _contains_self_unequal(NanWrapper()) is True
|
||||
|
||||
def test_custom_self_unequal_object(self):
|
||||
"""Custom objects where not (x == x) should be detected."""
|
||||
class NeverEqual:
|
||||
def __eq__(self, other):
|
||||
return False
|
||||
assert _contains_self_unequal(NeverEqual()) is True
|
||||
|
||||
|
||||
class TestEstimateValueSize:
|
||||
"""Test _estimate_value_size utility function."""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
"""Empty outputs should have zero size."""
|
||||
value = CacheValue(outputs=[])
|
||||
assert _estimate_value_size(value) == 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_tensor_size_estimation(self):
|
||||
"""Tensor size should be estimated correctly."""
|
||||
import torch
|
||||
|
||||
# 1000 float32 elements = 4000 bytes
|
||||
tensor = torch.zeros(1000, dtype=torch.float32)
|
||||
value = CacheValue(outputs=[[tensor]])
|
||||
|
||||
size = _estimate_value_size(value)
|
||||
assert size == 4000
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_nested_tensor_in_dict(self):
|
||||
"""Tensors nested in dicts should be counted."""
|
||||
import torch
|
||||
|
||||
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
value = CacheValue(outputs=[[{"samples": tensor}]])
|
||||
|
||||
size = _estimate_value_size(value)
|
||||
assert size == 400
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test cache provider registration and retrieval."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear providers before each test."""
|
||||
_clear_cache_providers()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clear providers after each test."""
|
||||
_clear_cache_providers()
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Provider should be registered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
|
||||
assert _has_cache_providers() is True
|
||||
providers = _get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0] is provider
|
||||
|
||||
def test_unregister_provider(self):
|
||||
"""Provider should be unregistered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
assert _has_cache_providers() is False
|
||||
|
||||
def test_multiple_providers(self):
|
||||
"""Multiple providers can be registered."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
|
||||
providers = _get_cache_providers()
|
||||
assert len(providers) == 2
|
||||
|
||||
def test_duplicate_registration_ignored(self):
|
||||
"""Registering same provider twice should be ignored."""
|
||||
provider = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider)
|
||||
register_cache_provider(provider) # Should be ignored
|
||||
|
||||
providers = _get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
|
||||
def test_clear_providers(self):
|
||||
"""_clear_cache_providers should remove all providers."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
_clear_cache_providers()
|
||||
|
||||
assert _has_cache_providers() is False
|
||||
assert len(_get_cache_providers()) == 0
|
||||
|
||||
|
||||
class TestCacheContext:
|
||||
"""Test CacheContext dataclass."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""CacheContext should be created with all fields."""
|
||||
context = CacheContext(
|
||||
node_id="node-456",
|
||||
class_type="KSampler",
|
||||
cache_key_hash="a" * 64,
|
||||
)
|
||||
|
||||
assert context.node_id == "node-456"
|
||||
assert context.class_type == "KSampler"
|
||||
assert context.cache_key_hash == "a" * 64
|
||||
|
||||
|
||||
class TestCacheValue:
|
||||
"""Test CacheValue dataclass."""
|
||||
|
||||
def test_value_creation(self):
|
||||
"""CacheValue should be created with outputs."""
|
||||
outputs = [[{"samples": "tensor_data"}]]
|
||||
value = CacheValue(outputs=outputs)
|
||||
|
||||
assert value.outputs == outputs
|
||||
|
||||
|
||||
class MockCacheProvider(CacheProvider):
|
||||
"""Mock cache provider for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.lookups = []
|
||||
self.stores = []
|
||||
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
self.lookups.append(context)
|
||||
return None
|
||||
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
self.stores.append((context, value))
|
||||
Reference in New Issue
Block a user