Compare commits

..

17 Commits

Author SHA1 Message Date
xeinherjer
dff0a4a158 Fix VAEDecodeAudioTiled ignoring tile_size input (#12735) (#12738) 2026-03-02 20:17:51 -05:00
Lodestone
9ebee0a217 Feat: z-image pixel space (model still training atm) (#12709)
* draft zeta (z-image pixel space)

* revert gitignore

* model loaded and able to run however vector direction still wrong tho

* flip the vector direction to original again this time

* Move wrongly positioned Z image pixel space class

* inherit Radiance LatentFormat class

* Fix parameters in classes for Zeta x0 dino

* remove arbitrary nn.init instances

* Remove unused import of lru_cache

---------

Co-authored-by: silveroxides <ishimarukaito@gmail.com>
2026-03-02 19:43:47 -05:00
comfyanonymous
57dd6c1aad Support loading zeta chroma weights properly. (#12734) 2026-03-02 18:54:18 -05:00
ComfyUI Wiki
f1f8996e15 chore: update workflow templates to v0.9.5 (#12732) 2026-03-02 09:13:42 -08:00
Alexander Piskun
afb54219fa feat(api-nodes): allow to use "IMAGE+TEXT" in NanoBanana2 (#12729) 2026-03-01 23:24:33 -08:00
rattus
7175c11a4e comfy aimdo 0.2.4 (#12727)
Comfy Aimdo 0.2.4 fixes a VRAM buffer alignment issue that happens in
someworkflows where action is able to bypass the pytorch allocator
and go straight to the cuda hook.
2026-03-01 22:21:41 -08:00
rattus
dfbf99a061 model_mangament: make dynamic --disable-smart-memory work (#12724)
This was previously considering the pool of dynamic models as one giant
entity for the sake of smart memory, but that isnt really the useful
or what a user would reasonably expect. Make Dynamic VRAM properly purge
its models just like the old --disable-smart-memory but conditioning
the dynamic-for-dynamic bypass on smart memory.

Re-enable dynamic smart memory.
2026-03-01 19:18:56 -08:00
comfyanonymous
602f6bd82c Make --disable-smart-memory disable dynamic vram. (#12722) 2026-03-01 15:28:39 -05:00
rattus
c0d472e5b9 comfy-aimdo 0.2.3 (#12720) 2026-03-01 11:14:56 -08:00
drozbay
4d79f4f028 fix: handle substep sigmas in context window set_step (#12719)
Multi-step samplers (eg. dpmpp_2s_ancestral) call the model at intermediate sigma values not present in the schedule. This caused set_step to crash with "No sample_sigmas matched current timestep" when context windows were enabled.

The fix is to keep self._step from the last exact match when a substep sigma is encountered, since substeps are still logically part of their parent step and should use the same context windows.

Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2026-03-01 09:38:30 -08:00
Christian Byrne
850e8b42ff feat: add text preview support to jobs API (#12169)
* feat: add text preview support to jobs API

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375
Co-authored-by: Amp <amp@ampcode.com>

* test: update tests to expect text as previewable media type

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375

---------
2026-02-28 21:38:19 -08:00
Christian Byrne
d159142615 refactor: rename Mahiro CFG to Similarity-Adaptive Guidance (#12172)
* refactor: rename Mahiro CFG to Similarity-Adaptive Guidance

Rename the display name to better describe what the node does:
adaptively blends guidance based on cosine similarity between
positive and negative conditions.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1
Co-authored-by: Amp <amp@ampcode.com>

* feat: add search aliases for old mahiro name

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1

* rename: Similarity-Adaptive Guidance → Positive-Biased Guidance (per reviewer)

- display_name changed to 'Positive-Biased Guidance' to avoid SAG acronym collision
- search_aliases expanded: mahiro, mahiro cfg, similarity-adaptive guidance, positive-biased cfg
- ruff format applied

---------

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-28 20:59:24 -08:00
comfyanonymous
1080bd442a Disable dynamic vram on wsl. (#12706) 2026-02-28 22:23:28 -05:00
comfyanonymous
17106cb124 Move parsing of requirements logic to function. (#12701) 2026-02-28 22:21:32 -05:00
rattus
48bb0bd18a cli_args: Default comfy to DynamicVram mode (#12658) 2026-02-28 16:52:30 -05:00
rattus
5f41584e96 Disable dynamic_vram when weight hooks applied (#12653)
* sd: add support for clip model reconstruction

* nodes: SetClipHooks: Demote the dynamic model patcher

* mp: Make dynamic_disable more robust

The backup need to not be cloned. In addition add a delegate object
to ModelPatcherDynamic so that non-cloning code can do
ModelPatcherDynamic demotion

* sampler_helpers: Demote to non-dynamic model patcher when hooking

* code rabbit review comments
2026-02-28 16:50:18 -05:00
Jukka Seppänen
1f6744162f feat: Support SCAIL WanVideo model (#12614) 2026-02-28 16:49:12 -05:00
65 changed files with 857 additions and 7617 deletions

View File

@@ -17,7 +17,7 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
@@ -45,25 +45,7 @@ def get_installed_frontend_version():
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-frontend-package", None)
def check_frontend_version():
@@ -217,25 +199,7 @@ class FrontendManager:
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-workflow-templates", None)
@classmethod
def default_frontend_path(cls) -> str:

View File

@@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@@ -179,8 +179,6 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
@@ -262,4 +260,4 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
return # substep from multi-step sampler: keep self._step from the last full step
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:

View File

@@ -14,9 +14,6 @@ if TYPE_CHECKING:
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.cli_args import args
import uuid
import os
from node_helpers import conditioning_set_values
# #######################################################################################################
@@ -64,37 +61,8 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
class _HookRef:
def __init__(self):
if _ISOLATION_HOOKREF_MODE:
self._pyisolate_id = str(uuid.uuid4())
def _ensure_pyisolate_id(self):
pyisolate_id = getattr(self, "_pyisolate_id", None)
if pyisolate_id is None:
pyisolate_id = str(uuid.uuid4())
self._pyisolate_id = pyisolate_id
return pyisolate_id
def __eq__(self, other):
if not _ISOLATION_HOOKREF_MODE:
return self is other
if not isinstance(other, _HookRef):
return False
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
def __hash__(self):
if not _ISOLATION_HOOKREF_MODE:
return id(self)
return hash(self._ensure_pyisolate_id())
def __str__(self):
if not _ISOLATION_HOOKREF_MODE:
return super().__str__()
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
pass
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@@ -200,8 +168,6 @@ class WeightHook(Hook):
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if self.weights is None:
self.weights = {}
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Clip:

View File

@@ -1,327 +0,0 @@
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
from __future__ import annotations
import asyncio
import inspect
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, TYPE_CHECKING
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
if TYPE_CHECKING:
from pyisolate import ExtensionManager
from .extension_wrapper import ComfyNodeExtension
LOG_PREFIX = "]["
isolated_node_timings: List[tuple[float, Path, int]] = []
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
def initialize_proxies() -> None:
from .child_hooks import is_child_process
is_child = is_child_process()
if is_child:
from .child_hooks import initialize_child_process
initialize_child_process()
else:
from .host_hooks import initialize_host_process
initialize_host_process()
start_shm_forensics()
@dataclass(frozen=True)
class IsolatedNodeSpec:
node_name: str
display_name: str
stub_class: type
module_path: Path
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
_CLAIMED_PATHS: Set[Path] = set()
_ISOLATION_SCAN_ATTEMPTED = False
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
_EARLY_START_TIME: Optional[float] = None
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
if _ISOLATION_BACKGROUND_TASK is not None:
return
_EARLY_START_TIME = time.perf_counter()
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
if _ISOLATION_BACKGROUND_TASK is not None:
specs = await _ISOLATION_BACKGROUND_TASK
return specs
return await initialize_isolation_nodes()
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
if _ISOLATED_NODE_SPECS:
return _ISOLATED_NODE_SPECS
if _ISOLATION_SCAN_ATTEMPTED:
return []
_ISOLATION_SCAN_ATTEMPTED = True
manifest_entries = find_manifest_directories()
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
if not manifest_entries:
return []
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
semaphore = asyncio.Semaphore(concurrency_limit)
async def load_with_semaphore(
node_dir: Path, manifest: Path
) -> List[IsolatedNodeSpec]:
async with semaphore:
load_start = time.perf_counter()
spec_list = await load_isolated_node(
node_dir,
manifest,
logger,
lambda name, info, extension: build_stub_class(
name,
info,
extension,
_RUNNING_EXTENSIONS,
logger,
),
PYISOLATE_VENV_ROOT,
_EXTENSION_MANAGERS,
)
spec_list = [
IsolatedNodeSpec(
node_name=node_name,
display_name=display_name,
stub_class=stub_cls,
module_path=node_dir,
)
for node_name, display_name, stub_cls in spec_list
]
isolated_node_timings.append(
(time.perf_counter() - load_start, node_dir, len(spec_list))
)
return spec_list
tasks = [
load_with_semaphore(node_dir, manifest)
for node_dir, manifest in manifest_entries
]
results = await asyncio.gather(*tasks, return_exceptions=True)
specs: List[IsolatedNodeSpec] = []
for result in results:
if isinstance(result, Exception):
logger.error(
"%s Isolated node failed during startup; continuing: %s",
LOG_PREFIX,
result,
)
continue
specs.extend(result)
_ISOLATED_NODE_SPECS = specs
return list(_ISOLATED_NODE_SPECS)
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
"""Get all node class types (node names) belonging to an extension."""
extension = _RUNNING_EXTENSIONS.get(extension_name)
if not extension:
return set()
ext_path = Path(extension.module_path)
class_types = set()
for spec in _ISOLATED_NODE_SPECS:
if spec.module_path.resolve() == ext_path.resolve():
class_types.add(spec.node_name)
return class_types
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
) -> None:
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
_RUNNING_EXTENSIONS.pop(ext_name, None)
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
logger.debug(
"%s ISO:notify_graph_start running=%d needed=%d",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
len(needed_class_types),
)
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
ext_class_types = _get_class_types_for_extension(ext_name)
# If NONE of this extension's nodes are in the execution graph → evict
if not ext_class_types.intersection(needed_class_types):
await _stop_extension(
ext_name,
extension,
"isolated custom_node not in execution graph, evicting",
)
# Isolated child processes add steady VRAM pressure; reclaim host-side models
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
try:
import comfy.model_management as model_management
device = model_management.get_torch_device()
if getattr(device, "type", None) == "cuda":
required = max(
model_management.minimum_inference_memory(),
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
)
free_before = model_management.get_free_memory(device)
if free_before < required and _RUNNING_EXTENSIONS:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
await _stop_extension(
ext_name,
extension,
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
)
if model_management.get_free_memory(device) < required:
model_management.unload_all_models()
model_management.cleanup_models_gc()
model_management.cleanup_models()
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.soft_empty_cache()
except Exception:
logger.debug(
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
)
finally:
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
logger.debug(
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
)
async def flush_running_extensions_transport_state() -> int:
total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None)
if not callable(flush_fn):
continue
try:
flushed = await flush_fn()
if isinstance(flushed, int):
total_flushed += flushed
if flushed > 0:
logger.debug(
"%s %s workflow-end flush released=%d",
LOG_PREFIX,
ext_name,
flushed,
)
except Exception:
logger.debug(
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
)
scan_shm_forensics(
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
)
return total_flushed
def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
"""Update all active RPC instances with the current event loop.
This MUST be called at the start of each workflow execution to ensure
RPC calls are scheduled on the correct event loop. This handles the case
where asyncio.run() creates a new event loop for each workflow.
Args:
loop: The event loop to use. If None, uses asyncio.get_running_loop().
"""
if loop is None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
update_count = 0
# Update RPCs from ExtensionManagers
for manager in _EXTENSION_MANAGERS:
if not hasattr(manager, "extensions"):
continue
for name, extension in manager.extensions.items():
if hasattr(extension, "rpc") and extension.rpc is not None:
if hasattr(extension.rpc, "update_event_loop"):
extension.rpc.update_event_loop(loop)
update_count += 1
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
# Also update RPCs from running extensions (they may have direct RPC refs)
for name, extension in _RUNNING_EXTENSIONS.items():
if hasattr(extension, "rpc") and extension.rpc is not None:
if hasattr(extension.rpc, "update_event_loop"):
extension.rpc.update_event_loop(loop)
update_count += 1
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
if update_count > 0:
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
else:
logger.debug(
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
)
__all__ = [
"LOG_PREFIX",
"initialize_proxies",
"initialize_isolation_nodes",
"start_isolation_loading_early",
"await_isolation_loading",
"notify_execution_graph",
"flush_running_extensions_transport_state",
"get_claimed_paths",
"update_rpc_event_loops",
"IsolatedNodeSpec",
"get_class_types_for_extension",
]

View File

@@ -1,505 +0,0 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
try:
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
from comfy.isolation.model_patcher_proxy import (
ModelPatcherProxy,
ModelPatcherRegistry,
)
from comfy.isolation.model_sampling_proxy import (
ModelSamplingProxy,
ModelSamplingRegistry,
)
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
from comfy.isolation.proxies.utils_proxy import UtilsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
except ImportError as exc: # Fail loud if Comfy environment is incomplete
raise ImportError(f"ComfyUI environment incomplete: {exc}")
logger = logging.getLogger(__name__)
# Force /dev/shm for shared memory (bwrap makes /tmp private)
import tempfile
if os.path.exists("/dev/shm"):
# Only override if not already set or if default is not /dev/shm
current_tmp = tempfile.gettempdir()
if not current_tmp.startswith("/dev/shm"):
logger.debug(
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
)
os.environ["TMPDIR"] = "/dev/shm"
tempfile.tempdir = None # Clear cache to force re-evaluation
class ComfyUIAdapter(IsolationAdapter):
# ComfyUI-specific IsolationAdapter implementation
@property
def identifier(self) -> str:
return "comfyui"
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
if "ComfyUI" in module_path and "custom_nodes" in module_path:
parts = module_path.split("ComfyUI")
if len(parts) > 1:
comfy_root = parts[0] + "ComfyUI"
return {
"preferred_root": comfy_root,
"additional_paths": [
os.path.join(comfy_root, "custom_nodes"),
os.path.join(comfy_root, "comfy"),
],
}
return None
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
comfy_root = snapshot.get("preferred_root")
if not comfy_root:
return
requirements_path = Path(comfy_root) / "requirements.txt"
if requirements_path.exists():
import re
for line in requirements_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
if pkg_name:
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
raise RuntimeError(
f"ModelPatcher in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side: register with registry
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
model_id = ModelPatcherRegistry().register(obj)
return {"__type__": "ModelPatcherRef", "model_id": model_id}
def deserialize_model_patcher(data: Any) -> Any:
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
if isinstance(data, dict):
return ModelPatcherProxy(
data["model_id"], registry=None, manage_lifecycle=False
)
return data
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelPatcherRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return ModelPatcherProxy(
data["model_id"], registry=None, manage_lifecycle=False
)
else:
return ModelPatcherRegistry()._get_instance(data["model_id"])
# Register ModelPatcher type for serialization
registry.register(
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
)
# Register ModelPatcherProxy type (already a proxy, just return ref)
registry.register(
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
)
# Register ModelPatcherRef for deserialization (context-aware: host or child)
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
def serialize_clip(obj: Any) -> Dict[str, Any]:
if hasattr(obj, "_instance_id"):
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
clip_id = CLIPRegistry().register(obj)
return {"__type__": "CLIPRef", "clip_id": clip_id}
def deserialize_clip(data: Any) -> Any:
if isinstance(data, dict):
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
return data
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
"""Context-aware CLIPRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
else:
return CLIPRegistry()._get_instance(data["clip_id"])
# Register CLIP type for serialization
registry.register("CLIP", serialize_clip, deserialize_clip)
# Register CLIPProxy type (already a proxy, just return ref)
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
# Register CLIPRef for deserialization (context-aware: host or child)
registry.register("CLIPRef", None, deserialize_clip_ref)
def serialize_vae(obj: Any) -> Dict[str, Any]:
if hasattr(obj, "_instance_id"):
return {"__type__": "VAERef", "vae_id": obj._instance_id}
vae_id = VAERegistry().register(obj)
return {"__type__": "VAERef", "vae_id": vae_id}
def deserialize_vae(data: Any) -> Any:
if isinstance(data, dict):
return VAEProxy(data["vae_id"])
return data
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
"""Context-aware VAERef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
# Child: create a proxy
return VAEProxy(data["vae_id"])
else:
# Host: lookup real VAE from registry
return VAERegistry()._get_instance(data["vae_id"])
# Register VAE type for serialization
registry.register("VAE", serialize_vae, deserialize_vae)
# Register VAEProxy type (already a proxy, just return ref)
registry.register("VAEProxy", serialize_vae, deserialize_vae)
# Register VAERef for deserialization (context-aware: host or child)
registry.register("VAERef", None, deserialize_vae_ref)
# ModelSampling serialization - handles ModelSampling* types
# copyreg removed - no pickle fallback allowed
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
raise RuntimeError(
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
def deserialize_model_sampling(data: Any) -> Any:
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
if isinstance(data, dict):
return ModelSamplingProxy(data["ms_id"])
return data
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelSamplingRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return ModelSamplingProxy(data["ms_id"])
else:
return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register ModelSampling type and proxy
registry.register(
"ModelSamplingDiscrete",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousEDM",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousV",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
)
# Register ModelSamplingRef for deserialization (context-aware: host or child)
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
def serialize_cond(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
return {
"__type__": type_key,
"cond": obj.cond,
}
def deserialize_cond(data: Dict[str, Any]) -> Any:
import importlib
type_key = data["__type__"]
module_name, class_name = type_key.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cls(data["cond"])
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
state: Dict[str, Any] = {}
for key, value in obj.__dict__.items():
if key.startswith("_"):
continue
if callable(value):
continue
state[key] = value
return state
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
return {
"__type__": type_key,
"state": _serialize_public_state(obj),
}
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
import importlib
type_key = data["__type__"]
module_name, class_name = type_key.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
obj = cls()
for key, value in data.get("state", {}).items():
prop = getattr(type(obj), key, None)
if isinstance(prop, property) and prop.fset is None:
continue
setattr(obj, key, value)
return obj
import comfy.conds
for cond_cls in vars(comfy.conds).values():
if not isinstance(cond_cls, type):
continue
if not issubclass(cond_cls, comfy.conds.CONDRegular):
continue
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
registry.register(type_key, serialize_cond, deserialize_cond)
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
import comfy.latent_formats
for latent_cls in vars(comfy.latent_formats).values():
if not isinstance(latent_cls, type):
continue
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
continue
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
registry.register(
type_key, serialize_latent_format, deserialize_latent_format
)
registry.register(
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
)
# V3 API: unwrap NodeOutput.args
def deserialize_node_output(data: Any) -> Any:
return getattr(data, "args", data)
registry.register("NodeOutput", None, deserialize_node_output)
# KSAMPLER serializer: stores sampler name instead of function object
# sampler_function is a callable which gets filtered out by JSONSocketTransport
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
func_name = obj.sampler_function.__name__
# Map function name back to sampler name
if func_name == "sample_unipc":
sampler_name = "uni_pc"
elif func_name == "sample_unipc_bh2":
sampler_name = "uni_pc_bh2"
elif func_name == "dpm_fast_function":
sampler_name = "dpm_fast"
elif func_name == "dpm_adaptive_function":
sampler_name = "dpm_adaptive"
elif func_name.startswith("sample_"):
sampler_name = func_name[7:] # Remove "sample_" prefix
else:
sampler_name = func_name
return {
"__type__": "KSAMPLER",
"sampler_name": sampler_name,
"extra_options": obj.extra_options,
"inpaint_options": obj.inpaint_options,
}
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
import comfy.samplers
return comfy.samplers.ksampler(
data["sampler_name"],
data.get("extra_options", {}),
data.get("inpaint_options", {}),
)
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
register_hooks_serializers(registry)
# Generic Numpy Serializer
def serialize_numpy(obj: Any) -> Any:
import torch
try:
# Attempt zero-copy conversion to Tensor
return torch.from_numpy(obj)
except Exception:
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
registry.register("ndarray", serialize_numpy, None)
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
return [
PromptServerService,
FolderPathsProxy,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
]
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
# Resolve the real name whether it's an instance or the Singleton class itself
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
if api_name == "FolderPathsProxy":
import folder_paths
# Replace module-level functions with proxy methods
# This is aggressive but necessary for transparent proxying
# Handle both instance and class cases
instance = api() if isinstance(api, type) else api
for name in dir(instance):
if not name.startswith("_"):
setattr(folder_paths, name, getattr(instance, name))
return
if api_name == "ModelManagementProxy":
import comfy.model_management
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
return
if api_name == "UtilsProxy":
import comfy.utils
# Static Injection of RPC mechanism to ensure Child can access it
# independent of instance lifecycle.
api.set_rpc(rpc)
# Don't overwrite host hook (infinite recursion)
return
if api_name == "PromptServerProxy":
# Defer heavy import to child context
import server
instance = api() if isinstance(api, type) else api
proxy = (
instance.instance
) # PromptServerProxy instance has .instance property returning self
original_register_route = proxy.register_route
def register_route_wrapper(
method: str, path: str, handler: Callable[..., Any]
) -> None:
callback_id = rpc.register_callback(handler)
loop = getattr(rpc, "loop", None)
if loop and loop.is_running():
import asyncio
asyncio.create_task(
original_register_route(
method, path, handler=callback_id, is_callback=True
)
)
else:
original_register_route(
method, path, handler=callback_id, is_callback=True
)
return None
proxy.register_route = register_route_wrapper
class RouteTableDefProxy:
def __init__(self, proxy_instance: Any):
self.proxy = proxy_instance
def get(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("GET", path, handler)
return handler
return decorator
def post(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("POST", path, handler)
return handler
return decorator
def patch(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PATCH", path, handler)
return handler
return decorator
def put(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PUT", path, handler)
return handler
return decorator
def delete(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("DELETE", path, handler)
return handler
return decorator
proxy.routes = RouteTableDefProxy(proxy)
if (
hasattr(server, "PromptServer")
and getattr(server.PromptServer, "instance", None) != proxy
):
server.PromptServer.instance = proxy

View File

@@ -1,141 +0,0 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
# Child process initialization for PyIsolate
import logging
import os
logger = logging.getLogger(__name__)
def is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def initialize_child_process() -> None:
# Manual RPC injection
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc:
_setup_prompt_server_stub(rpc)
_setup_utils_proxy(rpc)
else:
logger.warning("Could not get child RPC instance for manual injection")
_setup_prompt_server_stub()
_setup_utils_proxy()
except Exception as e:
logger.error(f"Manual RPC Injection failed: {e}")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_logging()
def _setup_prompt_server_stub(rpc=None) -> None:
try:
from .proxies.prompt_server_impl import PromptServerStub
import sys
import types
# Mock server module
if "server" not in sys.modules:
mock_server = types.ModuleType("server")
sys.modules["server"] = mock_server
server = sys.modules["server"]
if not hasattr(server, "PromptServer"):
class MockPromptServer:
pass
server.PromptServer = MockPromptServer
stub = PromptServerStub()
if rpc:
PromptServerStub.set_rpc(rpc)
if hasattr(stub, "set_rpc"):
stub.set_rpc(rpc)
server.PromptServer.instance = stub
except Exception as e:
logger.error(f"Failed to setup PromptServerStub: {e}")
def _setup_utils_proxy(rpc=None) -> None:
try:
import comfy.utils
import asyncio
# Capture main loop during initialization (safe context)
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
try:
from .proxies.base import set_global_loop
if main_loop:
set_global_loop(main_loop)
except ImportError:
pass
# Sync hook wrapper for progress updates
def sync_hook_wrapper(
value: int, total: int, preview: None = None, node_id: None = None
) -> None:
if node_id is None:
try:
from comfy_execution.utils import get_executing_context
ctx = get_executing_context()
if ctx:
node_id = ctx.node_id
else:
pass
except Exception:
pass
# Bypass blocked event loop by direct outbox injection
if rpc:
try:
# Use captured main loop if available (for threaded execution), or current loop
loop = main_loop
if loop is None:
loop = asyncio.get_event_loop()
rpc.outbox.put(
{
"kind": "call",
"object_id": "UtilsProxy",
"parent_call_id": None, # We are root here usually
"calling_loop": loop,
"future": loop.create_future(), # Dummy future
"method": "progress_bar_hook",
"args": (value, total, preview, node_id),
"kwargs": {},
}
)
except Exception as e:
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
else:
logging.getLogger(__name__).warning(
"No RPC instance available for progress update"
)
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
except Exception as e:
logger.error(f"Failed to setup UtilsProxy hook: {e}")
def _setup_logging() -> None:
logging.getLogger().setLevel(logging.INFO)

View File

@@ -1,327 +0,0 @@
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
# CLIP Proxy implementation
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
if TYPE_CHECKING:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
class CondStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "cond_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
_registry_class = CondStageModelRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<CondStageModelProxy {self._instance_id}>"
class TokenizerRegistry(BaseRegistry[Any]):
_type_prefix = "tokenizer"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
_registry_class = TokenizerRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<TokenizerProxy {self._instance_id}>"
logger = logging.getLogger(__name__)
class CLIPRegistry(BaseRegistry[Any]):
_type_prefix = "clip"
_allowed_setters = {
"layer_idx",
"tokenizer_options",
"use_clip_schedule",
"apply_hooks_to_conds",
}
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def get_patcher_id(self, instance_id: str) -> str:
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
async def get_cond_stage_model_id(self, instance_id: str) -> str:
return CondStageModelRegistry().register(
self._get_instance(instance_id).cond_stage_model
)
async def get_tokenizer_id(self, instance_id: str) -> str:
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
async def load_model(self, instance_id: str) -> None:
self._get_instance(instance_id).load_model()
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
self._get_instance(instance_id).clip_layer(layer_idx)
async def set_tokenizer_option(
self, instance_id: str, option_name: str, value: Any
) -> None:
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
if name not in self._allowed_setters:
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
setattr(self._get_instance(instance_id), name, value)
async def tokenize(
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
) -> Any:
return self._get_instance(instance_id).tokenize(
text, return_word_ids=return_word_ids, **kwargs
)
async def encode(self, instance_id: str, text: str) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(text))
async def encode_from_tokens(
self,
instance_id: str,
tokens: Any,
return_pooled: bool = False,
return_dict: bool = False,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens(
tokens, return_pooled=return_pooled, return_dict=return_dict
)
)
async def encode_from_tokens_scheduled(
self,
instance_id: str,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens_scheduled(
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
)
)
async def add_patches(
self,
instance_id: str,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> Any:
return self._get_instance(instance_id).add_patches(
patches, strength_patch=strength_patch, strength_model=strength_model
)
async def get_key_patches(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_key_patches()
async def load_sd(
self, instance_id: str, sd: dict, full_model: bool = False
) -> Any:
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
async def get_sd(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_sd()
async def clone(self, instance_id: str) -> str:
return self.register(self._get_instance(instance_id).clone())
class CLIPProxy(BaseProxy[CLIPRegistry]):
_registry_class = CLIPRegistry
__module__ = "comfy.sd"
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
@property
def patcher(self) -> "ModelPatcherProxy":
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@patcher.setter
def patcher(self, value: Any) -> None:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if isinstance(value, ModelPatcherProxy):
self._patcher_proxy = value
else:
logger.warning(
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
)
@property
def cond_stage_model(self) -> CondStageModelProxy:
if not hasattr(self, "_cond_stage_model_proxy"):
csm_id = self._call_rpc("get_cond_stage_model_id")
self._cond_stage_model_proxy = CondStageModelProxy(
csm_id, manage_lifecycle=False
)
return self._cond_stage_model_proxy
@property
def tokenizer(self) -> TokenizerProxy:
if not hasattr(self, "_tokenizer_proxy"):
tok_id = self._call_rpc("get_tokenizer_id")
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
return self._tokenizer_proxy
def load_model(self) -> ModelPatcherProxy:
self._call_rpc("load_model")
return self.patcher
@property
def layer_idx(self) -> Optional[int]:
return self._call_rpc("get_property", "layer_idx")
@layer_idx.setter
def layer_idx(self, value: Optional[int]) -> None:
self._call_rpc("set_property", "layer_idx", value)
@property
def tokenizer_options(self) -> dict:
return self._call_rpc("get_property", "tokenizer_options")
@tokenizer_options.setter
def tokenizer_options(self, value: dict) -> None:
self._call_rpc("set_property", "tokenizer_options", value)
@property
def use_clip_schedule(self) -> bool:
return self._call_rpc("get_property", "use_clip_schedule")
@use_clip_schedule.setter
def use_clip_schedule(self, value: bool) -> None:
self._call_rpc("set_property", "use_clip_schedule", value)
@property
def apply_hooks_to_conds(self) -> Any:
return self._call_rpc("get_property", "apply_hooks_to_conds")
@apply_hooks_to_conds.setter
def apply_hooks_to_conds(self, value: Any) -> None:
self._call_rpc("set_property", "apply_hooks_to_conds", value)
def clip_layer(self, layer_idx: int) -> None:
return self._call_rpc("clip_layer", layer_idx)
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
return self._call_rpc("set_tokenizer_option", option_name, value)
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
return self._call_rpc(
"tokenize", text, return_word_ids=return_word_ids, **kwargs
)
def encode(self, text: str) -> Any:
return self._call_rpc("encode", text)
def encode_from_tokens(
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
) -> Any:
res = self._call_rpc(
"encode_from_tokens",
tokens,
return_pooled=return_pooled,
return_dict=return_dict,
)
if return_pooled and isinstance(res, list) and not return_dict:
return tuple(res)
return res
def encode_from_tokens_scheduled(
self,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return self._call_rpc(
"encode_from_tokens_scheduled",
tokens,
unprojected=unprojected,
add_dict=add_dict,
show_pbar=show_pbar,
)
def add_patches(
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
) -> Any:
return self._call_rpc(
"add_patches",
patches,
strength_patch=strength_patch,
strength_model=strength_model,
)
def get_key_patches(self) -> Any:
return self._call_rpc("get_key_patches")
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
return self._call_rpc("load_sd", sd, full_model=full_model)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def clone(self) -> CLIPProxy:
new_id = self._call_rpc("clone")
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
if not IS_CHILD_PROCESS:
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()

View File

@@ -1,248 +0,0 @@
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
from __future__ import annotations
import logging
import os
import inspect
import sys
import types
import platform
from pathlib import Path
from typing import Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
async def _stop_extension_safe(
extension: ComfyNodeExtension, extension_name: str
) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
except Exception:
logger.debug("][ %s stop failed", extension_name, exc_info=True)
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
req, sep, marker = dep.partition(";")
req = req.strip()
marker_suffix = f";{marker}" if sep else ""
def _resolve_local_path(local_path: str) -> Path | None:
for base in base_paths:
candidate = (base / local_path).resolve()
if candidate.exists():
return candidate
return None
if req.startswith("./") or req.startswith("../"):
resolved = _resolve_local_path(req)
if resolved is not None:
return f"{resolved}{marker_suffix}"
if req.startswith("file://"):
raw = req[len("file://") :]
if raw.startswith("./") or raw.startswith("../"):
resolved = _resolve_local_path(raw)
if resolved is not None:
return f"file://{resolved}{marker_suffix}"
return dep
def get_enforcement_policy() -> Dict[str, bool]:
return {
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
}
class ExtensionLoadError(RuntimeError):
pass
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
normalized_name = extension_name.replace("-", "_").replace(".", "_")
if normalized_name not in sys.modules:
dummy_module = types.ModuleType(normalized_name)
dummy_module.__file__ = str(node_dir / "__init__.py")
dummy_module.__path__ = [str(node_dir)]
dummy_module.__package__ = normalized_name
sys.modules[normalized_name] = dummy_module
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
for details in cached_data.values():
if not isinstance(details, dict):
return True
if details.get("is_v3") and "schema_v1" not in details:
return True
return False
async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
try:
with manifest_path.open("rb") as handle:
manifest_data = tomllib.load(handle)
except Exception as e:
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
return []
# Parse [tool.comfy.isolation]
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
dependencies = project_config.get("dependencies", [])
if not isinstance(dependencies, list):
dependencies = []
# Get extension name (default to folder name if not in project.name)
extension_name = project_config.get("name", node_dir.name)
# LOGIC: Isolation Decision
policy = get_enforcement_policy()
isolated = can_isolate or policy["force_isolated"]
if not isolated:
return []
logger.info(f"][ Loading isolated node: {extension_name}")
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
dependencies = [
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
for dep in dependencies
]
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
manager: ExtensionManager = pyisolate.ExtensionManager(
ComfyNodeExtension, manager_config
)
extension_managers.append(manager)
host_policy = load_host_policy(Path(folder_paths.base_path))
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
share_cuda_ipc = share_torch and is_linux
extension_config = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
"dependencies": dependencies,
"share_torch": share_torch,
"share_cuda_ipc": share_cuda_ipc,
"sandbox": sandbox_config,
}
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
if cached_data:
if _is_stale_node_cache(cached_data):
logger.debug(
"][ %s cache is stale/incompatible; rebuilding metadata",
extension_name,
)
else:
logger.debug(f"][ {extension_name} loaded from cache")
specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension)
specs.append(
(node_name, details.get("display_name", node_name), stub_cls)
)
return specs
# Cache miss - spawn process and get metadata
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
try:
remote_nodes: Dict[str, str] = await extension.list_nodes()
except Exception as exc:
logger.warning(
"][ %s metadata discovery failed, skipping isolated load: %s",
extension_name,
exc,
)
await _stop_extension_safe(extension, extension_name)
return []
if not remote_nodes:
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
await _stop_extension_safe(extension, extension_name)
return []
specs: List[Tuple[str, str, type]] = []
cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items():
try:
details = await extension.get_node_details(node_name)
except Exception as exc:
logger.warning(
"][ %s failed to load metadata for %s, skipping node: %s",
extension_name,
node_name,
exc,
)
continue
details["display_name"] = display_name
cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension)
specs.append((node_name, display_name, stub_cls))
if not specs:
logger.warning(
"][ %s produced no usable nodes after metadata scan; skipping",
extension_name,
)
await _stop_extension_safe(extension, extension_name)
return []
# Save metadata to cache for future runs
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)
return specs
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]

View File

@@ -1,673 +0,0 @@
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
from __future__ import annotations
import asyncio
import torch
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError as e:
raise AttributeError(item) from e
def copy(self):
return AttrDict(super().copy())
import importlib
import inspect
import json
import logging
import os
import sys
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, Tuple
from pyisolate import ExtensionBase
from comfy_api.internal import _ComfyNodeInternal
LOG_PREFIX = "]["
V3_DISCOVERY_TIMEOUT = 30
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
logger = logging.getLogger(__name__)
def _flush_tensor_transport_state(marker: str) -> int:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
return 0
if not callable(flush_tensor_keeper):
return 0
flushed = flush_tensor_keeper()
if flushed > 0:
logger.debug(
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
)
return flushed
def _relieve_child_vram_pressure(marker: str) -> None:
import comfy.model_management as model_management
model_management.cleanup_models_gc()
model_management.cleanup_models()
device = model_management.get_torch_device()
if not hasattr(device, "type") or device.type == "cpu":
return
required = max(
model_management.minimum_inference_memory(),
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=True)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.cleanup_models()
model_management.soft_empty_cache()
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
def _sanitize_for_transport(value):
primitives = (str, int, float, bool, type(None))
if isinstance(value, primitives):
return value
cls_name = value.__class__.__name__
if cls_name == "FlexibleOptionalInputType":
return {
"__pyisolate_flexible_optional__": True,
"type": _sanitize_for_transport(getattr(value, "type", "*")),
}
if cls_name == "AnyType":
return {"__pyisolate_any_type__": True, "value": str(value)}
if cls_name == "ByPassTypeTuple":
return {
"__pyisolate_bypass_tuple__": [
_sanitize_for_transport(v) for v in tuple(value)
]
}
if isinstance(value, dict):
return {k: _sanitize_for_transport(v) for k, v in value.items()}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
if isinstance(value, list):
return [_sanitize_for_transport(v) for v in value]
return str(value)
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
# The canonical definition is now in pyisolate._internal.remote_handle
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
class ComfyNodeExtension(ExtensionBase):
def __init__(self) -> None:
super().__init__()
self.node_classes: Dict[str, type] = {}
self.display_names: Dict[str, str] = {}
self.node_instances: Dict[str, Any] = {}
self.remote_objects: Dict[str, Any] = {}
self._route_handlers: Dict[str, Any] = {}
self._module: Any = None
async def on_module_loaded(self, module: Any) -> None:
self._module = module
# Registries are initialized in host_hooks.py initialize_host_process()
# They auto-register via ProxiedSingleton when instantiated
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
try:
from comfy_api.latest import ComfyExtension
for name, obj in inspect.getmembers(module):
if not (
inspect.isclass(obj)
and issubclass(obj, ComfyExtension)
and obj is not ComfyExtension
):
continue
if not obj.__module__.startswith(module.__name__):
continue
try:
ext_instance = obj()
try:
await asyncio.wait_for(
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in on_load()",
LOG_PREFIX,
name,
)
continue
try:
v3_nodes = await asyncio.wait_for(
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in get_node_list()",
LOG_PREFIX,
name,
)
continue
for node_cls in v3_nodes:
if hasattr(node_cls, "GET_SCHEMA"):
schema = node_cls.GET_SCHEMA()
self.node_classes[schema.node_id] = node_cls
if schema.display_name:
self.display_names[schema.node_id] = schema.display_name
except Exception as e:
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
except ImportError:
pass
module_name = getattr(module, "__name__", "isolated_nodes")
for node_cls in self.node_classes.values():
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
node_cls.__module__ = module_name
self.node_instances = {}
async def list_nodes(self) -> Dict[str, str]:
return {name: self.display_names.get(name, name) for name in self.node_classes}
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
return await self.get_node_details(node_name)
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
input_types_raw = (
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
)
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
if output_is_list is not None:
output_is_list = tuple(bool(x) for x in output_is_list)
details: Dict[str, Any] = {
"input_types": _sanitize_for_transport(input_types_raw),
"return_types": tuple(
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
),
"return_names": getattr(node_cls, "RETURN_NAMES", None),
"function": str(getattr(node_cls, "FUNCTION", "execute")),
"category": str(getattr(node_cls, "CATEGORY", "")),
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
"output_is_list": output_is_list,
"is_v3": is_v3,
}
if is_v3:
try:
schema = node_cls.GET_SCHEMA()
schema_v1 = asdict(schema.get_v1_info(node_cls))
try:
schema_v3 = asdict(schema.get_v3_info(node_cls))
except (AttributeError, TypeError):
schema_v3 = self._build_schema_v3_fallback(schema)
details.update(
{
"schema_v1": schema_v1,
"schema_v3": schema_v3,
"hidden": [h.value for h in (schema.hidden or [])],
"description": getattr(schema, "description", ""),
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
"api_node": bool(getattr(node_cls, "API_NODE", False)),
"input_is_list": bool(
getattr(node_cls, "INPUT_IS_LIST", False)
),
"not_idempotent": bool(
getattr(node_cls, "NOT_IDEMPOTENT", False)
),
}
)
except Exception as exc:
logger.warning(
"%s V3 schema serialization failed for %s: %s",
LOG_PREFIX,
node_name,
exc,
)
return details
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
input_dict: Dict[str, Any] = {}
output_dict: Dict[str, Any] = {}
hidden_list: List[str] = []
if getattr(schema, "inputs", None):
for inp in schema.inputs:
self._add_schema_io_v3(inp, input_dict)
if getattr(schema, "outputs", None):
for out in schema.outputs:
self._add_schema_io_v3(out, output_dict)
if getattr(schema, "hidden", None):
for h in schema.hidden:
hidden_list.append(getattr(h, "value", str(h)))
return {
"input": input_dict,
"output": output_dict,
"hidden": hidden_list,
"name": getattr(schema, "node_id", None),
"display_name": getattr(schema, "display_name", None),
"description": getattr(schema, "description", None),
"category": getattr(schema, "category", None),
"output_node": getattr(schema, "is_output_node", False),
"deprecated": getattr(schema, "is_deprecated", False),
"experimental": getattr(schema, "is_experimental", False),
"api_node": getattr(schema, "is_api_node", False),
}
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
io_id = getattr(io_obj, "id", None)
if io_id is None:
return
io_type_fn = getattr(io_obj, "get_io_type", None)
io_type = (
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
)
as_dict_fn = getattr(io_obj, "as_dict", None)
payload = as_dict_fn() if callable(as_dict_fn) else {}
target[str(io_id)] = (io_type, payload)
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
if hasattr(node_cls, "INPUT_TYPES"):
return node_cls.INPUT_TYPES()
return {}
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
logger.debug(
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
len(inputs),
)
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
_relieve_child_vram_pressure("EXT:pre_execute")
resolved_inputs = self._resolve_remote_objects(inputs)
instance = self._get_node_instance(node_name)
node_cls = self._get_node_class(node_name)
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
# Hidden params come through RPC as string keys like "Hidden.prompt"
from comfy_api.latest._io import Hidden, HiddenHolder
# Map string representations back to Hidden enum keys
hidden_string_map = {
"Hidden.unique_id": Hidden.unique_id,
"Hidden.prompt": Hidden.prompt,
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
"Hidden.dynprompt": Hidden.dynprompt,
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
}
# Find and extract hidden parameters (both enum and string form)
hidden_found = {}
keys_to_remove = []
for key in list(resolved_inputs.keys()):
# Check string form first (from RPC serialization)
if key in hidden_string_map:
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
keys_to_remove.append(key)
# Also check enum form (direct calls)
elif isinstance(key, Hidden):
hidden_found[key] = resolved_inputs[key]
keys_to_remove.append(key)
# Remove hidden params from kwargs
for key in keys_to_remove:
resolved_inputs.pop(key)
# Set hidden on node class if any hidden params found
if hidden_found:
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
else:
# Update existing hidden holder
for key, value in hidden_found.items():
setattr(node_cls.hidden, key.value.lower(), value)
function_name = getattr(node_cls, "FUNCTION", "execute")
if not hasattr(instance, function_name):
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
handler = getattr(instance, function_name)
try:
if asyncio.iscoroutinefunction(handler):
result = await handler(**resolved_inputs)
else:
import functools
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, functools.partial(handler, **resolved_inputs)
)
except Exception:
logger.exception(
"%s ISO:child_execute_error ext=%s node=%s",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
)
raise
if type(result).__name__ == "NodeOutput":
result = result.args
if self._is_comfy_protocol_return(result):
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
)
return self._wrap_unpicklable_objects(result)
if not isinstance(result, tuple):
result = (result,)
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
len(result),
)
return self._wrap_unpicklable_objects(result)
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
return 0
logger.debug(
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
)
flushed = _flush_tensor_transport_state("EXT:workflow_end")
try:
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
)
registry = ModelPatcherRegistry()
removed = registry.sweep_pending_cleanup()
if removed > 0:
logger.debug(
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
)
except Exception:
logger.debug(
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
)
logger.debug(
"%s ISO:child_flush_done ext=%s flushed=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
flushed,
)
return flushed
async def get_remote_object(self, object_id: str) -> Any:
"""Retrieve a remote object by ID for host-side deserialization."""
if object_id not in self.remote_objects:
raise KeyError(f"Remote object {object_id} not found")
return self.remote_objects[object_id]
def _wrap_unpicklable_objects(self, data: Any) -> Any:
if isinstance(data, (str, int, float, bool, type(None))):
return data
if isinstance(data, torch.Tensor):
return data.detach() if data.requires_grad else data
# Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr(
data, "last_hidden_state"
):
fields = {}
for attr in (
"penultimate_hidden_states",
"last_hidden_state",
"image_embeds",
"text_embeds",
):
if hasattr(data, attr):
try:
fields[attr] = self._wrap_unpicklable_objects(
getattr(data, attr)
)
except Exception:
pass
if fields:
return {"__pyisolate_attribute_container__": True, "data": fields}
# Avoid converting arbitrary objects with stateful methods (models, etc.)
# They will be handled via RemoteObjectHandle below.
type_name = type(data).__name__
if type_name == "ModelPatcherProxy":
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
if type_name == "CLIPProxy":
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
if type_name == "VAEProxy":
return {"__type__": "VAERef", "vae_id": data._instance_id}
if type_name == "ModelSamplingProxy":
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
if isinstance(data, (list, tuple)):
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
return tuple(wrapped) if isinstance(data, tuple) else wrapped
if isinstance(data, dict):
converted_dict = {
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
}
return {"__pyisolate_attrdict__": True, "data": converted_dict}
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = data
return RemoteObjectHandle(object_id, type(data).__name__)
def _resolve_remote_objects(self, data: Any) -> Any:
if isinstance(data, RemoteObjectHandle):
if data.object_id not in self.remote_objects:
raise KeyError(f"Remote object {data.object_id} not found")
return self.remote_objects[data.object_id]
if isinstance(data, dict):
ref_type = data.get("__type__")
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
if ref_type == "ModelSamplingRef":
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
resolved = [self._resolve_remote_objects(item) for item in data]
return tuple(resolved) if isinstance(data, tuple) else resolved
return data
def _get_node_class(self, node_name: str) -> type:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
return self.node_classes[node_name]
def _get_node_instance(self, node_name: str) -> Any:
if node_name not in self.node_instances:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
self.node_instances[node_name] = self.node_classes[node_name]()
return self.node_instances[node_name]
async def before_module_loaded(self) -> None:
# Inject initialization here if we think this is the child
try:
from comfy.isolation import initialize_proxies
initialize_proxies()
except Exception as e:
logging.getLogger(__name__).error(
f"Failed to call initialize_proxies in before_module_loaded: {e}"
)
await super().before_module_loaded()
try:
from comfy_api.latest import ComfyAPI_latest
from .proxies.progress_proxy import ProgressProxy
ComfyAPI_latest.Execution = ProgressProxy
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
# latest_ui.folder_paths = fp_proxy
# latest_resources.folder_paths = fp_proxy
except Exception:
pass
async def call_route_handler(
self,
handler_module: str,
handler_func: str,
request_data: Dict[str, Any],
) -> Any:
cache_key = f"{handler_module}.{handler_func}"
if cache_key not in self._route_handlers:
if self._module is not None and hasattr(self._module, "__file__"):
node_dir = os.path.dirname(self._module.__file__)
if node_dir not in sys.path:
sys.path.insert(0, node_dir)
try:
module = importlib.import_module(handler_module)
self._route_handlers[cache_key] = getattr(module, handler_func)
except (ImportError, AttributeError) as e:
raise ValueError(f"Route handler not found: {cache_key}") from e
handler = self._route_handlers[cache_key]
mock_request = MockRequest(request_data)
if asyncio.iscoroutinefunction(handler):
result = await handler(mock_request)
else:
result = handler(mock_request)
return self._serialize_response(result)
def _is_comfy_protocol_return(self, result: Any) -> bool:
"""
Check if the result matches the ComfyUI 'Protocol Return' schema.
A Protocol Return is a dictionary containing specific reserved keys that
ComfyUI's execution engine interprets as instructions (UI updates,
Workflow expansion, etc.) rather than purely data outputs.
Schema:
- Must be a dict
- Must contain at least one of: 'ui', 'result', 'expand'
"""
if not isinstance(result, dict):
return False
return any(key in result for key in ("ui", "result", "expand"))
def _serialize_response(self, response: Any) -> Dict[str, Any]:
if response is None:
return {"type": "text", "body": "", "status": 204}
if isinstance(response, dict):
return {"type": "json", "body": response, "status": 200}
if isinstance(response, str):
return {"type": "text", "body": response, "status": 200}
if hasattr(response, "text") and hasattr(response, "status"):
return {
"type": "text",
"body": response.text
if hasattr(response, "text")
else str(response.body),
"status": response.status,
"headers": dict(response.headers)
if hasattr(response, "headers")
else {},
}
if hasattr(response, "body") and hasattr(response, "status"):
body = response.body
if isinstance(body, bytes):
try:
return {
"type": "text",
"body": body.decode("utf-8"),
"status": response.status,
}
except UnicodeDecodeError:
return {
"type": "binary",
"body": body.hex(),
"status": response.status,
}
return {"type": "json", "body": body, "status": response.status}
return {"type": "text", "body": str(response), "status": 200}
class MockRequest:
def __init__(self, data: Dict[str, Any]):
self.method = data.get("method", "GET")
self.path = data.get("path", "/")
self.query = data.get("query", {})
self._body = data.get("body", {})
self._text = data.get("text", "")
self.headers = data.get("headers", {})
self.content_type = data.get(
"content_type", self.headers.get("Content-Type", "application/json")
)
self.match_info = data.get("match_info", {})
async def json(self) -> Any:
if isinstance(self._body, dict):
return self._body
if isinstance(self._body, str):
return json.loads(self._body)
return {}
async def post(self) -> Dict[str, Any]:
if isinstance(self._body, dict):
return self._body
return {}
async def text(self) -> str:
if self._text:
return self._text
if isinstance(self._body, str):
return self._body
if isinstance(self._body, dict):
return json.dumps(self._body)
return ""
async def read(self) -> bytes:
return (await self.text()).encode("utf-8")

View File

@@ -1,26 +0,0 @@
# pylint: disable=import-outside-toplevel
# Host process initialization for PyIsolate
import logging
logger = logging.getLogger(__name__)
def initialize_host_process() -> None:
root = logging.getLogger()
for handler in root.handlers[:]:
root.removeHandler(handler)
root.addHandler(logging.NullHandler())
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerService
from .proxies.utils_proxy import UtilsProxy
from .vae_proxy import VAERegistry
FolderPathsProxy()
ModelManagementProxy()
ProgressProxy()
PromptServerService()
UtilsProxy()
VAERegistry()

View File

@@ -1,83 +0,0 @@
# pylint: disable=logging-fstring-interpolation
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, TypedDict
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
class HostSecurityPolicy(TypedDict):
allow_network: bool
writable_paths: List[str]
readonly_paths: List[str]
whitelist: Dict[str, str]
DEFAULT_POLICY: HostSecurityPolicy = {
"allow_network": False,
"writable_paths": ["/dev/shm", "/tmp"],
"readonly_paths": [],
"whitelist": {},
}
def _default_policy() -> HostSecurityPolicy:
return {
"allow_network": DEFAULT_POLICY["allow_network"],
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
}
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
config_path = comfy_root / "pyproject.toml"
policy = _default_policy()
if not config_path.exists():
logger.debug("Host policy file missing at %s, using defaults.", config_path)
return policy
try:
with config_path.open("rb") as f:
data = tomllib.load(f)
except Exception:
logger.warning(
"Failed to parse host policy from %s, using defaults.",
config_path,
exc_info=True,
)
return policy
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
if not isinstance(tool_config, dict):
logger.debug("No [tool.comfy.host] section found, using defaults.")
return policy
if "allow_network" in tool_config:
policy["allow_network"] = bool(tool_config["allow_network"])
if "writable_paths" in tool_config:
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
if "readonly_paths" in tool_config:
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
whitelist_raw = tool_config.get("whitelist")
if isinstance(whitelist_raw, dict):
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
logger.debug(
f"Loaded Host Policy: {len(policy['whitelist'])} whitelisted nodes, Network={policy['allow_network']}"
)
return policy
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]

View File

@@ -1,186 +0,0 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations
import hashlib
import json
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import folder_paths
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
LOG_PREFIX = "]["
logger = logging.getLogger(__name__)
CACHE_SUBDIR = "cache"
CACHE_KEY_FILE = "cache_key"
CACHE_DATA_FILE = "node_info.json"
CACHE_KEY_LENGTH = 16
def find_manifest_directories() -> List[Tuple[Path, Path]]:
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
manifest_dirs: List[Tuple[Path, Path]] = []
# Standard custom_nodes paths
for base_path in folder_paths.get_folder_paths("custom_nodes"):
base = Path(base_path)
if not base.exists() or not base.is_dir():
continue
for entry in base.iterdir():
if not entry.is_dir():
continue
# Look for pyproject.toml
manifest = entry / "pyproject.toml"
if not manifest.exists():
continue
# Validate [tool.comfy.isolation] section existence
try:
with manifest.open("rb") as f:
data = tomllib.load(f)
if (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
):
manifest_dirs.append((entry, manifest))
except Exception:
continue
return manifest_dirs
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
hasher = hashlib.sha256()
try:
# Hashing the manifest content ensures config changes invalidate cache
hasher.update(manifest_path.read_bytes())
except OSError:
hasher.update(b"__manifest_read_error__")
try:
py_files = sorted(node_dir.rglob("*.py"))
for py_file in py_files:
rel_path = py_file.relative_to(node_dir)
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
continue
hasher.update(str(rel_path).encode("utf-8"))
try:
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
except OSError:
hasher.update(b"__file_stat_error__")
except OSError:
hasher.update(b"__dir_scan_error__")
hasher.update(sys.version.encode("utf-8"))
try:
import pyisolate
hasher.update(pyisolate.__version__.encode("utf-8"))
except (ImportError, AttributeError):
hasher.update(b"__pyisolate_unknown__")
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
"""Return True only if stored cache key matches current computed key."""
try:
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
if not cache_key_file.exists() or not cache_data_file.exists():
return False
current_key = compute_cache_key(node_dir, manifest_path)
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
return current_key == stored_key
except Exception as e:
logger.debug(
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
)
return False
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
"""Load node metadata from cache, return None on any error."""
try:
_, cache_data_file = get_cache_path(node_dir, venv_root)
if not cache_data_file.exists():
return None
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return None
return data
except Exception:
return None
def save_to_cache(
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
) -> None:
"""Save node metadata and cache key atomically."""
try:
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
cache_dir = cache_key_file.parent
cache_dir.mkdir(parents=True, exist_ok=True)
cache_key = compute_cache_key(node_dir, manifest_path)
# Atomic write: data
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
try:
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
json.dump(node_data, f, indent=2)
os.replace(tmp_data_path, cache_data_file)
except Exception:
try:
os.unlink(tmp_data_path)
except OSError:
pass
raise
# Atomic write: key
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
try:
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
f.write(cache_key)
os.replace(tmp_key_path, cache_key_file)
except Exception:
try:
os.unlink(tmp_key_path)
except OSError:
pass
raise
except Exception as e:
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
__all__ = [
"LOG_PREFIX",
"find_manifest_directories",
"compute_cache_key",
"get_cache_path",
"is_cache_valid",
"load_from_cache",
"save_to_cache",
]

View File

@@ -1,820 +0,0 @@
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
# RPC proxy for ModelPatcher (parent process)
from __future__ import annotations
import logging
from typing import Any, Optional, List, Set, Dict, Callable
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
)
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
AutoPatcherEjector,
)
logger = logging.getLogger(__name__)
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
_registry_class = ModelPatcherRegistry
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
else:
self._rpc_caller = self._registry
return self._rpc_caller
def get_all_callbacks(self, call_type: str = None) -> Any:
return self._call_rpc("get_all_callbacks", call_type)
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
return self._call_rpc("get_all_wrappers", wrapper_type)
def _load_list(self, *args, **kwargs) -> Any:
return self._call_rpc("load_list_internal", *args, **kwargs)
def prepare_hook_patches_current_keyframe(
self, t: Any, hook_group: Any, model_options: Any
) -> None:
self._call_rpc(
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
)
def add_hook_patches(
self,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
self._call_rpc(
"add_hook_patches", hook, patches, strength_patch, strength_model
)
def clear_cached_hook_weights(self) -> None:
self._call_rpc("clear_cached_hook_weights")
def get_combined_hook_patches(self, hooks: Any) -> Any:
return self._call_rpc("get_combined_hook_patches", hooks)
def get_additional_models_with_key(self, key: str) -> Any:
return self._call_rpc("get_additional_models_with_key", key)
@property
def object_patches(self) -> Any:
return self._call_rpc("get_object_patches")
@property
def patches(self) -> Any:
res = self._call_rpc("get_patches")
if isinstance(res, dict):
new_res = {}
for k, v in res.items():
new_list = []
for item in v:
if isinstance(item, list):
new_list.append(tuple(item))
else:
new_list.append(item)
new_res[k] = new_list
return new_res
return res
@property
def pinned(self) -> Set:
val = self._call_rpc("get_patcher_attr", "pinned")
return set(val) if val is not None else set()
@property
def hook_patches(self) -> Dict:
val = self._call_rpc("get_patcher_attr", "hook_patches")
if val is None:
return {}
try:
from comfy.hooks import _HookRef
import json
new_val = {}
for k, v in val.items():
if isinstance(k, str):
if k.startswith("PYISOLATE_HOOKREF:"):
ref_id = k.split(":", 1)[1]
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
elif k.startswith("__pyisolate_key__"):
try:
json_str = k[len("__pyisolate_key__") :]
data = json.loads(json_str)
ref_id = None
if isinstance(data, list):
for item in data:
if (
isinstance(item, list)
and len(item) == 2
and item[0] == "id"
):
ref_id = item[1]
break
if ref_id:
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
else:
new_val[k] = v
except Exception:
new_val[k] = v
else:
new_val[k] = v
else:
new_val[k] = v
return new_val
except ImportError:
return val
def set_hook_mode(self, hook_mode: Any) -> None:
self._call_rpc("set_hook_mode", hook_mode)
def register_all_hook_patches(
self,
hooks: Any,
target_dict: Any,
model_options: Any = None,
registered: Any = None,
) -> None:
self._call_rpc(
"register_all_hook_patches", hooks, target_dict, model_options, registered
)
def is_clone(self, other: Any) -> bool:
if isinstance(other, ModelPatcherProxy):
return self._call_rpc("is_clone_by_id", other._instance_id)
return False
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
if not IS_CHILD_PROCESS:
return self._call_rpc("is_clone", clone)
return False
def get_model_object(self, name: str) -> Any:
return self._call_rpc("get_model_object", name)
@property
def model_options(self) -> dict:
data = self._call_rpc("get_model_options")
import json
def _decode_keys(obj):
if isinstance(obj, dict):
new_d = {}
for k, v in obj.items():
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
try:
json_str = k[17:]
val = json.loads(json_str)
if isinstance(val, list):
val = tuple(val)
new_d[val] = _decode_keys(v)
except:
new_d[k] = _decode_keys(v)
else:
new_d[k] = _decode_keys(v)
return new_d
if isinstance(obj, list):
return [_decode_keys(x) for x in obj]
return obj
return _decode_keys(data)
@model_options.setter
def model_options(self, value: dict) -> None:
self._call_rpc("set_model_options", value)
def apply_hooks(self, hooks: Any) -> Any:
return self._call_rpc("apply_hooks", hooks)
def prepare_state(self, timestep: Any) -> Any:
return self._call_rpc("prepare_state", timestep)
def restore_hook_patches(self) -> None:
self._call_rpc("restore_hook_patches")
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
self._call_rpc("unpatch_hooks", whitelist_keys_set)
def model_patches_to(self, device: Any) -> Any:
return self._call_rpc("model_patches_to", device)
def partially_load(
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
) -> Any:
return self._call_rpc(
"partially_load", device, extra_memory, force_patch_weights
)
def partially_unload(
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
) -> int:
return self._call_rpc(
"partially_unload", device_to, memory_to_free, force_patch_weights
)
def load(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._call_rpc(
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
)
def patch_model(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> Any:
self._call_rpc(
"patch_model",
device_to,
lowvram_model_memory,
load_weights,
force_patch_weights,
)
return self
def unpatch_model(
self, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._call_rpc("unpatch_model", device_to, unpatch_weights)
def detach(self, unpatch_all: bool = True) -> Any:
self._call_rpc("detach", unpatch_all)
return self.model
def _cpu_tensor_bytes(self, obj: Any) -> int:
import torch
if isinstance(obj, torch.Tensor):
if obj.device.type == "cpu":
return obj.nbytes
return 0
if isinstance(obj, dict):
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
if isinstance(obj, (list, tuple)):
return sum(self._cpu_tensor_bytes(v) for v in obj)
return 0
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
if required_bytes <= 0:
return True
import torch
import comfy.model_management as model_management
target_raw = self.load_device
try:
if isinstance(target_raw, torch.device):
target = target_raw
elif isinstance(target_raw, str):
target = torch.device(target_raw)
elif isinstance(target_raw, int):
target = torch.device(f"cuda:{target_raw}")
else:
target = torch.device(target_raw)
except Exception:
return True
if target.type != "cuda":
return True
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
if model_management.get_free_memory(target) >= required:
return True
model_management.cleanup_models_gc()
model_management.cleanup_models()
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.free_memory(required, target, for_dynamic=True)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
model_management.free_memory(required, target, for_dynamic=False)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.load_models_gpu(
[self],
minimum_memory_required=required,
)
return model_management.get_free_memory(target) >= required
def apply_model(self, *args, **kwargs) -> Any:
import torch
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
def _to_cuda(obj: Any) -> Any:
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
return obj.to("cuda")
if isinstance(obj, dict):
return {k: _to_cuda(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cuda(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cuda(v) for v in obj)
return obj
try:
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
except torch.OutOfMemoryError:
self._ensure_apply_model_headroom(required_bytes)
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
return dict.fromkeys(keys, None)
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
res = self._call_rpc("add_patches", *args, **kwargs)
if isinstance(res, list):
return [tuple(x) if isinstance(x, list) else x for x in res]
return res
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
return self._call_rpc("get_key_patches", filter_prefix)
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
def pin_weight_to_device(self, key):
self._call_rpc("pin_weight_to_device", key)
def unpin_weight(self, key):
self._call_rpc("unpin_weight", key)
def unpin_all_weights(self):
self._call_rpc("unpin_all_weights")
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
return self._call_rpc(
"calculate_weight", patches, weight, key, intermediate_dtype
)
def inject_model(self) -> None:
self._call_rpc("inject_model")
def eject_model(self) -> None:
self._call_rpc("eject_model")
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
return AutoPatcherEjector(
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
)
@property
def is_injected(self) -> bool:
return self._call_rpc("get_is_injected")
@property
def skip_injection(self) -> bool:
return self._call_rpc("get_skip_injection")
@skip_injection.setter
def skip_injection(self, value: bool) -> None:
self._call_rpc("set_skip_injection", value)
def clean_hooks(self) -> None:
self._call_rpc("clean_hooks")
def pre_run(self) -> None:
self._call_rpc("pre_run")
def cleanup(self) -> None:
try:
self._call_rpc("cleanup")
except Exception:
logger.debug(
"ModelPatcherProxy cleanup RPC failed for %s",
self._instance_id,
exc_info=True,
)
finally:
super().cleanup()
@property
def model(self) -> _InnerModelProxy:
return _InnerModelProxy(self)
def __getattr__(self, name: str) -> Any:
_whitelisted_attrs = {
"hook_patches_backup",
"hook_backup",
"cached_hook_patches",
"current_hooks",
"forced_hooks",
"is_clip",
"patches_uuid",
"pinned",
"attachments",
"additional_models",
"injections",
"hook_patches",
"model_lowvram",
"model_loaded_weight_memory",
"backup",
"object_patches_backup",
"weight_wrapper_patches",
"weight_inplace_update",
"force_cast_weights",
}
if name in _whitelisted_attrs:
return self._call_rpc("get_patcher_attr", name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def load_lora(
self,
lora_path: str,
strength_model: float,
clip: Optional[Any] = None,
strength_clip: float = 1.0,
) -> tuple:
clip_id = None
if clip is not None:
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
result = self._call_rpc(
"load_lora", lora_path, strength_model, clip_id, strength_clip
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
new_clip = CLIPProxy(result["clip_id"])
return (new_model, new_clip)
@property
def load_device(self) -> Any:
return self._call_rpc("get_load_device")
@property
def offload_device(self) -> Any:
return self._call_rpc("get_offload_device")
@property
def device(self) -> Any:
return self.load_device
def current_loaded_device(self) -> Any:
return self._call_rpc("current_loaded_device")
@property
def size(self) -> int:
return self._call_rpc("get_size")
def model_size(self) -> Any:
return self._call_rpc("model_size")
def loaded_size(self) -> Any:
return self._call_rpc("loaded_size")
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
def lowvram_patch_counter(self) -> int:
return self._call_rpc("lowvram_patch_counter")
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
def get_free_memory(self, device: Any) -> Any:
return self._call_rpc("get_free_memory", device)
def partially_unload_ram(self, ram_to_unload: int) -> Any:
return self._call_rpc("partially_unload_ram", ram_to_unload)
def model_dtype(self) -> Any:
res = self._call_rpc("model_dtype")
if isinstance(res, str) and res.startswith("torch."):
try:
import torch
attr = res.split(".")[-1]
if hasattr(torch, attr):
return getattr(torch, attr)
except ImportError:
pass
return res
@property
def hook_mode(self) -> Any:
return self._call_rpc("get_hook_mode")
@hook_mode.setter
def hook_mode(self, value: Any) -> None:
self._call_rpc("set_hook_mode", value)
def set_model_sampler_cfg_function(
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_cfg_function",
sampler_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_post_cfg_function(
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_post_cfg_function",
post_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_pre_cfg_function(
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_pre_cfg_function",
pre_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
def set_model_patch(self, patch: Any, name: str) -> None:
self._call_rpc("set_model_patch", patch, name)
def set_model_patch_replace(
self,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._call_rpc(
"set_model_patch_replace",
patch,
name,
block_name,
number,
transformer_index,
)
def set_model_attn1_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn1", block_name, number, transformer_index
)
def set_model_attn2_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn2", block_name, number, transformer_index
)
def set_model_attn1_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "output_block_patch")
def set_model_emb_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "emb_patch")
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "post_input")
def set_model_rope_options(
self,
scale_x=1.0,
shift_x=0.0,
scale_y=1.0,
shift_y=0.0,
scale_t=1.0,
shift_t=0.0,
**kwargs: Any,
) -> None:
options = {
"scale_x": scale_x,
"shift_x": shift_x,
"scale_y": scale_y,
"shift_y": shift_y,
"scale_t": scale_t,
"shift_t": shift_t,
}
options.update(kwargs)
self._call_rpc("set_model_rope_options", options)
def set_model_compute_dtype(self, dtype: Any) -> None:
self._call_rpc("set_model_compute_dtype", dtype)
def add_object_patch(self, name: str, obj: Any) -> None:
self._call_rpc("add_object_patch", name, obj)
def add_weight_wrapper(self, name: str, function: Any) -> None:
self._call_rpc("add_weight_wrapper", name, function)
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
@property
def wrappers(self) -> Any:
return self._call_rpc("get_wrappers")
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
self._call_rpc("add_callback_with_key", call_type, key, callback)
def add_callback(self, call_type: str, callback: Any) -> None:
self.add_callback_with_key(call_type, None, callback)
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
self._call_rpc("remove_callbacks_with_key", call_type, key)
@property
def callbacks(self) -> Any:
return self._call_rpc("get_callbacks")
def set_attachments(self, key: str, attachment: Any) -> None:
self._call_rpc("set_attachments", key, attachment)
def get_attachment(self, key: str) -> Any:
return self._call_rpc("get_attachment", key)
def remove_attachments(self, key: str) -> None:
self._call_rpc("remove_attachments", key)
def set_injections(self, key: str, injections: Any) -> None:
self._call_rpc("set_injections", key, injections)
def get_injections(self, key: str) -> Any:
return self._call_rpc("get_injections", key)
def remove_injections(self, key: str) -> None:
self._call_rpc("remove_injections", key)
def set_additional_models(self, key: str, models: Any) -> None:
ids = [m._instance_id for m in models]
self._call_rpc("set_additional_models", key, ids)
def remove_additional_models(self, key: str) -> None:
self._call_rpc("remove_additional_models", key)
def get_nested_additional_models(self) -> Any:
return self._call_rpc("get_nested_additional_models")
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@property
def parent(self) -> Any:
return self._call_rpc("get_parent")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
):
return self._parent._call_rpc("get_inner_model_attr", name)
if name == "load_device":
return self._parent._call_rpc("get_inner_model_attr", "load_device")
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling")
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k
)
if name == "extra_conds":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds", a, k
)
if name == "memory_required":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_memory_required", a, k
)
if name == "apply_model":
# Delegate to parent's method to get the CPU->CUDA optimization
return self._parent.apply_model
if name == "process_latent_in":
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
if name == "process_latent_out":
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
if name == "scale_latent_inpaint":
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
if name == "diffusion_model":
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
raise AttributeError(f"'{name}' not supported on isolated InnerModel")

View File

@@ -1,875 +0,0 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import
# RPC server for ModelPatcher isolation (child process)
from __future__ import annotations
import gc
import logging
from typing import Any, Optional, List
try:
from comfy.model_patcher import AutoPatcherEjector
except ImportError:
class AutoPatcherEjector:
def __init__(self, model, skip_and_inject_on_exit_only=False):
self.model = model
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
self.prev_skip_injection = False
self.was_injected = False
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = True
if self.model.is_injected:
self.model.eject_model()
self.was_injected = True
def __exit__(self, *args):
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = self.prev_skip_injection
self.model.inject_model()
if self.was_injected and not self.model.skip_injection:
self.model.inject_model()
self.model.skip_injection = self.prev_skip_injection
from comfy.isolation.proxies.base import (
BaseRegistry,
detach_if_grad,
)
logger = logging.getLogger(__name__)
class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model"
def __init__(self) -> None:
super().__init__()
self._pending_cleanup_ids: set[str] = set()
async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
new_model = instance.clone()
return self.register(new_model)
async def is_clone(self, instance_id: str, other: Any) -> bool:
instance = self._get_instance(instance_id)
if hasattr(other, "model"):
return instance.is_clone(other)
return False
async def get_model_object(self, instance_id: str, name: str) -> Any:
instance = self._get_instance(instance_id)
if name == "model":
return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name)
if name == "model_sampling":
from comfy.isolation.model_sampling_proxy import (
ModelSamplingRegistry,
ModelSamplingProxy,
)
registry = ModelSamplingRegistry()
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry)
return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict:
instance = self._get_instance(instance_id)
import copy
opts = copy.deepcopy(instance.model_options)
return self._sanitize_rpc_result(opts)
async def set_model_options(self, instance_id: str, options: dict) -> None:
self._get_instance(instance_id).model_options = options
async def get_patcher_attr(self, instance_id: str, name: str) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), name, None)
)
async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any:
instance = self._get_instance(instance_id)
sd_keys = instance.model.state_dict().keys()
return dict.fromkeys(sd_keys, None)
def _sanitize_rpc_result(self, obj, seen=None):
if seen is None:
seen = set()
if obj is None:
return None
if isinstance(obj, (bool, int, float, str)):
if isinstance(obj, str) and len(obj) > 500000:
return f"<Truncated String len={len(obj)}>"
return obj
obj_id = id(obj)
if obj_id in seen:
return None
seen.add(obj_id)
if isinstance(obj, (list, tuple)):
return [self._sanitize_rpc_result(x, seen) for x in obj]
if isinstance(obj, set):
return [self._sanitize_rpc_result(x, seen) for x in obj]
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if isinstance(k, tuple):
import json
try:
key_str = "__pyisolate_key__" + json.dumps(list(k))
new_dict[key_str] = self._sanitize_rpc_result(v, seen)
except Exception:
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
else:
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
return new_dict
if (
hasattr(obj, "__dict__")
and not hasattr(obj, "__get__")
and not hasattr(obj, "__call__")
):
return self._sanitize_rpc_result(obj.__dict__, seen)
if hasattr(obj, "items") and hasattr(obj, "get"):
return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()}
return None
async def get_load_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).load_device
async def get_offload_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).offload_device
async def current_loaded_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).current_loaded_device()
async def get_size(self, instance_id: str) -> int:
return self._get_instance(instance_id).size
async def model_size(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_size()
async def loaded_size(self, instance_id: str) -> Any:
return self._get_instance(instance_id).loaded_size()
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def lowvram_patch_counter(self, instance_id: str) -> int:
return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape)
async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id)
if hasattr(instance, "is_dynamic"):
return bool(instance.is_dynamic())
return False
async def get_free_memory(self, instance_id: str, device: Any) -> Any:
instance = self._get_instance(instance_id)
if hasattr(instance, "get_free_memory"):
return instance.get_free_memory(device)
import comfy.model_management
return comfy.model_management.get_free_memory(device)
async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any:
instance = self._get_instance(instance_id)
if hasattr(instance, "partially_unload_ram"):
return instance.partially_unload_ram(ram_to_unload)
return None
async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype()
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device)
async def partially_load(
self,
instance_id: str,
device: Any,
extra_memory: Any,
force_patch_weights: bool = False,
) -> Any:
return self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
)
async def partially_unload(
self,
instance_id: str,
device_to: Any,
memory_to_free: int = 0,
force_patch_weights: bool = False,
) -> int:
return self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
)
async def load(
self,
instance_id: str,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
)
async def patch_model(
self,
instance_id: str,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all)
async def prepare_state(self, instance_id: str, timestep: Any) -> Any:
instance = self._get_instance(instance_id)
cp = getattr(instance.model, "current_patcher", instance)
if cp is None:
cp = instance
return cp.prepare_state(timestep)
async def pre_run(self, instance_id: str) -> None:
self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
def sweep_pending_cleanup(self) -> int:
removed = 0
with self._lock:
pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear()
for instance_id in pending_ids:
instance = self._registry.pop(instance_id, None)
if instance is None:
continue
self._id_map.pop(id(instance), None)
removed += 1
gc.collect()
return removed
def purge_all(self) -> int:
with self._lock:
removed = len(self._registry)
self._registry.clear()
self._id_map.clear()
self._pending_cleanup_ids.clear()
gc.collect()
return removed
async def apply_hooks(self, instance_id: str, hooks: Any) -> Any:
instance = self._get_instance(instance_id)
cp = getattr(instance.model, "current_patcher", instance)
if cp is None:
cp = instance
return cp.apply_hooks(hooks=hooks)
async def clean_hooks(self, instance_id: str) -> None:
self._get_instance(instance_id).clean_hooks()
async def restore_hook_patches(self, instance_id: str) -> None:
self._get_instance(instance_id).restore_hook_patches()
async def unpatch_hooks(
self, instance_id: str, whitelist_keys_set: Optional[set] = None
) -> None:
self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set)
async def register_all_hook_patches(
self,
instance_id: str,
hooks: Any,
target_dict: Any,
model_options: Any,
registered: Any,
) -> None:
from types import SimpleNamespace
import comfy.hooks
instance = self._get_instance(instance_id)
if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"):
hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks
new_hooks = comfy.hooks.HookGroup()
if hasattr(hook_data, "hooks"):
new_hooks.hooks = (
hook_data["hooks"]
if isinstance(hook_data, dict)
else hook_data.hooks
)
hooks = new_hooks
instance.register_all_hook_patches(
hooks, target_dict, model_options, registered
)
async def get_hook_mode(self, instance_id: str) -> Any:
return getattr(self._get_instance(instance_id), "hook_mode", None)
async def set_hook_mode(self, instance_id: str, value: Any) -> None:
setattr(self._get_instance(instance_id), "hook_mode", value)
async def inject_model(self, instance_id: str) -> None:
instance = self._get_instance(instance_id)
try:
instance.inject_model()
except AttributeError as e:
if "inject" in str(e):
logger.error(
"Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping."
)
return
raise e
async def eject_model(self, instance_id: str) -> None:
self._get_instance(instance_id).eject_model()
async def get_is_injected(self, instance_id: str) -> bool:
return self._get_instance(instance_id).is_injected
async def set_skip_injection(self, instance_id: str, value: bool) -> None:
self._get_instance(instance_id).skip_injection = value
async def get_skip_injection(self, instance_id: str) -> bool:
return self._get_instance(instance_id).skip_injection
async def set_model_sampler_cfg_function(
self,
instance_id: str,
sampler_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
if not callable(sampler_cfg_function):
logger.error(
f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping."
)
return
self._get_instance(instance_id).set_model_sampler_cfg_function(
sampler_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_post_cfg_function(
self,
instance_id: str,
post_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
self._get_instance(instance_id).set_model_sampler_post_cfg_function(
post_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_pre_cfg_function(
self,
instance_id: str,
pre_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
self._get_instance(instance_id).set_model_sampler_pre_cfg_function(
pre_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_calc_cond_batch_function(
self, instance_id: str, fn: Any
) -> None:
self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn)
async def set_model_unet_function_wrapper(
self, instance_id: str, unet_wrapper_function: Any
) -> None:
self._get_instance(instance_id).set_model_unet_function_wrapper(
unet_wrapper_function
)
async def set_model_denoise_mask_function(
self, instance_id: str, denoise_mask_function: Any
) -> None:
self._get_instance(instance_id).set_model_denoise_mask_function(
denoise_mask_function
)
async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None:
self._get_instance(instance_id).set_model_patch(patch, name)
async def set_model_patch_replace(
self,
instance_id: str,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._get_instance(instance_id).set_model_patch_replace(
patch, name, block_name, number, transformer_index
)
async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_input_block_patch(patch)
async def set_model_input_block_patch_after_skip(
self, instance_id: str, patch: Any
) -> None:
self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch)
async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_output_block_patch(patch)
async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_emb_patch(patch)
async def set_model_forward_timestep_embed_patch(
self, instance_id: str, patch: Any
) -> None:
self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch)
async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_double_block_patch(patch)
async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_post_input_patch(patch)
async def set_model_rope_options(self, instance_id: str, options: dict) -> None:
self._get_instance(instance_id).set_model_rope_options(**options)
async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None:
self._get_instance(instance_id).set_model_compute_dtype(dtype)
async def clone_has_same_weights_by_id(
self, instance_id: str, other_id: str
) -> bool:
instance = self._get_instance(instance_id)
other = self._get_instance(other_id)
if not other:
return False
return instance.clone_has_same_weights(other)
async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any:
return self._get_instance(instance_id)._load_list(*args, **kwargs)
async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool:
instance = self._get_instance(instance_id)
other = self._get_instance(other_id)
if hasattr(instance, "is_clone"):
return instance.is_clone(other)
return False
async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None:
self._get_instance(instance_id).add_object_patch(name, obj)
async def add_weight_wrapper(
self, instance_id: str, name: str, function: Any
) -> None:
self._get_instance(instance_id).add_weight_wrapper(name, function)
async def add_wrapper_with_key(
self, instance_id: str, wrapper_type: Any, key: str, fn: Any
) -> None:
self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn)
async def remove_wrappers_with_key(
self, instance_id: str, wrapper_type: str, key: str
) -> None:
self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key)
async def get_wrappers(
self, instance_id: str, wrapper_type: str = None, key: str = None
) -> Any:
if wrapper_type is None and key is None:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "wrappers", {})
)
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_wrappers(wrapper_type, key)
)
async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])(
wrapper_type
)
)
async def add_callback_with_key(
self, instance_id: str, call_type: str, key: str, callback: Any
) -> None:
self._get_instance(instance_id).add_callback_with_key(call_type, key, callback)
async def remove_callbacks_with_key(
self, instance_id: str, call_type: str, key: str
) -> None:
self._get_instance(instance_id).remove_callbacks_with_key(call_type, key)
async def get_callbacks(
self, instance_id: str, call_type: str = None, key: str = None
) -> Any:
if call_type is None and key is None:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "callbacks", {})
)
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_callbacks(call_type, key)
)
async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])(
call_type
)
)
async def set_attachments(
self, instance_id: str, key: str, attachment: Any
) -> None:
self._get_instance(instance_id).set_attachments(key, attachment)
async def get_attachment(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_attachment(key)
)
async def remove_attachments(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_attachments(key)
async def set_injections(self, instance_id: str, key: str, injections: Any) -> None:
self._get_instance(instance_id).set_injections(key, injections)
async def get_injections(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_injections(key)
)
async def remove_injections(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_injections(key)
async def set_additional_models(
self, instance_id: str, key: str, models: Any
) -> None:
self._get_instance(instance_id).set_additional_models(key, models)
async def remove_additional_models(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_additional_models(key)
async def get_nested_additional_models(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_nested_additional_models()
)
async def get_additional_models(self, instance_id: str) -> List[str]:
models = self._get_instance(instance_id).get_additional_models()
return [self.register(m) for m in models]
async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_additional_models_with_key(key)
)
async def model_patches_models(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).model_patches_models()
)
async def get_patches(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy())
async def get_object_patches(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).object_patches.copy()
)
async def add_patches(
self,
instance_id: str,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> Any:
return self._get_instance(instance_id).add_patches(
patches, strength_patch, strength_model
)
async def get_key_patches(
self, instance_id: str, filter_prefix: Optional[str] = None
) -> Any:
res = self._get_instance(instance_id).get_key_patches()
if filter_prefix:
res = {k: v for k, v in res.items() if k.startswith(filter_prefix)}
safe_res = {}
for k, v in res.items():
safe_res[k] = [
f"<Tensor shape={t.shape} dtype={t.dtype}>"
if hasattr(t, "shape")
else str(t)
for t in v
]
return safe_res
async def add_hook_patches(
self,
instance_id: str,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict):
try:
hook.hook_ref = tuple(sorted(hook.hook_ref.items()))
except Exception:
hook.hook_ref = None
self._get_instance(instance_id).add_hook_patches(
hook, patches, strength_patch, strength_model
)
async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any:
if hooks is not None and hasattr(hooks, "hooks"):
for hook in getattr(hooks, "hooks", []):
hook_ref = getattr(hook, "hook_ref", None)
if isinstance(hook_ref, dict):
try:
hook.hook_ref = tuple(sorted(hook_ref.items()))
except Exception:
hook.hook_ref = None
res = self._get_instance(instance_id).get_combined_hook_patches(hooks)
return self._sanitize_rpc_result(res)
async def clear_cached_hook_weights(self, instance_id: str) -> None:
self._get_instance(instance_id).clear_cached_hook_weights()
async def prepare_hook_patches_current_keyframe(
self, instance_id: str, t: Any, hook_group: Any, model_options: Any
) -> None:
self._get_instance(instance_id).prepare_hook_patches_current_keyframe(
t, hook_group, model_options
)
async def get_parent(self, instance_id: str) -> Any:
return getattr(self._get_instance(instance_id), "parent", None)
async def patch_weight_to_device(
self,
instance_id: str,
key: str,
device_to: Any = None,
inplace_update: bool = False,
) -> None:
self._get_instance(instance_id).patch_weight_to_device(
key, device_to, inplace_update
)
async def pin_weight_to_device(self, instance_id: str, key: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.pin_weight_to_device(key)
async def unpin_weight(self, instance_id: str, key: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.unpin_weight(key)
async def unpin_all_weights(self, instance_id: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.unpin_all_weights()
async def calculate_weight(
self,
instance_id: str,
patches: Any,
weight: Any,
key: str,
intermediate_dtype: Any = float,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).calculate_weight(
patches, weight, key, intermediate_dtype
)
)
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
try:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id).model, name)
)
except AttributeError:
return None
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs)
return {
k: {"numel": v.numel(), "element_size": v.element_size()}
for k, v in sd.items()
}
async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result))
async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
)
async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
async def load_lora(
self,
instance_id: str,
lora_path: str,
strength_model: float,
clip_id: Optional[str] = None,
strength_clip: float = 1.0,
) -> dict:
import comfy.utils
import comfy.sd
import folder_paths
from comfy.isolation.clip_proxy import CLIPRegistry
model = self._get_instance(instance_id)
clip = None
if clip_id:
clip = CLIPRegistry()._get_instance(clip_id)
lora_full_path = folder_paths.get_full_path("loras", lora_path)
if lora_full_path is None:
raise ValueError(f"LoRA file not found: {lora_path}")
lora = comfy.utils.load_torch_file(lora_full_path)
new_model, new_clip = comfy.sd.load_lora_for_models(
model, clip, lora, strength_model, strength_clip
)
new_model_id = self.register(new_model) if new_model else None
new_clip_id = (
CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None
)
return {"model_id": new_model_id, "clip_id": new_clip_id}

View File

@@ -1,154 +0,0 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
# Isolation utilities and serializers for ModelPatcherProxy
from __future__ import annotations
import logging
import os
from typing import Any
logger = logging.getLogger(__name__)
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if not isolation_active:
return model_patcher
if is_child:
return model_patcher
if isinstance(model_patcher, ModelPatcherProxy):
return model_patcher
registry = ModelPatcherRegistry()
model_id = registry.register(model_patcher)
logger.debug(f"Isolated ModelPatcher: {model_id}")
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
def register_hooks_serializers(registry=None):
from pyisolate._internal.serialization_registry import SerializerRegistry
import comfy.hooks
if registry is None:
registry = SerializerRegistry.get_instance()
def serialize_enum(obj):
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
def deserialize_enum(data):
cls_name, val_name = data["__enum__"].split(".")
cls = getattr(comfy.hooks, cls_name)
return cls[val_name]
registry.register("EnumHookType", serialize_enum, deserialize_enum)
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
def serialize_hook_group(obj):
return {"__type__": "HookGroup", "hooks": obj.hooks}
def deserialize_hook_group(data):
hg = comfy.hooks.HookGroup()
for h in data["hooks"]:
hg.add(h)
return hg
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
def serialize_dict_state(obj):
d = obj.__dict__.copy()
d["__type__"] = type(obj).__name__
if "custom_should_register" in d:
del d["custom_should_register"]
return d
def deserialize_dict_state_generic(cls):
def _deserialize(data):
h = cls()
h.__dict__.update(data)
return h
return _deserialize
def deserialize_hook_keyframe(data):
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
h.__dict__.update(data)
return h
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
def deserialize_hook_keyframe_group(data):
h = comfy.hooks.HookKeyframeGroup()
h.__dict__.update(data)
return h
registry.register(
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
)
def deserialize_hook(data):
h = comfy.hooks.Hook()
h.__dict__.update(data)
return h
registry.register("Hook", serialize_dict_state, deserialize_hook)
def deserialize_weight_hook(data):
h = comfy.hooks.WeightHook()
h.__dict__.update(data)
return h
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
def serialize_set(obj):
return {"__set__": list(obj)}
def deserialize_set(data):
return set(data["__set__"])
registry.register("set", serialize_set, deserialize_set)
try:
from comfy.weight_adapter.lora import LoRAAdapter
def serialize_lora(obj):
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
def deserialize_lora(data):
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
except Exception:
pass
try:
from comfy.hooks import _HookRef
import uuid
def serialize_hook_ref(obj):
return {
"__hook_ref__": True,
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
}
def deserialize_hook_ref(data):
h = _HookRef()
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
return h
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
except ImportError:
pass
except Exception as e:
logger.warning(f"Failed to register _HookRef: {e}")
try:
register_hooks_serializers()
except Exception as e:
logger.error(f"Failed to initialize hook serializers: {e}")

View File

@@ -1,253 +0,0 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations
import asyncio
import logging
from typing import Any
from comfy.isolation.proxies.base import (
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
logger = logging.getLogger(__name__)
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
except Exception:
return None
for t in tensors:
if isinstance(t, torch.Tensor) and t.is_cuda:
return t.device
for t in tensors:
if isinstance(t, torch.Tensor):
return t.device
return None
def _to_device(obj: Any, device: Any) -> Any:
try:
import torch
except Exception:
return obj
if device is None:
return obj
if isinstance(obj, torch.Tensor):
if obj.device != device:
return obj.to(device)
return obj
if isinstance(obj, (list, tuple)):
converted = [_to_device(x, device) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_device(v, device) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.calculate_input(sigma, noise))
async def calculate_denoised(
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.calculate_denoised(sigma, model_output, model_input)
)
async def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
)
async def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
async def timestep(self, instance_id: str, sigma: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.timestep(sigma)
async def sigma(self, instance_id: str, timestep: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.sigma(timestep)
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
sampling = self._get_instance(instance_id)
return sampling.percent_to_sigma(percent)
async def get_sigma_min(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_min)
async def get_sigma_max(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_max)
async def get_sigma_data(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_data)
async def get_sigmas(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigmas)
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
sampling = self._get_instance(instance_id)
sampling.set_sigmas(sigmas)
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
_registry_class = ModelSamplingRegistry
__module__ = "comfy.isolation.model_sampling_proxy"
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
)
else:
registry = ModelSamplingRegistry()
class _LocalCaller:
def calculate_input(
self, instance_id: str, sigma: Any, noise: Any
) -> Any:
return registry.calculate_input(instance_id, sigma, noise)
def calculate_denoised(
self,
instance_id: str,
sigma: Any,
model_output: Any,
model_input: Any,
) -> Any:
return registry.calculate_denoised(
instance_id, sigma, model_output, model_input
)
def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
return registry.noise_scaling(
instance_id, sigma, noise, latent_image, max_denoise
)
def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
return registry.inverse_noise_scaling(
instance_id, sigma, latent
)
def timestep(self, instance_id: str, sigma: Any) -> Any:
return registry.timestep(instance_id, sigma)
def sigma(self, instance_id: str, timestep: Any) -> Any:
return registry.sigma(instance_id, timestep)
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
return registry.percent_to_sigma(instance_id, percent)
def get_sigma_min(self, instance_id: str) -> Any:
return registry.get_sigma_min(instance_id)
def get_sigma_max(self, instance_id: str) -> Any:
return registry.get_sigma_max(instance_id)
def get_sigma_data(self, instance_id: str) -> Any:
return registry.get_sigma_data(instance_id)
def get_sigmas(self, instance_id: str) -> Any:
return registry.get_sigmas(instance_id)
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
return registry.set_sigmas(instance_id, sigmas)
self._rpc_caller = _LocalCaller()
return self._rpc_caller
def _call(self, method_name: str, *args: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
result = method(self._instance_id, *args)
if asyncio.iscoroutine(result):
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(result)
return result
@property
def sigma_min(self) -> Any:
return self._call("get_sigma_min")
@property
def sigma_max(self) -> Any:
return self._call("get_sigma_max")
@property
def sigma_data(self) -> Any:
return self._call("get_sigma_data")
@property
def sigmas(self) -> Any:
return self._call("get_sigmas")
def calculate_input(self, sigma: Any, noise: Any) -> Any:
return self._call("calculate_input", sigma, noise)
def calculate_denoised(
self, sigma: Any, model_output: Any, model_input: Any
) -> Any:
return self._call("calculate_denoised", sigma, model_output, model_input)
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)
def sigma(self, timestep: Any) -> Any:
return self._call("sigma", timestep)
def percent_to_sigma(self, percent: float) -> Any:
return self._call("percent_to_sigma", percent)
def set_sigmas(self, sigmas: Any) -> None:
return self._call("set_sigmas", sigmas)

View File

@@ -1,17 +0,0 @@
from .base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
__all__ = [
"IS_CHILD_PROCESS",
"BaseRegistry",
"BaseProxy",
"get_thread_loop",
"run_coro_in_new_loop",
"detach_if_grad",
]

View File

@@ -1,213 +0,0 @@
# pylint: disable=global-statement,import-outside-toplevel,protected-access
from __future__ import annotations
import asyncio
import logging
import os
import threading
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton: # type: ignore[no-redef]
pass
logger = logging.getLogger(__name__)
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
_thread_local = threading.local()
T = TypeVar("T")
def get_thread_loop() -> asyncio.AbstractEventLoop:
loop = getattr(_thread_local, "loop", None)
if loop is None or loop.is_closed():
loop = asyncio.new_event_loop()
_thread_local.loop = loop
return loop
def run_coro_in_new_loop(coro: Any) -> Any:
result_box: Dict[str, Any] = {}
exc_box: Dict[str, BaseException] = {}
def runner() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result_box["value"] = loop.run_until_complete(coro)
except Exception as exc: # noqa: BLE001
exc_box["exc"] = exc
finally:
loop.close()
t = threading.Thread(target=runner, daemon=True)
t.start()
t.join()
if "exc" in exc_box:
raise exc_box["exc"]
return result_box.get("value")
def detach_if_grad(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
return obj.detach() if obj.requires_grad else obj
if isinstance(obj, (list, tuple)):
return type(obj)(detach_if_grad(x) for x in obj)
if isinstance(obj, dict):
return {k: detach_if_grad(v) for k, v in obj.items()}
return obj
class BaseRegistry(ProxiedSingleton, Generic[T]):
_type_prefix: str = "base"
def __init__(self) -> None:
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
super().__init__()
self._registry: Dict[str, T] = {}
self._id_map: Dict[int, str] = {}
self._counter = 0
self._lock = threading.Lock()
def register(self, instance: T) -> str:
with self._lock:
obj_id = id(instance)
if obj_id in self._id_map:
return self._id_map[obj_id]
instance_id = f"{self._type_prefix}_{self._counter}"
self._counter += 1
self._registry[instance_id] = instance
self._id_map[obj_id] = instance_id
return instance_id
def unregister_sync(self, instance_id: str) -> None:
with self._lock:
instance = self._registry.pop(instance_id, None)
if instance:
self._id_map.pop(id(instance), None)
def _get_instance(self, instance_id: str) -> T:
if IS_CHILD_PROCESS:
raise RuntimeError(
f"[{self.__class__.__name__}] _get_instance called in child"
)
with self._lock:
instance = self._registry.get(instance_id)
if instance is None:
raise ValueError(f"{instance_id} not found")
return instance
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
global _GLOBAL_LOOP
_GLOBAL_LOOP = loop
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
def __init__(
self,
instance_id: str,
registry: Optional[Any] = None,
manage_lifecycle: bool = False,
) -> None:
self._instance_id = instance_id
self._rpc_caller: Optional[Any] = None
self._registry = registry if registry is not None else self._registry_class()
self._manage_lifecycle = manage_lifecycle
self._cleaned_up = False
if manage_lifecycle and not IS_CHILD_PROCESS:
self._finalizer = weakref.finalize(
self, self._registry.unregister_sync, instance_id
)
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is None:
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
return self._rpc_caller
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
coro = method(self._instance_id, *args, **kwargs)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
# If we are already in the global loop, we can't block on it?
# Actually, this method is synchronous (__getattr__ -> lambda).
# If called from async context in main loop, we need to handle that.
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
# We are in the main loop. We cannot await/block here if we are just a sync function.
# But proxies are often called from sync code.
# If called from sync code in main loop, creating a new loop is bad.
# But we can't await `coro`.
# This implies proxies MUST be awaited if called from async context?
# Existing code used `run_coro_in_new_loop` which is weird.
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
# we use run_coroutine_threadsafe.
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result()
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}
def __setstate__(self, state: Dict[str, Any]) -> None:
self._instance_id = state["_instance_id"]
self._rpc_caller = None
self._registry = self._registry_class()
self._manage_lifecycle = False
self._cleaned_up = False
def cleanup(self) -> None:
if self._cleaned_up or IS_CHILD_PROCESS:
return
self._cleaned_up = True
finalizer = getattr(self, "_finalizer", None)
if finalizer is not None:
finalizer.detach()
self._registry.unregister_sync(self._instance_id)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._instance_id}>"
def create_rpc_method(method_name: str) -> Callable[..., Any]:
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
return self._call_rpc(method_name, *args, **kwargs)
method.__name__ = method_name
return method

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
from typing import Dict
import folder_paths
from pyisolate import ProxiedSingleton
class FolderPathsProxy(ProxiedSingleton):
"""
Dynamic proxy for folder_paths.
Uses __getattr__ for most lookups, with explicit handling for
mutable collections to ensure efficient by-value transfer.
"""
def __getattr__(self, name):
return getattr(folder_paths, name)
# Return dict snapshots (avoid RPC chatter)
@property
def folder_names_and_paths(self) -> Dict:
return dict(folder_paths.folder_names_and_paths)
@property
def extension_mimetypes_cache(self) -> Dict:
return dict(folder_paths.extension_mimetypes_cache)
@property
def filename_list_cache(self) -> Dict:
return dict(folder_paths.filename_list_cache)

View File

@@ -1,98 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, Optional
class AnyTypeProxy(str):
"""Replacement for custom AnyType objects used by some nodes."""
def __new__(cls, value: str = "*"):
return super().__new__(cls, value)
def __ne__(self, other): # type: ignore[override]
return False
class FlexibleOptionalInputProxy(dict):
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
super().__init__()
self.type = flex_type
if data:
self.update(data)
def __getitem__(self, key): # type: ignore[override]
return (self.type,)
def __contains__(self, key): # type: ignore[override]
return True
class ByPassTypeTupleProxy(tuple):
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
def __new__(cls, values):
return super().__new__(cls, values)
def __getitem__(self, index): # type: ignore[override]
if index >= len(self):
return AnyTypeProxy("*")
return super().__getitem__(index)
def _restore_special_value(value: Any) -> Any:
if isinstance(value, dict):
if value.get("__pyisolate_any_type__"):
return AnyTypeProxy(value.get("value", "*"))
if value.get("__pyisolate_flexible_optional__"):
flex_type = _restore_special_value(value.get("type"))
data_raw = value.get("data")
data = (
{k: _restore_special_value(v) for k, v in data_raw.items()}
if isinstance(data_raw, dict)
else {}
)
return FlexibleOptionalInputProxy(flex_type, data)
if value.get("__pyisolate_tuple__") is not None:
return tuple(
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
)
if value.get("__pyisolate_bypass_tuple__") is not None:
return ByPassTypeTupleProxy(
tuple(
_restore_special_value(v)
for v in value["__pyisolate_bypass_tuple__"]
)
)
return {k: _restore_special_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_restore_special_value(v) for v in value]
return value
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
if not isinstance(raw, dict):
return raw # type: ignore[return-value]
restored: Dict[str, object] = {}
for section, entries in raw.items():
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
restored[section] = _restore_special_value(entries)
elif isinstance(entries, dict):
restored[section] = {
k: _restore_special_value(v) for k, v in entries.items()
}
else:
restored[section] = _restore_special_value(entries)
return restored
__all__ = [
"AnyTypeProxy",
"FlexibleOptionalInputProxy",
"ByPassTypeTupleProxy",
"restore_input_types",
]

View File

@@ -1,27 +0,0 @@
import comfy.model_management as mm
from pyisolate import ProxiedSingleton
class ModelManagementProxy(ProxiedSingleton):
"""
Dynamic proxy for comfy.model_management.
Uses __getattr__ to forward all calls to the underlying module,
reducing maintenance burden.
"""
# Explicitly expose Enums/Classes as properties
@property
def VRAMState(self):
return mm.VRAMState
@property
def CPUState(self):
return mm.CPUState
@property
def OOM_EXCEPTION(self):
return mm.OOM_EXCEPTION
def __getattr__(self, name):
"""Forward all other attribute access to the module."""
return getattr(mm, name)

View File

@@ -1,35 +0,0 @@
from __future__ import annotations
import logging
from typing import Any, Optional
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton:
pass
from comfy_execution.progress import get_progress_state
logger = logging.getLogger(__name__)
class ProgressProxy(ProxiedSingleton):
def set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
__all__ = ["ProgressProxy"]

View File

@@ -1,265 +0,0 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
"""Stateless RPC Implementation for PromptServer.
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
- Host: PromptServerService (RPC Handler)
- Child: PromptServerStub (Interface Implementation)
"""
from __future__ import annotations
import asyncio
import os
from typing import Any, Dict, Optional, Callable
import logging
from aiohttp import web
# IMPORTS
from pyisolate import ProxiedSingleton
logger = logging.getLogger(__name__)
LOG_PREFIX = "[Isolation:C<->H]"
# ...
# =============================================================================
# CHILD SIDE: PromptServerStub
# =============================================================================
class PromptServerStub:
"""Stateless Stub for PromptServer."""
# Masquerade as the real server module
__module__ = "server"
_instance: Optional["PromptServerStub"] = None
_rpc: Optional[Any] = None # This will be the Caller object
_source_file: Optional[str] = None
def __init__(self):
self.routes = RouteStub(self)
@classmethod
def set_rpc(cls, rpc: Any) -> None:
"""Inject RPC client (called by adapter.py or manually)."""
# Create caller for HOST Service
# Assuming Host Service is registered as "PromptServerService" (class name)
# We target the Host Service Class
target_id = "PromptServerService"
# We need to pass a class to create_caller? Usually yes.
# But we don't have the Service class imported here necessarily (if running on child).
# pyisolate check verify_service type?
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
# We need a dummy class with right name?
# Or just rely on string ID if create_caller supports it?
# Standard: rpc.create_caller(PromptServerStub, target_id)
# But wait, PromptServerStub is the *Local* class.
# We want to call *Remote* class.
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
# The first arg is 'service_cls'.
cls._rpc = rpc.create_caller(
PromptServerService, target_id
) # We import Service below?
# We need PromptServerService available for the create_caller call?
# Or just use the Stub class if ID matches?
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
@property
def instance(self) -> "PromptServerStub":
return self
# ... Compatibility ...
@classmethod
def _get_source_file(cls) -> str:
if cls._source_file is None:
import folder_paths
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
return cls._source_file
@property
def __file__(self) -> str:
return self._get_source_file()
# --- Properties ---
@property
def client_id(self) -> Optional[str]:
return "isolated_client"
def supports(self, feature: str) -> bool:
return True
@property
def app(self):
raise RuntimeError(
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
)
@property
def prompt_queue(self):
raise RuntimeError(
"PromptServer.prompt_queue is not accessible in isolated nodes."
)
# --- UI Communication (RPC Delegates) ---
async def send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send_sync(event, data, sid)
async def send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send(event, data, sid)
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
if self._rpc:
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
# We must schedule it?
# Or use fire_remote equivalent?
# Caller object usually proxies calls. If host method is async, it returns coro.
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
# But UtilsProxy hook wrapper creates task.
# Does send_progress_text need to be sync? Yes, node code calls it sync.
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
except RuntimeError:
pass # Sync context without loop?
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
"""Register a route handler via RPC."""
if not self._rpc:
logger.error("RPC not initialized in PromptServerStub")
return
# Fire registration async
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
pass
class RouteStub:
"""Simulates aiohttp.web.RouteTableDef."""
def __init__(self, stub: PromptServerStub):
self._stub = stub
def get(self, path: str):
def decorator(handler):
self._stub.register_route("GET", path, handler)
return handler
return decorator
def post(self, path: str):
def decorator(handler):
self._stub.register_route("POST", path, handler)
return handler
return decorator
def patch(self, path: str):
def decorator(handler):
self._stub.register_route("PATCH", path, handler)
return handler
return decorator
def put(self, path: str):
def decorator(handler):
self._stub.register_route("PUT", path, handler)
return handler
return decorator
def delete(self, path: str):
def decorator(handler):
self._stub.register_route("DELETE", path, handler)
return handler
return decorator
# =============================================================================
# HOST SIDE: PromptServerService
# =============================================================================
class PromptServerService(ProxiedSingleton):
"""Host-side RPC Service for PromptServer."""
def __init__(self):
# We will bind to the real server instance lazily or via global import
pass
@property
def server(self):
from server import PromptServer
return PromptServer.instance
async def ui_send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send_sync(event, data, sid)
async def ui_send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send(event, data, sid)
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
# Made async to be awaitable by RPC layer
self.server.send_progress_text(text, node_id, sid)
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
async def route_wrapper(request: web.Request) -> web.Response:
# 1. Capture request data
req_data = {
"method": request.method,
"path": request.path,
"query": dict(request.query),
}
if request.can_read_body:
req_data["text"] = await request.text()
try:
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
result = await child_handler_proxy(req_data)
# 3. Serialize Response
return self._serialize_response(result)
except Exception as e:
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
return web.Response(status=500, text=str(e))
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
def _serialize_response(self, result: Any) -> web.Response:
"""Helper to convert Child result -> web.Response"""
if isinstance(result, web.Response):
return result
# Handle dict (json)
if isinstance(result, dict):
return web.json_response(result)
# Handle string
if isinstance(result, str):
return web.Response(text=result)
# Fallback
return web.Response(text=str(result))

View File

@@ -1,64 +0,0 @@
# pylint: disable=cyclic-import,import-outside-toplevel
from __future__ import annotations
from typing import Optional, Any
import comfy.utils
from pyisolate import ProxiedSingleton
import os
class UtilsProxy(ProxiedSingleton):
"""
Proxy for comfy.utils.
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
from isolated nodes reach the host.
"""
# _instance and __new__ removed to rely on SingletonMetaclass
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
# Create caller using class name as ID (standard for Singletons)
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
async def progress_bar_hook(
self,
value: int,
total: int,
preview: Optional[bytes] = None,
node_id: Optional[str] = None,
) -> Any:
"""
Host-side implementation: forwards the call to the real global hook.
Child-side: this method call is intercepted by RPC and sent to host.
"""
if os.environ.get("PYISOLATE_CHILD") == "1":
# Manual RPC dispatch for Child process
# Use class-level RPC storage (Static Injection)
if UtilsProxy._rpc:
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Fallback channel: global child rpc
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
get_child_rpc_instance()
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
# but we need a caller. For now, just pass to avoid crashing.
pass
except (ImportError, LookupError):
pass
return None
# Host Execution
if comfy.utils.PROGRESS_BAR_HOOK is not None:
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
def set_progress_bar_global_hook(self, hook: Any) -> None:
"""Forward hook registration (though usually not needed from child)."""
comfy.utils.set_progress_bar_global_hook(hook)

View File

@@ -1,49 +0,0 @@
import asyncio
import logging
import threading
logger = logging.getLogger(__name__)
class RpcBridge:
"""Minimal helper to run coroutines synchronously inside isolated processes.
If an event loop is already running, the coroutine is executed on a fresh
thread with its own loop to avoid nested run_until_complete errors.
"""
def run_sync(self, maybe_coro):
if not asyncio.iscoroutine(maybe_coro):
return maybe_coro
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
result_container = {}
exc_container = {}
def _runner():
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result_container["value"] = new_loop.run_until_complete(maybe_coro)
except Exception as exc: # pragma: no cover
exc_container["error"] = exc
finally:
try:
new_loop.close()
except Exception:
pass
t = threading.Thread(target=_runner, daemon=True)
t.start()
t.join()
if "error" in exc_container:
raise exc_container["error"]
return result_container.get("value")
return asyncio.run(maybe_coro)

View File

@@ -1,343 +0,0 @@
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
from __future__ import annotations
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Set, TYPE_CHECKING
from .proxies.helper_proxies import restore_input_types
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
from .shm_forensics import scan_shm_forensics
if TYPE_CHECKING:
from .extension_wrapper import ComfyNodeExtension
LOG_PREFIX = "]["
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
def _resource_snapshot() -> Dict[str, int]:
fd_count = -1
shm_sender_files = 0
try:
fd_count = len(os.listdir("/proc/self/fd"))
except Exception:
pass
try:
shm_root = Path("/dev/shm")
if shm_root.exists():
prefix = f"torch_{os.getpid()}_"
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
except Exception:
pass
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
summary: Dict[str, int] = {
"tensor_count": 0,
"cpu_tensors": 0,
"cuda_tensors": 0,
"shared_cpu_tensors": 0,
"tensor_bytes": 0,
}
try:
import torch
except Exception:
return summary
def visit(node: Any) -> None:
if isinstance(node, torch.Tensor):
summary["tensor_count"] += 1
summary["tensor_bytes"] += int(node.numel() * node.element_size())
if node.device.type == "cpu":
summary["cpu_tensors"] += 1
if node.is_shared():
summary["shared_cpu_tensors"] += 1
elif node.device.type == "cuda":
summary["cuda_tensors"] += 1
return
if isinstance(node, dict):
for v in node.values():
visit(v)
return
if isinstance(node, (list, tuple)):
for v in node:
visit(v)
visit(value)
return summary
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
for key, value in inputs.items():
key_text = str(key)
if "unique_id" in key_text:
return str(value)
return None
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
return
if not callable(flush_tensor_keeper):
return
flushed = flush_tensor_keeper()
if flushed > 0:
logger.debug(
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
)
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
import comfy.model_management as model_management
model_management.cleanup_models_gc()
model_management.cleanup_models()
device = model_management.get_torch_device()
if not hasattr(device, "type") or device.type == "cpu":
return
required = max(
model_management.minimum_inference_memory(),
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=True)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.cleanup_models()
model_management.soft_empty_cache()
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
def _detach_shared_cpu_tensors(value: Any) -> Any:
try:
import torch
except Exception:
return value
if isinstance(value, torch.Tensor):
if value.device.type == "cpu" and value.is_shared():
clone = value.clone()
if value.requires_grad:
clone.requires_grad_(True)
return clone
return value
if isinstance(value, list):
return [_detach_shared_cpu_tensors(v) for v in value]
if isinstance(value, tuple):
return tuple(_detach_shared_cpu_tensors(v) for v in value)
if isinstance(value, dict):
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
return value
def build_stub_class(
node_name: str,
info: Dict[str, object],
extension: "ComfyNodeExtension",
running_extensions: Dict[str, "ComfyNodeExtension"],
logger: logging.Logger,
) -> type:
is_v3 = bool(info.get("is_v3", False))
function_name = "_pyisolate_execute"
restored_input_types = restore_input_types(info.get("input_types", {}))
async def _execute(self, **inputs):
from comfy.isolation import _RUNNING_EXTENSIONS
# Update BOTH the local dict AND the module-level dict
running_extensions[extension.name] = extension
_RUNNING_EXTENSIONS[extension.name] = extension
prev_child = None
node_unique_id = _extract_hidden_unique_id(inputs)
summary = _tensor_transport_summary(inputs)
resources = _resource_snapshot()
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
summary["tensor_count"],
summary["cpu_tensors"],
summary["cuda_tensors"],
summary["shared_cpu_tensors"],
summary["tensor_bytes"],
resources["fd_count"],
resources["shm_sender_files"],
)
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
try:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
from pyisolate._internal.model_serialization import (
serialize_for_isolation,
deserialize_from_isolation,
)
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
logger.debug(
"%s ISO:serialize_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
serialized = serialize_for_isolation(inputs)
logger.debug(
"%s ISO:serialize_done ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
logger.debug(
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
result = await extension.execute_node(node_name, **serialized)
logger.debug(
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
deserialized = await deserialize_from_isolation(result, extension)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return _detach_shared_cpu_tensors(deserialized)
except ImportError:
return await extension.execute_node(node_name, **inputs)
except Exception:
logger.exception(
"%s ISO:execute_error ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
raise
finally:
if prev_child is not None:
os.environ["PYISOLATE_CHILD"] = prev_child
logger.debug(
"%s ISO:execute_end ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
def _input_types(
cls,
include_hidden: bool = True,
return_schema: bool = False,
live_inputs: Any = None,
):
if not is_v3:
return restored_input_types
inputs_copy = copy.deepcopy(restored_input_types)
if not include_hidden:
inputs_copy.pop("hidden", None)
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
dynamic = inputs_copy.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
if return_schema:
hidden_vals = info.get("hidden", []) or []
hidden_enums = []
for h in hidden_vals:
try:
hidden_enums.append(latest_io.Hidden(h))
except Exception:
hidden_enums.append(h)
class SchemaProxy:
hidden = hidden_enums
return inputs_copy, SchemaProxy, v3_data
return inputs_copy
def _validate_class(cls):
return True
def _get_node_info_v1(cls):
return info.get("schema_v1", {})
def _get_base_class(cls):
return latest_io.ComfyNode
attributes: Dict[str, object] = {
"FUNCTION": function_name,
"CATEGORY": info.get("category", ""),
"OUTPUT_NODE": info.get("output_node", False),
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
"RETURN_NAMES": info.get("return_names"),
function_name: _execute,
"_pyisolate_extension": extension,
"_pyisolate_node_name": node_name,
"INPUT_TYPES": classmethod(_input_types),
}
output_is_list = info.get("output_is_list")
if output_is_list is not None:
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
if is_v3:
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
attributes["DESCRIPTION"] = info.get("description", "")
attributes["EXPERIMENTAL"] = info.get("experimental", False)
attributes["DEPRECATED"] = info.get("deprecated", False)
attributes["API_NODE"] = info.get("api_node", False)
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
bases = (_ComfyNodeInternal,) if is_v3 else ()
stub_cls = type(class_name, bases, attributes)
if is_v3:
try:
stub_cls.VALIDATE_CLASS()
except Exception as e:
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
return stub_cls
def get_class_types_for_extension(
extension_name: str,
running_extensions: Dict[str, "ComfyNodeExtension"],
specs: List[Any],
) -> Set[str]:
extension = running_extensions.get(extension_name)
if not extension:
return set()
ext_path = Path(extension.module_path)
class_types = set()
for spec in specs:
if spec.module_path.resolve() == ext_path.resolve():
class_types.add(spec.node_name)
return class_types
__all__ = ["build_stub_class", "get_class_types_for_extension"]

View File

@@ -1,217 +0,0 @@
# pylint: disable=consider-using-from-import,import-outside-toplevel
from __future__ import annotations
import atexit
import hashlib
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Set
LOG_PREFIX = "]["
logger = logging.getLogger(__name__)
def _shm_debug_enabled() -> bool:
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
class _SHMForensicsTracker:
def __init__(self) -> None:
self._started = False
self._tracked_files: Set[str] = set()
self._current_model_context: Dict[str, str] = {
"id": "unknown",
"name": "unknown",
"hash": "????",
}
@staticmethod
def _snapshot_shm() -> Set[str]:
shm_path = Path("/dev/shm")
if not shm_path.exists():
return set()
return {f.name for f in shm_path.glob("torch_*")}
def start(self) -> None:
if self._started or not _shm_debug_enabled():
return
self._tracked_files = self._snapshot_shm()
self._started = True
logger.debug(
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
)
def stop(self) -> None:
if not self._started:
return
self.scan("shutdown", refresh_model_context=True)
self._started = False
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
def _compute_model_hash(self, model_patcher: Any) -> str:
try:
model_instance_id = getattr(model_patcher, "_instance_id", None)
if model_instance_id is not None:
model_id_text = str(model_instance_id)
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
import torch
real_model = (
model_patcher.model
if hasattr(model_patcher, "model")
else model_patcher
)
tensor = None
if hasattr(real_model, "parameters"):
for p in real_model.parameters():
if torch.is_tensor(p) and p.numel() > 0:
tensor = p
break
if tensor is None:
return "0000"
flat = tensor.flatten()
values = []
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
for i in indices:
if i < flat.shape[0]:
values.append(flat[i].item())
size = 0
if hasattr(model_patcher, "model_size"):
size = model_patcher.model_size()
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
except Exception:
return "err!"
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
try:
import comfy.model_management as model_management
except Exception:
return []
snapshot: List[Dict[str, Any]] = []
try:
for loaded_model in model_management.current_loaded_models:
model = loaded_model.model
if model is None:
continue
if str(getattr(loaded_model, "device", "")) != "cuda:0":
continue
name = (
model.model.__class__.__name__
if hasattr(model, "model")
else type(model).__name__
)
model_hash = self._compute_model_hash(model)
model_instance_id = getattr(model, "_instance_id", None)
if model_instance_id is None:
model_instance_id = model_hash
snapshot.append(
{
"name": str(name),
"id": str(model_instance_id),
"hash": str(model_hash or "????"),
"used": bool(getattr(loaded_model, "currently_used", False)),
}
)
except Exception:
return []
return snapshot
def _update_model_context(self) -> None:
snapshot = self._get_models_snapshot()
selected = None
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
if used_models:
selected = used_models[-1]
else:
live_models = [m for m in snapshot if m.get("id")]
if live_models:
selected = live_models[-1]
if selected is None:
self._current_model_context = {
"id": "unknown",
"name": "unknown",
"hash": "????",
}
return
self._current_model_context = {
"id": str(selected.get("id", "unknown")),
"name": str(selected.get("name", "unknown")),
"hash": str(selected.get("hash", "????") or "????"),
}
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
if not self._started or not _shm_debug_enabled():
return
if refresh_model_context:
self._update_model_context()
current = self._snapshot_shm()
added = current - self._tracked_files
removed = self._tracked_files - current
self._tracked_files = current
if not added and not removed:
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
return
for filename in sorted(added):
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
model_id = self._current_model_context["id"]
if model_id == "unknown":
logger.error(
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
LOG_PREFIX,
filename,
)
else:
logger.info(
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
LOG_PREFIX,
model_id,
filename,
self._current_model_context["name"],
self._current_model_context["hash"],
)
for filename in sorted(removed):
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
logger.debug(
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
LOG_PREFIX,
marker,
len(added),
len(removed),
len(self._tracked_files),
)
_TRACKER = _SHMForensicsTracker()
def start_shm_forensics() -> None:
_TRACKER.start()
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
def stop_shm_forensics() -> None:
_TRACKER.stop()
atexit.register(stop_shm_forensics)

View File

@@ -1,214 +0,0 @@
# pylint: disable=attribute-defined-outside-init
import logging
from typing import Any
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
logger = logging.getLogger(__name__)
class FirstStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "first_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
async def has_property(self, instance_id: str, name: str) -> bool:
obj = self._get_instance(instance_id)
return hasattr(obj, name)
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
_registry_class = FirstStageModelRegistry
__module__ = "comfy.ldm.models.autoencoder"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<FirstStageModelProxy {self._instance_id}>"
class VAERegistry(BaseRegistry[Any]):
_type_prefix = "vae"
async def get_patcher_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return ModelPatcherRegistry().register(vae.patcher)
async def get_first_stage_model_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return FirstStageModelRegistry().register(vae.first_stage_model)
async def encode(self, instance_id: str, pixels: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
async def encode_tiled(
self,
instance_id: str,
pixels: Any,
tile_x: int = 512,
tile_y: int = 512,
overlap: int = 64,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_tiled(
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
)
)
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
async def decode_tiled(
self,
instance_id: str,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).decode_tiled(
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
)
)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
async def process_input(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_input(image))
async def process_output(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_output(image))
class VAEProxy(BaseProxy[VAERegistry]):
_registry_class = VAERegistry
__module__ = "comfy.sd"
@property
def patcher(self) -> ModelPatcherProxy:
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@property
def first_stage_model(self) -> FirstStageModelProxy:
if not hasattr(self, "_first_stage_model_proxy"):
fsm_id = self._call_rpc("get_first_stage_model_id")
self._first_stage_model_proxy = FirstStageModelProxy(
fsm_id, manage_lifecycle=False
)
return self._first_stage_model_proxy
@property
def vae_dtype(self) -> Any:
return self._get_property("vae_dtype")
def encode(self, pixels: Any) -> Any:
return self._call_rpc("encode", pixels)
def encode_tiled(
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
) -> Any:
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
def decode(self, samples: Any, **kwargs: Any) -> Any:
return self._call_rpc("decode", samples, **kwargs)
def decode_tiled(
self,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return self._call_rpc(
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def _get_property(self, name: str) -> Any:
return self._call_rpc("get_property", name)
@property
def latent_dim(self) -> int:
return self._get_property("latent_dim")
@property
def latent_channels(self) -> int:
return self._get_property("latent_channels")
@property
def downscale_ratio(self) -> Any:
return self._get_property("downscale_ratio")
@property
def upscale_ratio(self) -> Any:
return self._get_property("upscale_ratio")
@property
def output_channels(self) -> int:
return self._get_property("output_channels")
@property
def check_not_vide(self) -> bool:
return self._get_property("not_video")
@property
def device(self) -> Any:
return self._get_property("device")
@property
def working_dtypes(self) -> Any:
return self._get_property("working_dtypes")
@property
def disable_offload(self) -> bool:
return self._get_property("disable_offload")
@property
def size(self) -> Any:
return self._get_property("size")
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_encode", shape, dtype)
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_decode", shape, dtype)
def process_input(self, image: Any) -> Any:
return self._call_rpc("process_input", image)
def process_output(self, image: Any) -> Any:
return self._call_rpc("process_output", image)
if not IS_CHILD_PROCESS:
_VAE_REGISTRY_SINGLETON = VAERegistry()
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()

View File

@@ -1,5 +1,4 @@
import math
import os
from functools import partial
from scipy import integrate
@@ -13,8 +12,8 @@ from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.cli_args import args
from comfy.utils import model_trange as trange
def append_zero(x):
@@ -192,13 +191,6 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:
x = x.to(target_device)
s_in = s_in.to(target_device)
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.

View File

@@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
def process_out(self, latent):
return latent
class ZImagePixelSpace(ChromaRadiance):
"""Pixel-space latent format for ZImage DCT variant.
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass

View File

@@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
def invert_slices(slices, length):
@@ -858,3 +859,267 @@ class NextDiT(nn.Module):
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img
#############################################################################
# Pixel Space Decoder Components #
#############################################################################
def _modulate_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class PixelResBlock(nn.Module):
"""
Residual block with AdaLN modulation, zero-initialised so it starts as
an identity at the beginning of training.
"""
def __init__(self, channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
h = self.mlp(h)
return x + gate * h
class DCTFinalLayer(nn.Module):
"""Zero-initialised output projection (adopted from DiT)."""
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
"""
Small MLP decoder head for the pixel-space variant.
Takes per-patch pixel values and a per-patch conditioning vector from the
transformer backbone and predicts the denoised pixel values.
x : [B*N, P^2, C] noisy pixel values per patch position
c : [B*N, dim] backbone hidden state per patch (conditioning)
→ [B*N, P^2, C]
"""
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
dtype=dtype,
device=device,
operations=operations,
)
# Residual blocks
self.res_blocks = nn.ModuleList([
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
])
# Output projection
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
#############################################################################
# NextDiT Pixel Space #
#############################################################################
class NextDiTPixelSpace(NextDiT):
"""
Pixel-space variant of NextDiT.
Identical transformer backbone to NextDiT, but the output head is replaced
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
per patch rather than a single affine projection.
Key differences vs NextDiT:
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
• ``_forward`` stores the raw patchified pixel values before the backbone
embedding and feeds them to ``dec_net`` together with the per-patch
backbone hidden states.
• Supports optional x0 prediction via ``use_x0``.
"""
def __init__(
self,
# decoder-specific
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
use_x0: bool = False,
# all NextDiT args forwarded unchanged
**kwargs,
):
super().__init__(**kwargs)
# Remove the latent-space final layer not used in pixel space
del self.final_layer
patch_size = kwargs.get("patch_size", 2)
in_channels = kwargs.get("in_channels", 4)
dim = kwargs.get("dim", 4096)
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
dtype=kwargs.get("dtype"),
device=kwargs.get("device"),
operations=kwargs.get("operations"),
)
if use_x0:
self.register_buffer("__x0__", torch.tensor([]))
# ------------------------------------------------------------------
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
# with the pixel-space dec_net decoder.
# ------------------------------------------------------------------
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
adaln_input = t
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
pH = pW = self.patch_size
B, C, H, W = x.shape
pixel_patches = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
.flatten(3) # [B, Ht, Wt, pH*pW*C]
.flatten(1, 2) # [B, N, pH*pW*C]
)
N = pixel_patches.shape[1]
# decoder sees one token per patch: [B*N, 1, P^2*C]
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens,
ref_latents=ref_latents, ref_contexts=ref_contexts,
siglip_feats=siglip_feats, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
# img may have padding tokens beyond N; only the first N are real image patches
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
output = output.reshape(B, N, -1) # [B, N, P^2*C]
# prepend zero cap placeholder so unpatchify indexing works unchanged
cap_placeholder = torch.zeros(
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
)
img_out = self.unpatchify(
torch.cat([cap_placeholder, output], dim=1),
img_size, cap_size, return_tensor=x_is_tensor
)[:, :, :h, :w]
return -img_out
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
# _forward returns neg_x0 = -x0 (negated decoder output).
#
# Reference inference (working_inference_reference.py):
# out = _forward(img, t) # = -x0
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
# img += (t_prev - t_curr) * pred # Euler step
#
# ComfyUI's Euler sampler does the same:
# x_next = x + (sigma_next - sigma) * model_output
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)

View File

@@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class SCAILWanModel(WanModel):
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
del scail_x
# time embeddings
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.cat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
if scail_pose_seq_len > 0:
x = x[:, :-scail_pose_seq_len]
# unpatchify
x = self.unpatchify(x, grid_sizes)
if reference_latent is not None:
x = x[:, :, reference_latent.shape[2]:]
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
h_scale = h / H_pose
w_scale = w / W_pose
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
reference_latent = None
if "reference_latent" in kwargs:
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]

View File

@@ -76,6 +76,7 @@ class ModelType(Enum):
FLUX = 8
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type):
@@ -108,17 +109,11 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c):
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
pass
return ModelSampling(model_config)
@@ -998,6 +993,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@@ -1050,6 +1049,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1260,6 +1263,11 @@ class Lumina2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
@@ -1493,6 +1501,50 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class WAN21_SCAIL(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
pose_latents = self.process_latent_in(pose_latents)
pose_mask = torch.ones_like(pose_latents[:, :4])
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
@@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
dit_config["in_channels"] = 3
dit_config["decoder_in_channels"] = dec_in_ch
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@@ -498,6 +521,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -531,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
@@ -1053,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
n_layers = count_blocks(state_dict, 'layers.{}.')
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
for k in state_dict: # For zeta chroma
if k not in sd_map:
sd_map[k] = k
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')

View File

@@ -180,6 +180,14 @@ def is_ixuca():
return True
return False
def is_wsl():
version = platform.uname().release
if version.endswith("-Microsoft"):
return True
elif version.endswith("microsoft-standard-WSL2"):
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -350,7 +358,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@@ -378,7 +386,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@@ -570,13 +578,7 @@ class LoadedModel:
self._patcher_finalizer.detach()
def is_dead(self):
# Model is dead if the weakref to model has been garbage collected
# This can happen with ModelPatcherProxy objects between isolated workflows
if self.model is None:
return True
if self.real_model is None:
return False
return self.real_model() is None
return self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device):
@@ -622,7 +624,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
unloaded_model = []
can_unload = []
unloaded_models = []
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@@ -631,17 +632,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
if can_unload and isolation_active:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
flush_tensor_keeper = None
if callable(flush_tensor_keeper):
flushed = flush_tensor_keeper()
if flushed > 0:
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
gc.collect()
for x in sorted(can_unload):
i = x[-1]
memory_to_free = 1e32
@@ -649,12 +639,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
@@ -663,13 +652,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded = current_loaded_models.pop(i)
model_obj = unloaded.model
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
unloaded_models.append(unloaded)
unloaded_models.append(current_loaded_models.pop(i))
if len(unloaded_model) > 0:
soft_empty_cache()
@@ -791,28 +774,25 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
dead_found = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
dead_found = True
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
if do_gc:
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def archive_model_dtypes(model):
@@ -829,11 +809,6 @@ def cleanup_models():
for i in to_delete:
x = current_loaded_models.pop(i)
model_obj = getattr(x, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
del x
def dtype_size(dtype):

View File

@@ -308,15 +308,22 @@ class ModelPatcher:
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self, disable_dynamic=False):
def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
if model_override is None:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -325,13 +332,12 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
for k in self.attachments:
@@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
assert load_device is not None
def is_dynamic(self):
@@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
def get_non_dynamic_delegate(self):
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
return model_patcher
CoreModelPatcher = ModelPatcher

View File

@@ -66,6 +66,18 @@ def convert_cond(cond):
out.append(temp)
return out
def cond_has_hooks(cond):
for c in cond:
temp = c[1]
if "hooks" in temp:
return True
if "control" in temp:
control = temp["control"]
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
return True
return False
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets: list[ControlBase] = []

View File

@@ -11,14 +11,12 @@ from functools import partial
import collections
import math
import logging
import os
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
from comfy.cli_args import args
import scipy.stats
import numpy
@@ -215,7 +213,6 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -297,17 +294,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
patches = p.patches
batch_chunks = len(cond_or_uncond)
if isolation_active:
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
input_x = torch.cat(input_x).to(target_device)
else:
input_x = torch.cat(input_x)
input_x = torch.cat(input_x)
c = cond_cat(c)
if isolation_active:
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
else:
timestep_ = torch.cat([timestep] * batch_chunks)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
@@ -338,17 +327,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
out_t = output[o]
mult_t = mult[o]
if isolation_active:
target_dev = out_conds[cond_index].device
if hasattr(out_t, "device") and out_t.device != target_dev:
out_t = out_t.to(target_dev)
if hasattr(mult_t, "device") and mult_t.device != target_dev:
mult_t = mult_t.to(target_dev)
if a is None:
out_conds[cond_index] += out_t * mult_t
out_counts[cond_index] += mult_t
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
@@ -356,8 +337,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += out_t * mult_t
out_cts += mult_t
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
@@ -411,31 +392,14 @@ class KSamplerX0Inpaint:
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
if isolation_active:
latent_image = self.latent_image
if hasattr(latent_image, "device") and latent_image.device != x.device:
latent_image = latent_image.to(x.device)
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
if hasattr(scaled, "device") and scaled.device != x.device:
scaled = scaled.to(x.device)
else:
scaled = self.inner_model.inner_model.scale_latent_inpaint(
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
)
x = x * denoise_mask + scaled * latent_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
latent_image = self.latent_image
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
latent_image = latent_image.to(out.device)
out = out * denoise_mask + latent_image * latent_mask
out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model_sampling, steps):
@@ -982,6 +946,8 @@ class CFGGuider:
def inner_set_conds(self, conds):
for k in conds:
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):

View File

@@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
if no_init:
return
params = target.params.copy()
@@ -233,7 +233,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@@ -267,9 +268,9 @@ class CLIP:
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self):
def clone(self, disable_dynamic=False):
n = CLIP(no_init=True)
n.patcher = self.patcher.clone()
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
@@ -1164,14 +1165,21 @@ class CLIPType(Enum):
LONGCAT_IMAGE = 26
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
return clip.patcher
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
return clip
class TEModel(Enum):
@@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = state_dicts
class EmptyClass:
@@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
return clip
def load_gligen(ckpt_path):
@@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model:
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
@@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
disable_dynamic=disable_dynamic)
return model
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
embedding_directory=embedding_directory, output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return clip.patcher
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None
clipvision = None
@@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

View File

@@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class ZImagePixelSpace(ZImage):
unet_config = {
"image_model": "zimage_pixel",
}
# Pixel-space model: no spatial compression, operates on raw RGB patches.
latent_format = latent_formats.ZImagePixelSpace
# Much lower memory than latent-space models (no VAE, small patches).
memory_usage_factor = 0.05 # TODO: figure out the optimal value for this.
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -1268,6 +1282,16 @@ class WAN21_FlowRVS(WAN21_T2V):
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
return out
class WAN21_SCAIL(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1710,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model = "gemini-3.1-flash-image-preview"
if response_modalities == "IMAGE+TEXT":
raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.")
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images is not None:
@@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE"],
options=["IMAGE", "IMAGE+TEXT"],
advanced=True,
),
IO.Combo.Input(
@@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,

View File

@@ -20,7 +20,7 @@ class JobStatus:
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
@@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
normalized[node_id] = normalized_node
return normalized
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
TEXT_PREVIEW_MAX_LENGTH = 1024
def _create_text_preview(value: str) -> dict:
"""Create a text preview dict with optional truncation.
Returns:
dict with 'content' and optionally 'truncated' flag
"""
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
return {'content': value}
return {
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
'truncated': True
}
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
normalized = normalize_output_item(item)
if normalized is None:
continue
if not isinstance(item, dict):
# Handle text outputs (non-dict items like strings or tuples)
normalized = normalize_output_item(item)
if normalized is None:
# Not a 3D file string — check for text preview
if media_type == 'text':
count += 1
if preview_output is None:
if isinstance(item, tuple):
text_value = item[0] if item else ''
else:
text_value = str(item)
text_preview = _create_text_preview(text_value)
enriched = {
**text_preview,
'nodeId': node_id,
'mediaType': media_type
}
if fallback_preview is None:
fallback_preview = enriched
continue
# normalize_output_item returned a dict (e.g. 3D file)
item = normalized
count += 1
if preview_output is not None:
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
if is_previewable(media_type, item):
enriched = {
**normalized,
**item,
'nodeId': node_id,
}
if 'mediaType' not in normalized:
if 'mediaType' not in item:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched

View File

@@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)

View File

@@ -248,7 +248,7 @@ class SetClipHooks:
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
clip = clip.clone(disable_dynamic=True)
if apply_to_conds:
clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone()

View File

@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro CFG",
display_name="Positive-Biased Guidance",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
@@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
search_aliases=[
"mahiro",
"mahiro cfg",
"similarity-adaptive guidance",
"positive-biased cfg",
],
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
scale: float = args["cond_scale"]
cond_p: torch.Tensor = args["cond_denoised"]
uncond_p: torch.Tensor = args["uncond_denoised"]
# naive leap
leap = cond_p * scale
#sim with uncond leap
# sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
simsc = 2 * (sim + 1)
wm = (simsc * cfg + (4 - simsc) * leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return io.NodeOutput(m)

View File

@@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
WanSCAILToVideo,
]
async def comfy_entrypoint() -> WanExtension:

View File

@@ -92,7 +92,7 @@ if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
elif not args.use_process_isolation:
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@@ -1,9 +1,7 @@
import copy
import gc
import heapq
import inspect
import logging
import os
import sys
import threading
import time
@@ -263,31 +261,20 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# Check for isolated node - skip validation and class cloning
if hasattr(obj, "_pyisolate_extension"):
# Isolated Node: The stub is just a proxy; real validation happens in child process
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# Inject hidden inputs so they're available in the isolated child process
inputs.update(v3_data.get("hidden_inputs", {}))
f = getattr(obj, func)
# Standard V3 Node (Existing Logic)
# if is just a class, then assign no state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@@ -549,14 +536,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
# Keep isolation node execution deterministic by default, but allow
# opt-out for diagnostics.
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
if args.use_process_isolation and isolation_sequential:
await await_completion()
return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
@@ -668,22 +647,6 @@ class PromptExecutor:
self.status_messages = []
self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None:
try:
from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state()
except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
@@ -725,17 +688,6 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# Update RPC event loops for all isolated extensions
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
@@ -749,20 +701,6 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
@@ -789,13 +727,6 @@ class PromptExecutor:
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
if args.use_process_isolation:
pending_class_types = set()
for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
@@ -826,7 +757,6 @@ class PromptExecutor:
"outputs": ui_outputs,
"meta": meta_outputs,
}
comfy.model_management.cleanup_models_gc()
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()

98
main.py
View File

@@ -1,21 +1,7 @@
import os
import sys
IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1"
if __name__ == "__main__" and IS_PYISOLATE_CHILD:
del os.environ["PYISOLATE_CHILD"]
IS_PYISOLATE_CHILD = False
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
if CURRENT_DIR not in sys.path:
sys.path.insert(0, CURRENT_DIR)
IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__"
import comfy.options
comfy.options.enable_args_parsing()
import os
import importlib.util
import folder_paths
import time
@@ -23,38 +9,24 @@ from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools
import utils.extra_config
import logging
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
if '--use-process-isolation' in sys.argv:
from comfy.isolation import initialize_proxies
initialize_proxies()
import comfy_aimdo.control
# Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture)
try:
import pyisolate
from comfy.isolation.adapter import ComfyUIAdapter
pyisolate.register_adapter(ComfyUIAdapter())
logging.info("PyIsolate adapter registered: comfyui")
except ImportError:
logging.warning("PyIsolate not installed or version too old for explicit registration")
except Exception as e:
logging.error(f"Failed to register PyIsolate adapter: {e}")
if enables_dynamic_vram():
comfy_aimdo.control.init()
if not IS_PYISOLATE_CHILD:
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native'
if not IS_PYISOLATE_CHILD:
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
if IS_PRIMARY_PROCESS:
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
if not IS_PYISOLATE_CHILD:
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
@@ -106,15 +78,14 @@ if args.enable_manager:
def apply_custom_paths():
from utils import extra_config # Deferred import - spawn re-runs main.py
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
extra_config.load_extra_path_config(extra_model_paths_config_path)
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
extra_config.load_extra_path_config(config_path)
utils.extra_config.load_extra_path_config(config_path)
# --output-directory, --input-directory, --user-directory
if args.output_directory:
@@ -187,16 +158,14 @@ def execute_prestartup_script():
else:
import_message = " (PRESTARTUP FAILED)"
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
logging.info("")
if not IS_PYISOLATE_CHILD:
apply_custom_paths()
apply_custom_paths()
if args.enable_manager and not IS_PYISOLATE_CHILD:
if args.enable_manager:
comfyui_manager.prestartup()
if not IS_PYISOLATE_CHILD:
execute_prestartup_script()
execute_prestartup_script()
# Main code
@@ -208,27 +177,22 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils
if not IS_PYISOLATE_CHILD:
import execution
import server
from protocol import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
import hook_breaker_ac10a0
import execution
import server
from protocol import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
if enables_dynamic_vram():
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
@@ -420,10 +384,6 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
if args.use_process_isolation:
from comfy.isolation import start_isolation_loading_early
start_isolation_loading_early(asyncio_loop)
if args.enable_manager and not args.disable_manager_ui:
comfyui_manager.start()
@@ -468,9 +428,7 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
if not IS_PYISOLATE_CHILD:
import comfyui_version
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")

View File

@@ -1,5 +1,6 @@
import hashlib
import torch
import logging
from comfy.cli_args import args
@@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
return c
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
"""
Apply values to conditioning only during [start_percent, end_percent], keeping the
original conditioning active outside that range. Respects existing per-entry ranges.
"""
if start_percent > end_percent:
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
return conditioning
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
c = []
for t in conditioning:
cond_start = t[1].get("start_percent", 0.0)
cond_end = t[1].get("end_percent", 1.0)
intersect_start = max(start_percent, cond_start)
intersect_end = min(end_percent, cond_end)
if intersect_start >= intersect_end: # no overlap: emit unchanged
c.append(t)
continue
if intersect_start > cond_start: # part before the requested range
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
if intersect_end < cond_end: # part after the requested range
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
return c
def pillow(fn, arg):
prev_value = None
try:

View File

@@ -1925,7 +1925,6 @@ class ImageInvert:
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
ESSENTIALS_CATEGORY = "Image Tools"
@classmethod
def INPUT_TYPES(s):
@@ -2307,27 +2306,6 @@ async def init_external_custom_nodes():
Returns:
None
"""
whitelist = set()
isolated_module_paths = set()
if args.use_process_isolation:
from pathlib import Path
from comfy.isolation import await_isolation_loading, get_claimed_paths
from comfy.isolation.host_policy import load_host_policy
# Load Global Host Policy
host_policy = load_host_policy(Path(folder_paths.base_path))
whitelist_dict = host_policy.get("whitelist", {})
# Normalize whitelist keys to lowercase for case-insensitive matching
# (matches ComfyUI-Manager's normalization: project.name.strip().lower())
whitelist = set(k.strip().lower() for k in whitelist_dict.keys())
logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.")
isolated_specs = await await_isolation_loading()
for spec in isolated_specs:
NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class)
NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name)
isolated_module_paths = get_claimed_paths()
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
@@ -2351,16 +2329,6 @@ async def init_external_custom_nodes():
logging.info(f"Blocked by policy: {module_path}")
continue
if args.use_process_isolation:
if Path(module_path).resolve() in isolated_module_paths:
continue
# Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted.
# Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager)
if possible_module.strip().lower() not in whitelist:
logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).")
continue
time_before = time.perf_counter()
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success))
@@ -2375,14 +2343,6 @@ async def init_external_custom_nodes():
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
if args.use_process_isolation:
from comfy.isolation import isolated_node_timings
if isolated_node_timings:
logging.info("\nImport times for isolated custom nodes:")
for timing, path, count in sorted(isolated_node_timings):
logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count))
logging.info("")
async def init_builtin_extra_nodes():
"""
Initializes the built-in extra nodes in ComfyUI.
@@ -2475,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_audio_encoder.py",
"nodes_rope.py",
"nodes_logic.py",
"nodes_resolution.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
@@ -2482,10 +2443,12 @@ async def init_builtin_extra_nodes():
"nodes_zimage.py",
"nodes_glsl.py",
"nodes_lora_debug.py",
"nodes_textgen.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
"nodes_sdpose.py",
]
import_failed = []

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.4
comfyui-workflow-templates==0.9.5
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.2
comfy-aimdo>=0.2.4
requests
#non essential dependencies:
@@ -32,5 +32,3 @@ pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
pyisolate==0.9.1

View File

@@ -3,6 +3,7 @@ import sys
import asyncio
import traceback
import time
import nodes
import folder_paths
import execution
@@ -195,8 +196,6 @@ def create_block_external_middleware():
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
if loop is None:
loop = asyncio.get_event_loop()
mimetypes.init()
mimetypes.add_type('application/javascript; charset=utf-8', '.js')

View File

@@ -49,6 +49,12 @@ def mock_provider(mock_releases):
return provider
@pytest.fixture(autouse=True)
def clear_cache():
import utils.install_util
utils.install_util.PACKAGE_VERSIONS = {}
def test_get_release(mock_provider, mock_releases):
version = "1.0.0"
release = mock_provider.get_release(version)

View File

@@ -38,13 +38,13 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']:
"""Images, video, audio, 3d, text media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d', 'text']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
"""Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']:
for media_type in ['latents', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self):

View File

@@ -1,122 +0,0 @@
"""Tests for pyisolate._internal.client import-time snapshot handling."""
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
# Paths needed for subprocess
PYISOLATE_ROOT = str(Path(__file__).parent.parent)
COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
SCRIPT = """
import json, sys
import pyisolate._internal.client # noqa: F401 # triggers snapshot logic
print(json.dumps(sys.path[:6]))
"""
def _run_client_process(env):
# Ensure subprocess can find pyisolate and ComfyUI
pythonpath_parts = [PYISOLATE_ROOT, COMFYUI_ROOT]
existing = env.get("PYTHONPATH", "")
if existing:
pythonpath_parts.append(existing)
env["PYTHONPATH"] = ":".join(pythonpath_parts)
result = subprocess.run( # noqa: S603
[sys.executable, "-c", SCRIPT],
capture_output=True,
text=True,
env=env,
check=True,
)
stdout = result.stdout.strip().splitlines()[-1]
return json.loads(stdout)
@pytest.fixture()
def comfy_module_path(tmp_path):
comfy_root = tmp_path / "ComfyUI"
module_path = comfy_root / "custom_nodes" / "TestNode"
module_path.mkdir(parents=True)
return comfy_root, module_path
def test_snapshot_applied_and_comfy_root_prepend(tmp_path, comfy_module_path):
comfy_root, module_path = comfy_module_path
# Must include real ComfyUI path for utils validation to pass
host_paths = [COMFYUI_ROOT, "/host/lib1", "/host/lib2"]
snapshot = {
"sys_path": host_paths,
"sys_executable": sys.executable,
"sys_prefix": sys.prefix,
"environment": {},
}
snapshot_path = tmp_path / "snapshot.json"
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
env = os.environ.copy()
env.update(
{
"PYISOLATE_CHILD": "1",
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
"PYISOLATE_MODULE_PATH": str(module_path),
}
)
path_prefix = _run_client_process(env)
# Current client behavior preserves the runtime bootstrap path order and
# keeps the resolved ComfyUI root available for imports.
assert COMFYUI_ROOT in path_prefix
# Module path should not override runtime root selection.
assert str(comfy_root) not in path_prefix
def test_missing_snapshot_file_does_not_crash(tmp_path, comfy_module_path):
_, module_path = comfy_module_path
missing_snapshot = tmp_path / "missing.json"
env = os.environ.copy()
env.update(
{
"PYISOLATE_CHILD": "1",
"PYISOLATE_HOST_SNAPSHOT": str(missing_snapshot),
"PYISOLATE_MODULE_PATH": str(module_path),
}
)
# Should not raise even though snapshot path is missing
paths = _run_client_process(env)
assert len(paths) > 0
def test_no_comfy_root_when_module_path_absent(tmp_path):
# Must include real ComfyUI path for utils validation to pass
host_paths = [COMFYUI_ROOT, "/alpha", "/beta"]
snapshot = {
"sys_path": host_paths,
"sys_executable": sys.executable,
"sys_prefix": sys.prefix,
"environment": {},
}
snapshot_path = tmp_path / "snapshot.json"
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
env = os.environ.copy()
env.update(
{
"PYISOLATE_CHILD": "1",
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
}
)
paths = _run_client_process(env)
# Runtime path bootstrap keeps ComfyUI importability regardless of host
# snapshot extras.
assert COMFYUI_ROOT in paths
assert "/alpha" not in paths and "/beta" not in paths

View File

@@ -1,111 +0,0 @@
"""Unit tests for FolderPathsProxy."""
import pytest
from pathlib import Path
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
class TestFolderPathsProxy:
"""Test FolderPathsProxy methods."""
@pytest.fixture
def proxy(self):
"""Create a FolderPathsProxy instance for testing."""
return FolderPathsProxy()
def test_get_temp_directory_returns_string(self, proxy):
"""Verify get_temp_directory returns a non-empty string."""
result = proxy.get_temp_directory()
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert len(result) > 0, "Temp directory path is empty"
def test_get_temp_directory_returns_absolute_path(self, proxy):
"""Verify get_temp_directory returns an absolute path."""
result = proxy.get_temp_directory()
path = Path(result)
assert path.is_absolute(), f"Path is not absolute: {result}"
def test_get_input_directory_returns_string(self, proxy):
"""Verify get_input_directory returns a non-empty string."""
result = proxy.get_input_directory()
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert len(result) > 0, "Input directory path is empty"
def test_get_input_directory_returns_absolute_path(self, proxy):
"""Verify get_input_directory returns an absolute path."""
result = proxy.get_input_directory()
path = Path(result)
assert path.is_absolute(), f"Path is not absolute: {result}"
def test_get_annotated_filepath_plain_name(self, proxy):
"""Verify get_annotated_filepath works with plain filename."""
result = proxy.get_annotated_filepath("test.png")
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert "test.png" in result, f"Filename not in result: {result}"
def test_get_annotated_filepath_with_output_annotation(self, proxy):
"""Verify get_annotated_filepath handles [output] annotation."""
result = proxy.get_annotated_filepath("test.png[output]")
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert "test.pn" in result, f"Filename base not in result: {result}"
# Should resolve to output directory
assert "output" in result.lower() or Path(result).parent.name == "output"
def test_get_annotated_filepath_with_input_annotation(self, proxy):
"""Verify get_annotated_filepath handles [input] annotation."""
result = proxy.get_annotated_filepath("test.png[input]")
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert "test.pn" in result, f"Filename base not in result: {result}"
def test_get_annotated_filepath_with_temp_annotation(self, proxy):
"""Verify get_annotated_filepath handles [temp] annotation."""
result = proxy.get_annotated_filepath("test.png[temp]")
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert "test.pn" in result, f"Filename base not in result: {result}"
def test_exists_annotated_filepath_returns_bool(self, proxy):
"""Verify exists_annotated_filepath returns a boolean."""
result = proxy.exists_annotated_filepath("nonexistent.png")
assert isinstance(result, bool), f"Expected bool, got {type(result)}"
def test_exists_annotated_filepath_nonexistent_file(self, proxy):
"""Verify exists_annotated_filepath returns False for nonexistent file."""
result = proxy.exists_annotated_filepath("definitely_does_not_exist_12345.png")
assert result is False, "Expected False for nonexistent file"
def test_exists_annotated_filepath_with_annotation(self, proxy):
"""Verify exists_annotated_filepath works with annotation suffix."""
# Even for nonexistent files, should return bool without error
result = proxy.exists_annotated_filepath("test.png[output]")
assert isinstance(result, bool), f"Expected bool, got {type(result)}"
def test_models_dir_property_returns_string(self, proxy):
"""Verify models_dir property returns valid path string."""
result = proxy.models_dir
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert len(result) > 0, "Models directory path is empty"
def test_models_dir_is_absolute_path(self, proxy):
"""Verify models_dir returns an absolute path."""
result = proxy.models_dir
path = Path(result)
assert path.is_absolute(), f"Path is not absolute: {result}"
def test_add_model_folder_path_runs_without_error(self, proxy):
"""Verify add_model_folder_path executes without raising."""
test_path = "/tmp/test_models_florence2"
# Should not raise
proxy.add_model_folder_path("TEST_FLORENCE2", test_path)
def test_get_folder_paths_returns_list(self, proxy):
"""Verify get_folder_paths returns a list."""
# Use known folder type that should exist
result = proxy.get_folder_paths("checkpoints")
assert isinstance(result, list), f"Expected list, got {type(result)}"
def test_get_folder_paths_checkpoints_not_empty(self, proxy):
"""Verify checkpoints folder paths list is not empty."""
result = proxy.get_folder_paths("checkpoints")
# Should have at least one checkpoint path registered
assert len(result) > 0, "Checkpoints folder paths is empty"

View File

@@ -1,72 +0,0 @@
from pathlib import Path
def _write_pyproject(path: Path, content: str) -> None:
path.write_text(content, encoding="utf-8")
def test_load_host_policy_defaults_when_pyproject_missing(tmp_path):
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
policy = load_host_policy(tmp_path)
assert policy["allow_network"] == DEFAULT_POLICY["allow_network"]
assert policy["writable_paths"] == DEFAULT_POLICY["writable_paths"]
assert policy["readonly_paths"] == DEFAULT_POLICY["readonly_paths"]
assert policy["whitelist"] == DEFAULT_POLICY["whitelist"]
def test_load_host_policy_defaults_when_section_missing(tmp_path):
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[project]
name = "ComfyUI"
""".strip(),
)
policy = load_host_policy(tmp_path)
assert policy["allow_network"] == DEFAULT_POLICY["allow_network"]
assert policy["whitelist"] == {}
def test_load_host_policy_reads_values(tmp_path):
from comfy.isolation.host_policy import load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
allow_network = true
writable_paths = ["/tmp/a", "/tmp/b"]
readonly_paths = ["/opt/readonly"]
[tool.comfy.host.whitelist]
ExampleNode = "*"
""".strip(),
)
policy = load_host_policy(tmp_path)
assert policy["allow_network"] is True
assert policy["writable_paths"] == ["/tmp/a", "/tmp/b"]
assert policy["readonly_paths"] == ["/opt/readonly"]
assert policy["whitelist"] == {"ExampleNode": "*"}
def test_load_host_policy_ignores_invalid_whitelist_type(tmp_path):
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
allow_network = true
whitelist = ["bad"]
""".strip(),
)
policy = load_host_policy(tmp_path)
assert policy["allow_network"] is True
assert policy["whitelist"] == DEFAULT_POLICY["whitelist"]

View File

@@ -1,56 +0,0 @@
"""Unit tests for PyIsolate isolation system initialization."""
def test_log_prefix():
"""Verify LOG_PREFIX constant is correctly defined."""
from comfy.isolation import LOG_PREFIX
assert LOG_PREFIX == "]["
assert isinstance(LOG_PREFIX, str)
def test_module_initialization():
"""Verify module initializes without errors."""
import comfy.isolation
assert hasattr(comfy.isolation, 'LOG_PREFIX')
assert hasattr(comfy.isolation, 'initialize_proxies')
class TestInitializeProxies:
def test_initialize_proxies_runs_without_error(self):
from comfy.isolation import initialize_proxies
initialize_proxies()
def test_initialize_proxies_registers_folder_paths_proxy(self):
from comfy.isolation import initialize_proxies
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
initialize_proxies()
proxy = FolderPathsProxy()
assert proxy is not None
assert hasattr(proxy, "get_temp_directory")
def test_initialize_proxies_registers_model_management_proxy(self):
from comfy.isolation import initialize_proxies
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
initialize_proxies()
proxy = ModelManagementProxy()
assert proxy is not None
assert hasattr(proxy, "get_torch_device")
def test_initialize_proxies_can_be_called_multiple_times(self):
from comfy.isolation import initialize_proxies
initialize_proxies()
initialize_proxies()
initialize_proxies()
def test_dev_proxies_accessible_when_dev_mode(self, monkeypatch):
"""Verify dev mode does not break core proxy initialization."""
monkeypatch.setenv("PYISOLATE_DEV", "1")
from comfy.isolation import initialize_proxies
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.utils_proxy import UtilsProxy
initialize_proxies()
folder_proxy = FolderPathsProxy()
utils_proxy = UtilsProxy()
assert folder_proxy is not None
assert utils_proxy is not None

View File

@@ -1,434 +0,0 @@
"""
Unit tests for manifest_loader.py cache functions.
Phase 1 tests verify:
1. Cache miss on first run (no cache exists)
2. Cache hit when nothing changes
3. Invalidation on .py file touch
4. Invalidation on manifest change
5. Cache location correctness (in venv_root, NOT in custom_nodes)
6. Corrupt cache handling (graceful failure)
These tests verify the cache implementation is correct BEFORE it's activated
in extension_loader.py (Phase 2).
"""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
from unittest import mock
class TestComputeCacheKey:
"""Tests for compute_cache_key() function."""
def test_key_includes_manifest_content(self, tmp_path: Path) -> None:
"""Cache key changes when manifest content changes."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
# Initial manifest
manifest.write_text("isolated: true\ndependencies: []\n")
key1 = compute_cache_key(node_dir, manifest)
# Modified manifest
manifest.write_text("isolated: true\ndependencies: [numpy]\n")
key2 = compute_cache_key(node_dir, manifest)
assert key1 != key2, "Key should change when manifest content changes"
def test_key_includes_py_file_mtime(self, tmp_path: Path) -> None:
"""Cache key changes when any .py file is touched."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
py_file = node_dir / "nodes.py"
py_file.write_text("# test code")
key1 = compute_cache_key(node_dir, manifest)
# Wait a moment to ensure mtime changes
time.sleep(0.01)
py_file.write_text("# modified code")
key2 = compute_cache_key(node_dir, manifest)
assert key1 != key2, "Key should change when .py file mtime changes"
def test_key_includes_python_version(self, tmp_path: Path) -> None:
"""Cache key changes when Python version changes."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
key1 = compute_cache_key(node_dir, manifest)
# Mock different Python version
with mock.patch.object(sys, "version", "3.99.0 (fake)"):
key2 = compute_cache_key(node_dir, manifest)
assert key1 != key2, "Key should change when Python version changes"
def test_key_includes_pyisolate_version(self, tmp_path: Path) -> None:
"""Cache key changes when PyIsolate version changes."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
key1 = compute_cache_key(node_dir, manifest)
# Mock different pyisolate version
with mock.patch.dict(sys.modules, {"pyisolate": mock.MagicMock(__version__="99.99.99")}):
# Need to reimport to pick up the mock
import importlib
from comfy.isolation import manifest_loader
importlib.reload(manifest_loader)
key2 = manifest_loader.compute_cache_key(node_dir, manifest)
# Keys should be different (though the mock approach is tricky)
# At minimum, verify key is a valid hex string
assert len(key1) == 16, "Key should be 16 hex characters"
assert all(c in "0123456789abcdef" for c in key1), "Key should be hex"
assert len(key2) == 16, "Key should be 16 hex characters"
assert all(c in "0123456789abcdef" for c in key2), "Key should be hex"
def test_key_excludes_pycache(self, tmp_path: Path) -> None:
"""Cache key ignores __pycache__ directory changes."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
py_file = node_dir / "nodes.py"
py_file.write_text("# test code")
key1 = compute_cache_key(node_dir, manifest)
# Add __pycache__ file
pycache = node_dir / "__pycache__"
pycache.mkdir()
(pycache / "nodes.cpython-310.pyc").write_bytes(b"compiled")
key2 = compute_cache_key(node_dir, manifest)
assert key1 == key2, "Key should NOT change when __pycache__ modified"
def test_key_is_deterministic(self, tmp_path: Path) -> None:
"""Same inputs produce same key."""
from comfy.isolation.manifest_loader import compute_cache_key
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
(node_dir / "nodes.py").write_text("# code")
key1 = compute_cache_key(node_dir, manifest)
key2 = compute_cache_key(node_dir, manifest)
assert key1 == key2, "Key should be deterministic"
class TestGetCachePath:
"""Tests for get_cache_path() function."""
def test_returns_correct_paths(self, tmp_path: Path) -> None:
"""Cache paths are in venv_root, not in node_dir."""
from comfy.isolation.manifest_loader import get_cache_path
node_dir = tmp_path / "custom_nodes" / "MyNode"
venv_root = tmp_path / ".pyisolate_venvs"
key_file, data_file = get_cache_path(node_dir, venv_root)
assert key_file == venv_root / "MyNode" / "cache" / "cache_key"
assert data_file == venv_root / "MyNode" / "cache" / "node_info.json"
def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None:
"""Verify cache is NOT stored in custom_nodes directory."""
from comfy.isolation.manifest_loader import get_cache_path
node_dir = tmp_path / "custom_nodes" / "MyNode"
venv_root = tmp_path / ".pyisolate_venvs"
key_file, data_file = get_cache_path(node_dir, venv_root)
# Neither path should be under node_dir
assert not str(key_file).startswith(str(node_dir))
assert not str(data_file).startswith(str(node_dir))
class TestIsCacheValid:
"""Tests for is_cache_valid() function."""
def test_false_when_no_cache_exists(self, tmp_path: Path) -> None:
"""Returns False when cache files don't exist."""
from comfy.isolation.manifest_loader import is_cache_valid
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
assert is_cache_valid(node_dir, manifest, venv_root) is False
def test_true_when_cache_matches(self, tmp_path: Path) -> None:
"""Returns True when cache key matches current state."""
from comfy.isolation.manifest_loader import (
compute_cache_key,
get_cache_path,
is_cache_valid,
)
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
(node_dir / "nodes.py").write_text("# code")
venv_root = tmp_path / ".pyisolate_venvs"
# Create valid cache
cache_key = compute_cache_key(node_dir, manifest)
key_file, data_file = get_cache_path(node_dir, venv_root)
key_file.parent.mkdir(parents=True, exist_ok=True)
key_file.write_text(cache_key)
data_file.write_text("{}")
assert is_cache_valid(node_dir, manifest, venv_root) is True
def test_false_when_key_mismatch(self, tmp_path: Path) -> None:
"""Returns False when stored key doesn't match current state."""
from comfy.isolation.manifest_loader import get_cache_path, is_cache_valid
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
# Create cache with wrong key
key_file, data_file = get_cache_path(node_dir, venv_root)
key_file.parent.mkdir(parents=True, exist_ok=True)
key_file.write_text("wrong_key_12345")
data_file.write_text("{}")
assert is_cache_valid(node_dir, manifest, venv_root) is False
def test_false_when_data_file_missing(self, tmp_path: Path) -> None:
"""Returns False when node_info.json is missing."""
from comfy.isolation.manifest_loader import (
compute_cache_key,
get_cache_path,
is_cache_valid,
)
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
# Create only key file, not data file
cache_key = compute_cache_key(node_dir, manifest)
key_file, _ = get_cache_path(node_dir, venv_root)
key_file.parent.mkdir(parents=True, exist_ok=True)
key_file.write_text(cache_key)
assert is_cache_valid(node_dir, manifest, venv_root) is False
def test_invalidation_on_py_change(self, tmp_path: Path) -> None:
"""Cache invalidates when .py file is modified."""
from comfy.isolation.manifest_loader import (
compute_cache_key,
get_cache_path,
is_cache_valid,
)
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
py_file = node_dir / "nodes.py"
py_file.write_text("# original")
venv_root = tmp_path / ".pyisolate_venvs"
# Create valid cache
cache_key = compute_cache_key(node_dir, manifest)
key_file, data_file = get_cache_path(node_dir, venv_root)
key_file.parent.mkdir(parents=True, exist_ok=True)
key_file.write_text(cache_key)
data_file.write_text("{}")
# Verify cache is valid initially
assert is_cache_valid(node_dir, manifest, venv_root) is True
# Modify .py file
time.sleep(0.01) # Ensure mtime changes
py_file.write_text("# modified")
# Cache should now be invalid
assert is_cache_valid(node_dir, manifest, venv_root) is False
class TestLoadFromCache:
"""Tests for load_from_cache() function."""
def test_returns_none_when_no_cache(self, tmp_path: Path) -> None:
"""Returns None when cache doesn't exist."""
from comfy.isolation.manifest_loader import load_from_cache
node_dir = tmp_path / "test_node"
venv_root = tmp_path / ".pyisolate_venvs"
assert load_from_cache(node_dir, venv_root) is None
def test_returns_data_when_valid(self, tmp_path: Path) -> None:
"""Returns cached data when file exists and is valid JSON."""
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
node_dir = tmp_path / "test_node"
venv_root = tmp_path / ".pyisolate_venvs"
test_data = {"TestNode": {"inputs": [], "outputs": []}}
_, data_file = get_cache_path(node_dir, venv_root)
data_file.parent.mkdir(parents=True, exist_ok=True)
data_file.write_text(json.dumps(test_data))
result = load_from_cache(node_dir, venv_root)
assert result == test_data
def test_returns_none_on_corrupt_json(self, tmp_path: Path) -> None:
"""Returns None when JSON is corrupt."""
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
node_dir = tmp_path / "test_node"
venv_root = tmp_path / ".pyisolate_venvs"
_, data_file = get_cache_path(node_dir, venv_root)
data_file.parent.mkdir(parents=True, exist_ok=True)
data_file.write_text("{ corrupt json }")
assert load_from_cache(node_dir, venv_root) is None
def test_returns_none_on_invalid_structure(self, tmp_path: Path) -> None:
"""Returns None when data is not a dict."""
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
node_dir = tmp_path / "test_node"
venv_root = tmp_path / ".pyisolate_venvs"
_, data_file = get_cache_path(node_dir, venv_root)
data_file.parent.mkdir(parents=True, exist_ok=True)
data_file.write_text("[1, 2, 3]") # Array, not dict
assert load_from_cache(node_dir, venv_root) is None
class TestSaveToCache:
"""Tests for save_to_cache() function."""
def test_creates_cache_directory(self, tmp_path: Path) -> None:
"""Creates cache directory if it doesn't exist."""
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest)
key_file, data_file = get_cache_path(node_dir, venv_root)
assert key_file.parent.exists()
def test_writes_both_files(self, tmp_path: Path) -> None:
"""Writes both cache_key and node_info.json."""
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
save_to_cache(node_dir, venv_root, {"TestNode": {"key": "value"}}, manifest)
key_file, data_file = get_cache_path(node_dir, venv_root)
assert key_file.exists()
assert data_file.exists()
def test_data_is_valid_json(self, tmp_path: Path) -> None:
"""Written data can be parsed as JSON."""
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
test_data = {"TestNode": {"inputs": ["IMAGE"], "outputs": ["IMAGE"]}}
save_to_cache(node_dir, venv_root, test_data, manifest)
_, data_file = get_cache_path(node_dir, venv_root)
loaded = json.loads(data_file.read_text())
assert loaded == test_data
def test_roundtrip_with_validation(self, tmp_path: Path) -> None:
"""Saved cache is immediately valid."""
from comfy.isolation.manifest_loader import (
is_cache_valid,
load_from_cache,
save_to_cache,
)
node_dir = tmp_path / "test_node"
node_dir.mkdir()
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
(node_dir / "nodes.py").write_text("# code")
venv_root = tmp_path / ".pyisolate_venvs"
test_data = {"TestNode": {"foo": "bar"}}
save_to_cache(node_dir, venv_root, test_data, manifest)
assert is_cache_valid(node_dir, manifest, venv_root) is True
assert load_from_cache(node_dir, venv_root) == test_data
def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None:
"""Verify no files written to custom_nodes directory."""
from comfy.isolation.manifest_loader import save_to_cache
node_dir = tmp_path / "custom_nodes" / "MyNode"
node_dir.mkdir(parents=True)
manifest = node_dir / "pyisolate.yaml"
manifest.write_text("isolated: true\n")
venv_root = tmp_path / ".pyisolate_venvs"
save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest)
# Check nothing was created under node_dir
for item in node_dir.iterdir():
assert item.name == "pyisolate.yaml", f"Unexpected file in node_dir: {item}"

View File

@@ -1,50 +0,0 @@
"""Unit tests for ModelManagementProxy."""
import pytest
import torch
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
class TestModelManagementProxy:
"""Test ModelManagementProxy methods."""
@pytest.fixture
def proxy(self):
"""Create a ModelManagementProxy instance for testing."""
return ModelManagementProxy()
def test_get_torch_device_returns_device(self, proxy):
"""Verify get_torch_device returns a torch.device object."""
result = proxy.get_torch_device()
assert isinstance(result, torch.device), f"Expected torch.device, got {type(result)}"
def test_get_torch_device_is_valid(self, proxy):
"""Verify get_torch_device returns a valid device (cpu or cuda)."""
result = proxy.get_torch_device()
assert result.type in ("cpu", "cuda"), f"Unexpected device type: {result.type}"
def test_get_torch_device_name_returns_string(self, proxy):
"""Verify get_torch_device_name returns a non-empty string."""
device = proxy.get_torch_device()
result = proxy.get_torch_device_name(device)
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert len(result) > 0, "Device name is empty"
def test_get_torch_device_name_with_cpu(self, proxy):
"""Verify get_torch_device_name works with CPU device."""
cpu_device = torch.device("cpu")
result = proxy.get_torch_device_name(cpu_device)
assert isinstance(result, str), f"Expected str, got {type(result)}"
assert "cpu" in result.lower(), f"Expected 'cpu' in device name, got: {result}"
def test_get_torch_device_name_with_cuda_if_available(self, proxy):
"""Verify get_torch_device_name works with CUDA device if available."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
cuda_device = torch.device("cuda:0")
result = proxy.get_torch_device_name(cuda_device)
assert isinstance(result, str), f"Expected str, got {type(result)}"
# Should contain device identifier
assert len(result) > 0, "CUDA device name is empty"

View File

@@ -1,93 +0,0 @@
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
import pytest
from pyisolate.path_helpers import build_child_sys_path, serialize_host_snapshot
def test_serialize_host_snapshot_includes_expected_keys(tmp_path: Path, monkeypatch) -> None:
output = tmp_path / "snapshot.json"
monkeypatch.setenv("EXTRA_FLAG", "1")
snapshot = serialize_host_snapshot(output_path=output, extra_env_keys=["EXTRA_FLAG"])
assert "sys_path" in snapshot
assert "sys_executable" in snapshot
assert "sys_prefix" in snapshot
assert "environment" in snapshot
assert output.exists()
assert snapshot["environment"].get("EXTRA_FLAG") == "1"
persisted = json.loads(output.read_text(encoding="utf-8"))
assert persisted["sys_path"] == snapshot["sys_path"]
def test_build_child_sys_path_preserves_host_order() -> None:
host_paths = ["/host/root", "/host/site-packages"]
extra_paths = ["/node/.venv/lib/python3.12/site-packages"]
result = build_child_sys_path(host_paths, extra_paths, preferred_root=None)
assert result == host_paths + extra_paths
def test_build_child_sys_path_inserts_comfy_root_when_missing() -> None:
host_paths = ["/host/site-packages"]
comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
extra_paths: list[str] = []
result = build_child_sys_path(host_paths, extra_paths, preferred_root=comfy_root)
assert result[0] == comfy_root
assert result[1:] == host_paths
def test_build_child_sys_path_deduplicates_entries(tmp_path: Path) -> None:
path_a = str(tmp_path / "a")
path_b = str(tmp_path / "b")
host_paths = [path_a, path_b]
extra_paths = [path_a, path_b, str(tmp_path / "c")]
result = build_child_sys_path(host_paths, extra_paths)
assert result == [path_a, path_b, str(tmp_path / "c")]
def test_build_child_sys_path_skips_duplicate_comfy_root() -> None:
comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
host_paths = [comfy_root, "/host/other"]
result = build_child_sys_path(host_paths, extra_paths=[], preferred_root=comfy_root)
assert result == host_paths
def test_child_import_succeeds_after_path_unification(tmp_path: Path, monkeypatch) -> None:
host_root = tmp_path / "host"
utils_pkg = host_root / "utils"
app_pkg = host_root / "app"
utils_pkg.mkdir(parents=True)
app_pkg.mkdir(parents=True)
(utils_pkg / "__init__.py").write_text("from . import install_util\n", encoding="utf-8")
(utils_pkg / "install_util.py").write_text("VALUE = 'hello'\n", encoding="utf-8")
(app_pkg / "__init__.py").write_text("", encoding="utf-8")
(app_pkg / "frontend_management.py").write_text(
"from utils import install_util\nVALUE = install_util.VALUE\n",
encoding="utf-8",
)
child_only = tmp_path / "child_only"
child_only.mkdir()
target_module = "app.frontend_management"
for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]:
sys.modules.pop(name)
monkeypatch.setattr(sys, "path", [str(child_only)])
with pytest.raises(ModuleNotFoundError):
__import__(target_module)
for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]:
sys.modules.pop(name)
unified = build_child_sys_path([], [], preferred_root=str(host_root))
monkeypatch.setattr(sys, "path", unified)
module = __import__(target_module, fromlist=["VALUE"])
assert module.VALUE == "hello"

View File

@@ -1,51 +0,0 @@
import os
import sys
from pathlib import Path
repo_root = Path(__file__).resolve().parents[1]
pyisolate_root = repo_root.parent / "pyisolate"
if pyisolate_root.exists():
sys.path.insert(0, str(pyisolate_root))
from comfy.isolation.adapter import ComfyUIAdapter
from pyisolate._internal.serialization_registry import SerializerRegistry
def test_identifier():
adapter = ComfyUIAdapter()
assert adapter.identifier == "comfyui"
def test_get_path_config_valid():
adapter = ComfyUIAdapter()
path = os.path.join("/opt", "ComfyUI", "custom_nodes", "demo")
cfg = adapter.get_path_config(path)
assert cfg is not None
assert cfg["preferred_root"].endswith("ComfyUI")
assert "custom_nodes" in cfg["additional_paths"][0]
def test_get_path_config_invalid():
adapter = ComfyUIAdapter()
assert adapter.get_path_config("/random/path") is None
def test_provide_rpc_services():
adapter = ComfyUIAdapter()
services = adapter.provide_rpc_services()
names = {s.__name__ for s in services}
assert "PromptServerService" in names
assert "FolderPathsProxy" in names
def test_register_serializers():
adapter = ComfyUIAdapter()
registry = SerializerRegistry.get_instance()
registry.clear()
adapter.register_serializers(registry)
assert registry.has_handler("ModelPatcher")
assert registry.has_handler("CLIP")
assert registry.has_handler("VAE")
registry.clear()

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import sys
import logging
import re
# The path to the requirements.txt file
requirements_path = Path(__file__).parents[1] / "requirements.txt"
@@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {requirements_path}
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
""".strip()
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
PACKAGE_VERSIONS = {}
def get_required_packages_versions():
if len(PACKAGE_VERSIONS) > 0:
return PACKAGE_VERSIONS.copy()
out = PACKAGE_VERSIONS
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().replace(">=", "==")
s = line.split("==")
if len(s) == 2:
version_str = s[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
continue
out[s[0]] = version_str
return out.copy()
except FileNotFoundError:
logging.error("requirements.txt not found.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None