mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-02 22:29:01 +00:00
Compare commits
142 Commits
v3/model_m
...
pyisolate-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0f8784e9f | ||
|
|
95e1059661 | ||
|
|
80d49441e5 | ||
|
|
9d0e114ee3 | ||
|
|
ac4412d0fa | ||
|
|
94f1a1cc9d | ||
|
|
e721e24136 | ||
|
|
25ec3d96a3 | ||
|
|
7962db477a | ||
|
|
3c8ba051b6 | ||
|
|
a1c3124821 | ||
|
|
9ca799362d | ||
|
|
22f5e43c12 | ||
|
|
3cfd5e3311 | ||
|
|
1f1ec377ce | ||
|
|
0a7f8e11b6 | ||
|
|
35e9fce775 | ||
|
|
c7f7d52b68 | ||
|
|
08b26ed7c2 | ||
|
|
b233dbe0bc | ||
|
|
3811780e4f | ||
|
|
3dd10a59c0 | ||
|
|
88d05fe483 | ||
|
|
fd41ec97cc | ||
|
|
420e900f69 | ||
|
|
38ca94599f | ||
|
|
74b5a337dc | ||
|
|
8a4d85c708 | ||
|
|
a4522017c5 | ||
|
|
907e5dcbbf | ||
|
|
7253531670 | ||
|
|
e14b04478c | ||
|
|
eb8737d675 | ||
|
|
0467f690a8 | ||
|
|
4f5b7dbf1f | ||
|
|
3ebe1ac22e | ||
|
|
befa83d434 | ||
|
|
33f83d53ae | ||
|
|
b874bd2b8c | ||
|
|
0aa02453bb | ||
|
|
599f9c5010 | ||
|
|
11fefa58e9 | ||
|
|
d8090013b8 | ||
|
|
048dd2f321 | ||
|
|
84aba95e03 | ||
|
|
9b1c63eb69 | ||
|
|
7a7debcaf1 | ||
|
|
dba2766e53 | ||
|
|
caa43d2395 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 | ||
|
|
5fb8f06495 | ||
|
|
5a182bfaf1 | ||
|
|
f394af8d0f | ||
|
|
aeb5bdc8f6 | ||
|
|
64953bda0a | ||
|
|
b254cecd03 | ||
|
|
1bb956fb66 | ||
|
|
96d6bd1a4a | ||
|
|
5f2117528a | ||
|
|
0301ccf745 | ||
|
|
4d172e9ad7 | ||
|
|
5632b2df9d | ||
|
|
2687652530 | ||
|
|
6d11cc7354 | ||
|
|
f262444dd4 | ||
|
|
239ddd3327 | ||
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf | ||
|
|
19236edfa4 | ||
|
|
73c3f86973 | ||
|
|
262abf437b | ||
|
|
5284e6bf69 | ||
|
|
44f8598521 | ||
|
|
fe52843fe5 | ||
|
|
c39653163d | ||
|
|
18927538a1 | ||
|
|
8a6fbc2dc2 | ||
|
|
b44fc4c589 | ||
|
|
4454fab7f0 | ||
|
|
1978f59ffd | ||
|
|
88e6370527 | ||
|
|
c0370044cd | ||
|
|
ecd2a19661 | ||
|
|
2c1d06a4e3 | ||
|
|
e2c71ceb00 | ||
|
|
596ed68691 | ||
|
|
ce4a1ab48d | ||
|
|
e1ede29d82 | ||
|
|
df1e5e8514 | ||
|
|
dc9822b7df | ||
|
|
712efb466b | ||
|
|
726af73867 | ||
|
|
831351a29e | ||
|
|
e1add563f9 | ||
|
|
8902907d7a | ||
|
|
e03fe8b591 | ||
|
|
ae79e33345 | ||
|
|
117e214354 | ||
|
|
4a93a62371 | ||
|
|
66c18522fb | ||
|
|
e5ae670a40 | ||
|
|
3fe61cedda | ||
|
|
2a4328d639 | ||
|
|
d297a749a2 | ||
|
|
2b7cc7e3b6 | ||
|
|
4993411fd9 | ||
|
|
2c7cef4a23 | ||
|
|
76a7fa96db | ||
|
|
cdcf4119b3 | ||
|
|
dbe70b6821 | ||
|
|
00fff6019e | ||
|
|
123a7874a9 | ||
|
|
f719f9c062 | ||
|
|
fe053ba5eb | ||
|
|
6648ab68bc | ||
|
|
6615db925c | ||
|
|
8ca842a8ed | ||
|
|
c1b63a7e78 | ||
|
|
349a636a2b | ||
|
|
a4be04c5d7 | ||
|
|
baf8c87455 | ||
|
|
62315fbb15 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd |
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
@@ -0,0 +1,127 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: false
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
sequence_diagrams: false
|
||||
estimate_code_review_effort: false
|
||||
assess_linked_issues: false
|
||||
related_issues: false
|
||||
related_prs: false
|
||||
suggested_labels: false
|
||||
auto_apply_labels: false
|
||||
suggested_reviewers: false
|
||||
auto_assign_reviewers: false
|
||||
in_progress_fortune: false
|
||||
enable_prompt_for_ai_agents: true
|
||||
|
||||
path_filters:
|
||||
- "!comfy_api_nodes/apis/**"
|
||||
- "!**/generated/*.pyi"
|
||||
- "!.ci/**"
|
||||
- "!script_examples/**"
|
||||
- "!**/__pycache__/**"
|
||||
- "!**/*.ipynb"
|
||||
- "!**/*.png"
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
- Backward compatibility (breaking changes affect all custom nodes)
|
||||
- Memory management and GPU resource handling
|
||||
- Performance implications in hot paths
|
||||
- Thread safety for concurrent execution
|
||||
- path: "comfy_api_nodes/**"
|
||||
instructions: |
|
||||
Third-party API integration nodes. Focus on:
|
||||
- No hardcoded API keys or secrets
|
||||
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||
- Correct Pydantic model usage
|
||||
- Security of user data passed to external APIs
|
||||
- path: "comfy_extras/**"
|
||||
instructions: |
|
||||
Community-contributed extra nodes. Focus on:
|
||||
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||
- No breaking changes to existing node interfaces
|
||||
- path: "comfy_execution/**"
|
||||
instructions: |
|
||||
Execution engine (graph execution, caching, jobs). Focus on:
|
||||
- Caching correctness
|
||||
- Concurrent execution safety
|
||||
- Graph validation edge cases
|
||||
- path: "nodes.py"
|
||||
instructions: |
|
||||
Core node definitions (2500+ lines). Focus on:
|
||||
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||
- Consistency of INPUT_TYPES return format
|
||||
- path: "alembic_db/**"
|
||||
instructions: |
|
||||
Database migrations. Focus on:
|
||||
- Migration safety and rollback support
|
||||
- Data preservation during schema changes
|
||||
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
enabled: false
|
||||
gitleaks:
|
||||
enabled: true
|
||||
shellcheck:
|
||||
enabled: false
|
||||
markdownlint:
|
||||
enabled: false
|
||||
yamllint:
|
||||
enabled: false
|
||||
languagetool:
|
||||
enabled: false
|
||||
github-checks:
|
||||
enabled: true
|
||||
timeout_ms: 90000
|
||||
ast-grep:
|
||||
essential_rules: true
|
||||
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
learnings:
|
||||
scope: "auto"
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
venv*/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
@@ -227,11 +225,11 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
107
app/node_replace_manager.py
Normal file
107
app/node_replace_manager.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
|
||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
@@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -179,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -4,6 +4,25 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
if not is_equal(self.cond, other.cond):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
327
comfy/isolation/__init__.py
Normal file
327
comfy/isolation/__init__.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
import folder_paths
|
||||
from .extension_loader import load_isolated_node
|
||||
from .manifest_loader import find_manifest_directories
|
||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyisolate import ExtensionManager
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||
|
||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolatedNodeSpec:
|
||||
node_name: str
|
||||
display_name: str
|
||||
stub_class: type
|
||||
module_path: Path
|
||||
|
||||
|
||||
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||
_CLAIMED_PATHS: Set[Path] = set()
|
||||
_ISOLATION_SCAN_ATTEMPTED = False
|
||||
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||
_EARLY_START_TIME: Optional[float] = None
|
||||
|
||||
|
||||
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
return
|
||||
_EARLY_START_TIME = time.perf_counter()
|
||||
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||
|
||||
|
||||
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
specs = await _ISOLATION_BACKGROUND_TASK
|
||||
return specs
|
||||
return await initialize_isolation_nodes()
|
||||
|
||||
|
||||
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||
|
||||
if _ISOLATED_NODE_SPECS:
|
||||
return _ISOLATED_NODE_SPECS
|
||||
|
||||
if _ISOLATION_SCAN_ATTEMPTED:
|
||||
return []
|
||||
|
||||
_ISOLATION_SCAN_ATTEMPTED = True
|
||||
manifest_entries = find_manifest_directories()
|
||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||
|
||||
if not manifest_entries:
|
||||
return []
|
||||
|
||||
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def load_with_semaphore(
|
||||
node_dir: Path, manifest: Path
|
||||
) -> List[IsolatedNodeSpec]:
|
||||
async with semaphore:
|
||||
load_start = time.perf_counter()
|
||||
spec_list = await load_isolated_node(
|
||||
node_dir,
|
||||
manifest,
|
||||
logger,
|
||||
lambda name, info, extension: build_stub_class(
|
||||
name,
|
||||
info,
|
||||
extension,
|
||||
_RUNNING_EXTENSIONS,
|
||||
logger,
|
||||
),
|
||||
PYISOLATE_VENV_ROOT,
|
||||
_EXTENSION_MANAGERS,
|
||||
)
|
||||
spec_list = [
|
||||
IsolatedNodeSpec(
|
||||
node_name=node_name,
|
||||
display_name=display_name,
|
||||
stub_class=stub_cls,
|
||||
module_path=node_dir,
|
||||
)
|
||||
for node_name, display_name, stub_cls in spec_list
|
||||
]
|
||||
isolated_node_timings.append(
|
||||
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||
)
|
||||
return spec_list
|
||||
|
||||
tasks = [
|
||||
load_with_semaphore(node_dir, manifest)
|
||||
for node_dir, manifest in manifest_entries
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
specs: List[IsolatedNodeSpec] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"%s Isolated node failed during startup; continuing: %s",
|
||||
LOG_PREFIX,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
specs.extend(result)
|
||||
|
||||
_ISOLATED_NODE_SPECS = specs
|
||||
return list(_ISOLATED_NODE_SPECS)
|
||||
|
||||
|
||||
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
"""Get all node class types (node names) belonging to an extension."""
|
||||
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in _ISOLATED_NODE_SPECS:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
|
||||
return class_types
|
||||
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
) -> None:
|
||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph → evict
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if getattr(device, "type", None) == "cuda":
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.unload_all_models()
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
finally:
|
||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
if not callable(flush_fn):
|
||||
continue
|
||||
try:
|
||||
flushed = await flush_fn()
|
||||
if isinstance(flushed, int):
|
||||
total_flushed += flushed
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush released=%d",
|
||||
LOG_PREFIX,
|
||||
ext_name,
|
||||
flushed,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||
)
|
||||
scan_shm_forensics(
|
||||
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||
)
|
||||
return total_flushed
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
|
||||
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||
"""Update all active RPC instances with the current event loop.
|
||||
|
||||
This MUST be called at the start of each workflow execution to ensure
|
||||
RPC calls are scheduled on the correct event loop. This handles the case
|
||||
where asyncio.run() creates a new event loop for each workflow.
|
||||
|
||||
Args:
|
||||
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||
"""
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
update_count = 0
|
||||
|
||||
# Update RPCs from ExtensionManagers
|
||||
for manager in _EXTENSION_MANAGERS:
|
||||
if not hasattr(manager, "extensions"):
|
||||
continue
|
||||
for name, extension in manager.extensions.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||
|
||||
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||
|
||||
if update_count > 0:
|
||||
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"initialize_proxies",
|
||||
"initialize_isolation_nodes",
|
||||
"start_isolation_loading_early",
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
"get_class_types_for_extension",
|
||||
]
|
||||
505
comfy/isolation/adapter.py
Normal file
505
comfy/isolation/adapter.py
Normal file
@@ -0,0 +1,505 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||
from comfy.isolation.model_patcher_proxy import (
|
||||
ModelPatcherProxy,
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingProxy,
|
||||
ModelSamplingRegistry,
|
||||
)
|
||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
except ImportError as exc: # Fail loud if Comfy environment is incomplete
|
||||
raise ImportError(f"ComfyUI environment incomplete: {exc}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||
import tempfile
|
||||
|
||||
if os.path.exists("/dev/shm"):
|
||||
# Only override if not already set or if default is not /dev/shm
|
||||
current_tmp = tempfile.gettempdir()
|
||||
if not current_tmp.startswith("/dev/shm"):
|
||||
logger.debug(
|
||||
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||
)
|
||||
os.environ["TMPDIR"] = "/dev/shm"
|
||||
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||
|
||||
|
||||
class ComfyUIAdapter(IsolationAdapter):
|
||||
# ComfyUI-specific IsolationAdapter implementation
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "comfyui"
|
||||
|
||||
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||
parts = module_path.split("ComfyUI")
|
||||
if len(parts) > 1:
|
||||
comfy_root = parts[0] + "ComfyUI"
|
||||
return {
|
||||
"preferred_root": comfy_root,
|
||||
"additional_paths": [
|
||||
os.path.join(comfy_root, "custom_nodes"),
|
||||
os.path.join(comfy_root, "comfy"),
|
||||
],
|
||||
}
|
||||
return None
|
||||
|
||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||
comfy_root = snapshot.get("preferred_root")
|
||||
if not comfy_root:
|
||||
return
|
||||
|
||||
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||
if requirements_path.exists():
|
||||
import re
|
||||
|
||||
for line in requirements_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||
if pkg_name:
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelPatcher in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with registry
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
model_id = ModelPatcherRegistry().register(obj)
|
||||
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||
|
||||
def deserialize_model_patcher(data: Any) -> Any:
|
||||
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
return data
|
||||
|
||||
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
else:
|
||||
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||
|
||||
# Register ModelPatcher type for serialization
|
||||
registry.register(
|
||||
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||
registry.register(
|
||||
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||
|
||||
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||
clip_id = CLIPRegistry().register(obj)
|
||||
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||
|
||||
def deserialize_clip(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
return data
|
||||
|
||||
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
else:
|
||||
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||
|
||||
# Register CLIP type for serialization
|
||||
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||
# Register CLIPProxy type (already a proxy, just return ref)
|
||||
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||
|
||||
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||
vae_id = VAERegistry().register(obj)
|
||||
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||
|
||||
def deserialize_vae(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return VAEProxy(data["vae_id"])
|
||||
return data
|
||||
|
||||
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware VAERef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
# Child: create a proxy
|
||||
return VAEProxy(data["vae_id"])
|
||||
else:
|
||||
# Host: lookup real VAE from registry
|
||||
return VAERegistry()._get_instance(data["vae_id"])
|
||||
|
||||
# Register VAE type for serialization
|
||||
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||
# Register VAEProxy type (already a proxy, just return ref)
|
||||
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||
# Register VAERef for deserialization (context-aware: host or child)
|
||||
registry.register("VAERef", None, deserialize_vae_ref)
|
||||
|
||||
# ModelSampling serialization - handles ModelSampling* types
|
||||
# copyreg removed - no pickle fallback allowed
|
||||
|
||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelSampling in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
|
||||
def deserialize_model_sampling(data: Any) -> Any:
|
||||
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
return data
|
||||
|
||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register ModelSampling type and proxy
|
||||
registry.register(
|
||||
"ModelSamplingDiscrete",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousEDM",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousV",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||
|
||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"cond": obj.cond,
|
||||
}
|
||||
|
||||
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(data["cond"])
|
||||
|
||||
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if callable(value):
|
||||
continue
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"state": _serialize_public_state(obj),
|
||||
}
|
||||
|
||||
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
obj = cls()
|
||||
for key, value in data.get("state", {}).items():
|
||||
prop = getattr(type(obj), key, None)
|
||||
if isinstance(prop, property) and prop.fset is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
import comfy.conds
|
||||
|
||||
for cond_cls in vars(comfy.conds).values():
|
||||
if not isinstance(cond_cls, type):
|
||||
continue
|
||||
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||
continue
|
||||
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||
|
||||
import comfy.latent_formats
|
||||
|
||||
for latent_cls in vars(comfy.latent_formats).values():
|
||||
if not isinstance(latent_cls, type):
|
||||
continue
|
||||
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||
continue
|
||||
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||
registry.register(
|
||||
type_key, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
registry.register(
|
||||
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
|
||||
# V3 API: unwrap NodeOutput.args
|
||||
def deserialize_node_output(data: Any) -> Any:
|
||||
return getattr(data, "args", data)
|
||||
|
||||
registry.register("NodeOutput", None, deserialize_node_output)
|
||||
|
||||
# KSAMPLER serializer: stores sampler name instead of function object
|
||||
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||
func_name = obj.sampler_function.__name__
|
||||
# Map function name back to sampler name
|
||||
if func_name == "sample_unipc":
|
||||
sampler_name = "uni_pc"
|
||||
elif func_name == "sample_unipc_bh2":
|
||||
sampler_name = "uni_pc_bh2"
|
||||
elif func_name == "dpm_fast_function":
|
||||
sampler_name = "dpm_fast"
|
||||
elif func_name == "dpm_adaptive_function":
|
||||
sampler_name = "dpm_adaptive"
|
||||
elif func_name.startswith("sample_"):
|
||||
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||
else:
|
||||
sampler_name = func_name
|
||||
return {
|
||||
"__type__": "KSAMPLER",
|
||||
"sampler_name": sampler_name,
|
||||
"extra_options": obj.extra_options,
|
||||
"inpaint_options": obj.inpaint_options,
|
||||
}
|
||||
|
||||
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||
import comfy.samplers
|
||||
|
||||
return comfy.samplers.ksampler(
|
||||
data["sampler_name"],
|
||||
data.get("extra_options", {}),
|
||||
data.get("inpaint_options", {}),
|
||||
)
|
||||
|
||||
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||
|
||||
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
registry.register("ndarray", serialize_numpy, None)
|
||||
|
||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||
return [
|
||||
PromptServerService,
|
||||
FolderPathsProxy,
|
||||
ModelManagementProxy,
|
||||
UtilsProxy,
|
||||
ProgressProxy,
|
||||
VAERegistry,
|
||||
CLIPRegistry,
|
||||
ModelPatcherRegistry,
|
||||
ModelSamplingRegistry,
|
||||
FirstStageModelRegistry,
|
||||
]
|
||||
|
||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||
|
||||
if api_name == "FolderPathsProxy":
|
||||
import folder_paths
|
||||
|
||||
# Replace module-level functions with proxy methods
|
||||
# This is aggressive but necessary for transparent proxying
|
||||
# Handle both instance and class cases
|
||||
instance = api() if isinstance(api, type) else api
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(folder_paths, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "ModelManagementProxy":
|
||||
import comfy.model_management
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
# Replace module-level functions with proxy methods
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(comfy.model_management, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "UtilsProxy":
|
||||
import comfy.utils
|
||||
|
||||
# Static Injection of RPC mechanism to ensure Child can access it
|
||||
# independent of instance lifecycle.
|
||||
api.set_rpc(rpc)
|
||||
|
||||
# Don't overwrite host hook (infinite recursion)
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
141
comfy/isolation/child_hooks.py
Normal file
141
comfy/isolation/child_hooks.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||
# Child process initialization for PyIsolate
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
# Manual RPC injection
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc:
|
||||
_setup_prompt_server_stub(rpc)
|
||||
_setup_utils_proxy(rpc)
|
||||
else:
|
||||
logger.warning("Could not get child RPC instance for manual injection")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
except Exception as e:
|
||||
logger.error(f"Manual RPC Injection failed: {e}")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Mock server module
|
||||
if "server" not in sys.modules:
|
||||
mock_server = types.ModuleType("server")
|
||||
sys.modules["server"] = mock_server
|
||||
|
||||
server = sys.modules["server"]
|
||||
|
||||
if not hasattr(server, "PromptServer"):
|
||||
|
||||
class MockPromptServer:
|
||||
pass
|
||||
|
||||
server.PromptServer = MockPromptServer
|
||||
|
||||
stub = PromptServerStub()
|
||||
|
||||
if rpc:
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
if hasattr(stub, "set_rpc"):
|
||||
stub.set_rpc(rpc)
|
||||
|
||||
server.PromptServer.instance = stub
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||
|
||||
|
||||
def _setup_utils_proxy(rpc=None) -> None:
|
||||
try:
|
||||
import comfy.utils
|
||||
import asyncio
|
||||
|
||||
# Capture main loop during initialization (safe context)
|
||||
main_loop = None
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .proxies.base import set_global_loop
|
||||
|
||||
if main_loop:
|
||||
set_global_loop(main_loop)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Sync hook wrapper for progress updates
|
||||
def sync_hook_wrapper(
|
||||
value: int, total: int, preview: None = None, node_id: None = None
|
||||
) -> None:
|
||||
if node_id is None:
|
||||
try:
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
ctx = get_executing_context()
|
||||
if ctx:
|
||||
node_id = ctx.node_id
|
||||
else:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Bypass blocked event loop by direct outbox injection
|
||||
if rpc:
|
||||
try:
|
||||
# Use captured main loop if available (for threaded execution), or current loop
|
||||
loop = main_loop
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
rpc.outbox.put(
|
||||
{
|
||||
"kind": "call",
|
||||
"object_id": "UtilsProxy",
|
||||
"parent_call_id": None, # We are root here usually
|
||||
"calling_loop": loop,
|
||||
"future": loop.create_future(), # Dummy future
|
||||
"method": "progress_bar_hook",
|
||||
"args": (value, total, preview, node_id),
|
||||
"kwargs": {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(
|
||||
"No RPC instance available for progress update"
|
||||
)
|
||||
|
||||
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup UtilsProxy hook: {e}")
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
327
comfy/isolation/clip_proxy.py
Normal file
327
comfy/isolation/clip_proxy.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
|
||||
# CLIP Proxy implementation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
class CondStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "cond_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
|
||||
_registry_class = CondStageModelRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CondStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class TokenizerRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "tokenizer"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
|
||||
_registry_class = TokenizerRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TokenizerProxy {self._instance_id}>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "clip"
|
||||
_allowed_setters = {
|
||||
"layer_idx",
|
||||
"tokenizer_options",
|
||||
"use_clip_schedule",
|
||||
"apply_hooks_to_conds",
|
||||
}
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
|
||||
|
||||
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
|
||||
|
||||
async def get_cond_stage_model_id(self, instance_id: str) -> str:
|
||||
return CondStageModelRegistry().register(
|
||||
self._get_instance(instance_id).cond_stage_model
|
||||
)
|
||||
|
||||
async def get_tokenizer_id(self, instance_id: str) -> str:
|
||||
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
|
||||
|
||||
async def load_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).load_model()
|
||||
|
||||
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
|
||||
self._get_instance(instance_id).clip_layer(layer_idx)
|
||||
|
||||
async def set_tokenizer_option(
|
||||
self, instance_id: str, option_name: str, value: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
|
||||
if name not in self._allowed_setters:
|
||||
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
|
||||
setattr(self._get_instance(instance_id), name, value)
|
||||
|
||||
async def tokenize(
|
||||
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).tokenize(
|
||||
text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
async def encode(self, instance_id: str, text: str) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(text))
|
||||
|
||||
async def encode_from_tokens(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
return_pooled: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens(
|
||||
tokens, return_pooled=return_pooled, return_dict=return_dict
|
||||
)
|
||||
)
|
||||
|
||||
async def encode_from_tokens_scheduled(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens_scheduled(
|
||||
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
|
||||
)
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch=strength_patch, strength_model=strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_key_patches()
|
||||
|
||||
async def load_sd(
|
||||
self, instance_id: str, sd: dict, full_model: bool = False
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
|
||||
|
||||
async def get_sd(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_sd()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
return self.register(self._get_instance(instance_id).clone())
|
||||
|
||||
|
||||
class CLIPProxy(BaseProxy[CLIPRegistry]):
|
||||
_registry_class = CLIPRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
@property
|
||||
def patcher(self) -> "ModelPatcherProxy":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@patcher.setter
|
||||
def patcher(self, value: Any) -> None:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if isinstance(value, ModelPatcherProxy):
|
||||
self._patcher_proxy = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self) -> CondStageModelProxy:
|
||||
if not hasattr(self, "_cond_stage_model_proxy"):
|
||||
csm_id = self._call_rpc("get_cond_stage_model_id")
|
||||
self._cond_stage_model_proxy = CondStageModelProxy(
|
||||
csm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._cond_stage_model_proxy
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerProxy:
|
||||
if not hasattr(self, "_tokenizer_proxy"):
|
||||
tok_id = self._call_rpc("get_tokenizer_id")
|
||||
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
|
||||
return self._tokenizer_proxy
|
||||
|
||||
def load_model(self) -> ModelPatcherProxy:
|
||||
self._call_rpc("load_model")
|
||||
return self.patcher
|
||||
|
||||
@property
|
||||
def layer_idx(self) -> Optional[int]:
|
||||
return self._call_rpc("get_property", "layer_idx")
|
||||
|
||||
@layer_idx.setter
|
||||
def layer_idx(self, value: Optional[int]) -> None:
|
||||
self._call_rpc("set_property", "layer_idx", value)
|
||||
|
||||
@property
|
||||
def tokenizer_options(self) -> dict:
|
||||
return self._call_rpc("get_property", "tokenizer_options")
|
||||
|
||||
@tokenizer_options.setter
|
||||
def tokenizer_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_property", "tokenizer_options", value)
|
||||
|
||||
@property
|
||||
def use_clip_schedule(self) -> bool:
|
||||
return self._call_rpc("get_property", "use_clip_schedule")
|
||||
|
||||
@use_clip_schedule.setter
|
||||
def use_clip_schedule(self, value: bool) -> None:
|
||||
self._call_rpc("set_property", "use_clip_schedule", value)
|
||||
|
||||
@property
|
||||
def apply_hooks_to_conds(self) -> Any:
|
||||
return self._call_rpc("get_property", "apply_hooks_to_conds")
|
||||
|
||||
@apply_hooks_to_conds.setter
|
||||
def apply_hooks_to_conds(self, value: Any) -> None:
|
||||
self._call_rpc("set_property", "apply_hooks_to_conds", value)
|
||||
|
||||
def clip_layer(self, layer_idx: int) -> None:
|
||||
return self._call_rpc("clip_layer", layer_idx)
|
||||
|
||||
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
|
||||
return self._call_rpc("set_tokenizer_option", option_name, value)
|
||||
|
||||
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(
|
||||
"tokenize", text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> Any:
|
||||
return self._call_rpc("encode", text)
|
||||
|
||||
def encode_from_tokens(
|
||||
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
|
||||
) -> Any:
|
||||
res = self._call_rpc(
|
||||
"encode_from_tokens",
|
||||
tokens,
|
||||
return_pooled=return_pooled,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if return_pooled and isinstance(res, list) and not return_dict:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
def encode_from_tokens_scheduled(
|
||||
self,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return self._call_rpc(
|
||||
"encode_from_tokens_scheduled",
|
||||
tokens,
|
||||
unprojected=unprojected,
|
||||
add_dict=add_dict,
|
||||
show_pbar=show_pbar,
|
||||
)
|
||||
|
||||
def add_patches(
|
||||
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"add_patches",
|
||||
patches,
|
||||
strength_patch=strength_patch,
|
||||
strength_model=strength_model,
|
||||
)
|
||||
|
||||
def get_key_patches(self) -> Any:
|
||||
return self._call_rpc("get_key_patches")
|
||||
|
||||
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
|
||||
return self._call_rpc("load_sd", sd, full_model=full_model)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def clone(self) -> CLIPProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
|
||||
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
|
||||
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()
|
||||
248
comfy/isolation/extension_loader.py
Normal file
248
comfy/isolation/extension_loader.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _stop_extension_safe(
|
||||
extension: ComfyNodeExtension, extension_name: str
|
||||
) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
logger.info(f"][ Loading isolated node: {extension_name}")
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
ComfyNodeExtension, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
extension_config = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.debug(
|
||||
"][ %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"][ {extension_name} loaded from cache")
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
673
comfy/isolation/extension_wrapper.py
Normal file
673
comfy/isolation/extension_wrapper.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(handler, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
result = result.args
|
||||
if self._is_comfy_protocol_return(result):
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(result),
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.detach() if data.requires_grad else data
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = data
|
||||
return RemoteObjectHandle(object_id, type(data).__name__)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
f"Failed to call initialize_proxies in before_module_loaded: {e}"
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
26
comfy/isolation/host_hooks.py
Normal file
26
comfy/isolation/host_hooks.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# Host process initialization for PyIsolate
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_host_process() -> None:
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
root.addHandler(logging.NullHandler())
|
||||
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerService
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
from .vae_proxy import VAERegistry
|
||||
|
||||
FolderPathsProxy()
|
||||
ModelManagementProxy()
|
||||
ProgressProxy()
|
||||
PromptServerService()
|
||||
UtilsProxy()
|
||||
VAERegistry()
|
||||
83
comfy/isolation/host_policy.py
Normal file
83
comfy/isolation/host_policy.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm", "/tmp"],
|
||||
"readonly_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_path = comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
|
||||
|
||||
logger.debug(
|
||||
f"Loaded Host Policy: {len(policy['whitelist'])} whitelisted nodes, Network={policy['allow_network']}"
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
186
comfy/isolation/manifest_loader.py
Normal file
186
comfy/isolation/manifest_loader.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_SUBDIR = "cache"
|
||||
CACHE_KEY_FILE = "cache_key"
|
||||
CACHE_DATA_FILE = "node_info.json"
|
||||
CACHE_KEY_LENGTH = 16
|
||||
|
||||
|
||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||
|
||||
# Standard custom_nodes paths
|
||||
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||
base = Path(base_path)
|
||||
if not base.exists() or not base.is_dir():
|
||||
continue
|
||||
|
||||
for entry in base.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Look for pyproject.toml
|
||||
manifest = entry / "pyproject.toml"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
|
||||
# Validate [tool.comfy.isolation] section existence
|
||||
try:
|
||||
with manifest.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
if (
|
||||
"tool" in data
|
||||
and "comfy" in data["tool"]
|
||||
and "isolation" in data["tool"]["comfy"]
|
||||
):
|
||||
manifest_dirs.append((entry, manifest))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return manifest_dirs
|
||||
|
||||
|
||||
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
try:
|
||||
# Hashing the manifest content ensures config changes invalidate cache
|
||||
hasher.update(manifest_path.read_bytes())
|
||||
except OSError:
|
||||
hasher.update(b"__manifest_read_error__")
|
||||
|
||||
try:
|
||||
py_files = sorted(node_dir.rglob("*.py"))
|
||||
for py_file in py_files:
|
||||
rel_path = py_file.relative_to(node_dir)
|
||||
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||
continue
|
||||
hasher.update(str(rel_path).encode("utf-8"))
|
||||
try:
|
||||
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||
except OSError:
|
||||
hasher.update(b"__file_stat_error__")
|
||||
except OSError:
|
||||
hasher.update(b"__dir_scan_error__")
|
||||
|
||||
hasher.update(sys.version.encode("utf-8"))
|
||||
|
||||
try:
|
||||
import pyisolate
|
||||
|
||||
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||
except (ImportError, AttributeError):
|
||||
hasher.update(b"__pyisolate_unknown__")
|
||||
|
||||
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||
|
||||
|
||||
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||
|
||||
|
||||
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||
"""Return True only if stored cache key matches current computed key."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||
return False
|
||||
current_key = compute_cache_key(node_dir, manifest_path)
|
||||
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||
return current_key == stored_key
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||
"""Load node metadata from cache, return None on any error."""
|
||||
try:
|
||||
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_data_file.exists():
|
||||
return None
|
||||
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(
|
||||
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||
) -> None:
|
||||
"""Save node metadata and cache key atomically."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
cache_dir = cache_key_file.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||
|
||||
# Atomic write: data
|
||||
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||
json.dump(node_data, f, indent=2)
|
||||
os.replace(tmp_data_path, cache_data_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_data_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Atomic write: key
|
||||
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||
f.write(cache_key)
|
||||
os.replace(tmp_key_path, cache_key_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_key_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"find_manifest_directories",
|
||||
"compute_cache_key",
|
||||
"get_cache_path",
|
||||
"is_cache_valid",
|
||||
"load_from_cache",
|
||||
"save_to_cache",
|
||||
]
|
||||
820
comfy/isolation/model_patcher_proxy.py
Normal file
820
comfy/isolation/model_patcher_proxy.py
Normal file
@@ -0,0 +1,820 @@
|
||||
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
|
||||
# RPC proxy for ModelPatcher (parent process)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, List, Set, Dict, Callable
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
AutoPatcherEjector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
_registry_class = ModelPatcherRegistry
|
||||
__module__ = "comfy.model_patcher"
|
||||
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
else:
|
||||
self._rpc_caller = self._registry
|
||||
return self._rpc_caller
|
||||
|
||||
def get_all_callbacks(self, call_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_callbacks", call_type)
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_wrappers", wrapper_type)
|
||||
|
||||
def _load_list(self, *args, **kwargs) -> Any:
|
||||
return self._call_rpc("load_list_internal", *args, **kwargs)
|
||||
|
||||
def prepare_hook_patches_current_keyframe(
|
||||
self, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
|
||||
)
|
||||
|
||||
def add_hook_patches(
|
||||
self,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"add_hook_patches", hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
def clear_cached_hook_weights(self) -> None:
|
||||
self._call_rpc("clear_cached_hook_weights")
|
||||
|
||||
def get_combined_hook_patches(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("get_combined_hook_patches", hooks)
|
||||
|
||||
def get_additional_models_with_key(self, key: str) -> Any:
|
||||
return self._call_rpc("get_additional_models_with_key", key)
|
||||
|
||||
@property
|
||||
def object_patches(self) -> Any:
|
||||
return self._call_rpc("get_object_patches")
|
||||
|
||||
@property
|
||||
def patches(self) -> Any:
|
||||
res = self._call_rpc("get_patches")
|
||||
if isinstance(res, dict):
|
||||
new_res = {}
|
||||
for k, v in res.items():
|
||||
new_list = []
|
||||
for item in v:
|
||||
if isinstance(item, list):
|
||||
new_list.append(tuple(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_res[k] = new_list
|
||||
return new_res
|
||||
return res
|
||||
|
||||
@property
|
||||
def pinned(self) -> Set:
|
||||
val = self._call_rpc("get_patcher_attr", "pinned")
|
||||
return set(val) if val is not None else set()
|
||||
|
||||
@property
|
||||
def hook_patches(self) -> Dict:
|
||||
val = self._call_rpc("get_patcher_attr", "hook_patches")
|
||||
if val is None:
|
||||
return {}
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import json
|
||||
|
||||
new_val = {}
|
||||
for k, v in val.items():
|
||||
if isinstance(k, str):
|
||||
if k.startswith("PYISOLATE_HOOKREF:"):
|
||||
ref_id = k.split(":", 1)[1]
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
elif k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[len("__pyisolate_key__") :]
|
||||
data = json.loads(json_str)
|
||||
ref_id = None
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and item[0] == "id"
|
||||
):
|
||||
ref_id = item[1]
|
||||
break
|
||||
if ref_id:
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
except Exception:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
return new_val
|
||||
except ImportError:
|
||||
return val
|
||||
|
||||
def set_hook_mode(self, hook_mode: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", hook_mode)
|
||||
|
||||
def register_all_hook_patches(
|
||||
self,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any = None,
|
||||
registered: Any = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"register_all_hook_patches", hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
if isinstance(other, ModelPatcherProxy):
|
||||
return self._call_rpc("is_clone_by_id", other._instance_id)
|
||||
return False
|
||||
|
||||
def clone(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return ModelPatcherProxy(
|
||||
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if isinstance(clone, ModelPatcherProxy):
|
||||
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
|
||||
if not IS_CHILD_PROCESS:
|
||||
return self._call_rpc("is_clone", clone)
|
||||
return False
|
||||
|
||||
def get_model_object(self, name: str) -> Any:
|
||||
return self._call_rpc("get_model_object", name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
data = self._call_rpc("get_model_options")
|
||||
import json
|
||||
|
||||
def _decode_keys(obj):
|
||||
if isinstance(obj, dict):
|
||||
new_d = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[17:]
|
||||
val = json.loads(json_str)
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
new_d[val] = _decode_keys(v)
|
||||
except:
|
||||
new_d[k] = _decode_keys(v)
|
||||
else:
|
||||
new_d[k] = _decode_keys(v)
|
||||
return new_d
|
||||
if isinstance(obj, list):
|
||||
return [_decode_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
return _decode_keys(data)
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_model_options", value)
|
||||
|
||||
def apply_hooks(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("apply_hooks", hooks)
|
||||
|
||||
def prepare_state(self, timestep: Any) -> Any:
|
||||
return self._call_rpc("prepare_state", timestep)
|
||||
|
||||
def restore_hook_patches(self) -> None:
|
||||
self._call_rpc("restore_hook_patches")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
|
||||
self._call_rpc("unpatch_hooks", whitelist_keys_set)
|
||||
|
||||
def model_patches_to(self, device: Any) -> Any:
|
||||
return self._call_rpc("model_patches_to", device)
|
||||
|
||||
def partially_load(
|
||||
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"partially_load", device, extra_memory, force_patch_weights
|
||||
)
|
||||
|
||||
def partially_unload(
|
||||
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
|
||||
) -> int:
|
||||
return self._call_rpc(
|
||||
"partially_unload", device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
def load(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
def patch_model(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
self._call_rpc(
|
||||
"patch_model",
|
||||
device_to,
|
||||
lowvram_model_memory,
|
||||
load_weights,
|
||||
force_patch_weights,
|
||||
)
|
||||
return self
|
||||
|
||||
def unpatch_model(
|
||||
self, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._call_rpc("unpatch_model", device_to, unpatch_weights)
|
||||
|
||||
def detach(self, unpatch_all: bool = True) -> Any:
|
||||
self._call_rpc("detach", unpatch_all)
|
||||
return self.model
|
||||
|
||||
def _cpu_tensor_bytes(self, obj: Any) -> int:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device.type == "cpu":
|
||||
return obj.nbytes
|
||||
return 0
|
||||
if isinstance(obj, dict):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj)
|
||||
return 0
|
||||
|
||||
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
|
||||
if required_bytes <= 0:
|
||||
return True
|
||||
|
||||
import torch
|
||||
import comfy.model_management as model_management
|
||||
|
||||
target_raw = self.load_device
|
||||
try:
|
||||
if isinstance(target_raw, torch.device):
|
||||
target = target_raw
|
||||
elif isinstance(target_raw, str):
|
||||
target = torch.device(target_raw)
|
||||
elif isinstance(target_raw, int):
|
||||
target = torch.device(f"cuda:{target_raw}")
|
||||
else:
|
||||
target = torch.device(target_raw)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if target.type != "cuda":
|
||||
return True
|
||||
|
||||
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
|
||||
if model_management.get_free_memory(target) >= required:
|
||||
return True
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.free_memory(required, target, for_dynamic=True)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
|
||||
model_management.free_memory(required, target, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.load_models_gpu(
|
||||
[self],
|
||||
minimum_memory_required=required,
|
||||
)
|
||||
|
||||
return model_management.get_free_memory(target) >= required
|
||||
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
def _to_cuda(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
|
||||
return obj.to("cuda")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cuda(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cuda(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cuda(v) for v in obj)
|
||||
return obj
|
||||
|
||||
try:
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
return dict.fromkeys(keys, None)
|
||||
|
||||
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self._call_rpc("add_patches", *args, **kwargs)
|
||||
if isinstance(res, list):
|
||||
return [tuple(x) if isinstance(x, list) else x for x in res]
|
||||
return res
|
||||
|
||||
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
return self._call_rpc("get_key_patches", filter_prefix)
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
self._call_rpc("pin_weight_to_device", key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
self._call_rpc("unpin_weight", key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self._call_rpc("unpin_all_weights")
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
|
||||
return self._call_rpc(
|
||||
"calculate_weight", patches, weight, key, intermediate_dtype
|
||||
)
|
||||
|
||||
def inject_model(self) -> None:
|
||||
self._call_rpc("inject_model")
|
||||
|
||||
def eject_model(self) -> None:
|
||||
self._call_rpc("eject_model")
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
|
||||
return AutoPatcherEjector(
|
||||
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
|
||||
)
|
||||
|
||||
@property
|
||||
def is_injected(self) -> bool:
|
||||
return self._call_rpc("get_is_injected")
|
||||
|
||||
@property
|
||||
def skip_injection(self) -> bool:
|
||||
return self._call_rpc("get_skip_injection")
|
||||
|
||||
@skip_injection.setter
|
||||
def skip_injection(self, value: bool) -> None:
|
||||
self._call_rpc("set_skip_injection", value)
|
||||
|
||||
def clean_hooks(self) -> None:
|
||||
self._call_rpc("clean_hooks")
|
||||
|
||||
def pre_run(self) -> None:
|
||||
self._call_rpc("pre_run")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
try:
|
||||
self._call_rpc("cleanup")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcherProxy cleanup RPC failed for %s",
|
||||
self._instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def model(self) -> _InnerModelProxy:
|
||||
return _InnerModelProxy(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
_whitelisted_attrs = {
|
||||
"hook_patches_backup",
|
||||
"hook_backup",
|
||||
"cached_hook_patches",
|
||||
"current_hooks",
|
||||
"forced_hooks",
|
||||
"is_clip",
|
||||
"patches_uuid",
|
||||
"pinned",
|
||||
"attachments",
|
||||
"additional_models",
|
||||
"injections",
|
||||
"hook_patches",
|
||||
"model_lowvram",
|
||||
"model_loaded_weight_memory",
|
||||
"backup",
|
||||
"object_patches_backup",
|
||||
"weight_wrapper_patches",
|
||||
"weight_inplace_update",
|
||||
"force_cast_weights",
|
||||
}
|
||||
if name in _whitelisted_attrs:
|
||||
return self._call_rpc("get_patcher_attr", name)
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip: Optional[Any] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> tuple:
|
||||
clip_id = None
|
||||
if clip is not None:
|
||||
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
|
||||
result = self._call_rpc(
|
||||
"load_lora", lora_path, strength_model, clip_id, strength_clip
|
||||
)
|
||||
new_model = None
|
||||
if result.get("model_id"):
|
||||
new_model = ModelPatcherProxy(
|
||||
result["model_id"],
|
||||
self._registry,
|
||||
manage_lifecycle=not IS_CHILD_PROCESS,
|
||||
)
|
||||
new_clip = None
|
||||
if result.get("clip_id"):
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
new_clip = CLIPProxy(result["clip_id"])
|
||||
return (new_model, new_clip)
|
||||
|
||||
@property
|
||||
def load_device(self) -> Any:
|
||||
return self._call_rpc("get_load_device")
|
||||
|
||||
@property
|
||||
def offload_device(self) -> Any:
|
||||
return self._call_rpc("get_offload_device")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.load_device
|
||||
|
||||
def current_loaded_device(self) -> Any:
|
||||
return self._call_rpc("current_loaded_device")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._call_rpc("get_size")
|
||||
|
||||
def model_size(self) -> Any:
|
||||
return self._call_rpc("model_size")
|
||||
|
||||
def loaded_size(self) -> Any:
|
||||
return self._call_rpc("loaded_size")
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._call_rpc("lowvram_patch_counter")
|
||||
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
def get_free_memory(self, device: Any) -> Any:
|
||||
return self._call_rpc("get_free_memory", device)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload: int) -> Any:
|
||||
return self._call_rpc("partially_unload_ram", ram_to_unload)
|
||||
|
||||
def model_dtype(self) -> Any:
|
||||
res = self._call_rpc("model_dtype")
|
||||
if isinstance(res, str) and res.startswith("torch."):
|
||||
try:
|
||||
import torch
|
||||
|
||||
attr = res.split(".")[-1]
|
||||
if hasattr(torch, attr):
|
||||
return getattr(torch, attr)
|
||||
except ImportError:
|
||||
pass
|
||||
return res
|
||||
|
||||
@property
|
||||
def hook_mode(self) -> Any:
|
||||
return self._call_rpc("get_hook_mode")
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", value)
|
||||
|
||||
def set_model_sampler_cfg_function(
|
||||
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_cfg_function",
|
||||
sampler_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_post_cfg_function(
|
||||
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_post_cfg_function",
|
||||
post_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(
|
||||
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_pre_cfg_function",
|
||||
pre_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
|
||||
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
|
||||
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
|
||||
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
|
||||
|
||||
def set_model_patch(self, patch: Any, name: str) -> None:
|
||||
self._call_rpc("set_model_patch", patch, name)
|
||||
|
||||
def set_model_patch_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_patch_replace",
|
||||
patch,
|
||||
name,
|
||||
block_name,
|
||||
number,
|
||||
transformer_index,
|
||||
)
|
||||
|
||||
def set_model_attn1_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn1", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn2_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn2", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def set_model_emb_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "emb_patch")
|
||||
|
||||
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(
|
||||
self,
|
||||
scale_x=1.0,
|
||||
shift_x=0.0,
|
||||
scale_y=1.0,
|
||||
shift_y=0.0,
|
||||
scale_t=1.0,
|
||||
shift_t=0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
options = {
|
||||
"scale_x": scale_x,
|
||||
"shift_x": shift_x,
|
||||
"scale_y": scale_y,
|
||||
"shift_y": shift_y,
|
||||
"scale_t": scale_t,
|
||||
"shift_t": shift_t,
|
||||
}
|
||||
options.update(kwargs)
|
||||
self._call_rpc("set_model_rope_options", options)
|
||||
|
||||
def set_model_compute_dtype(self, dtype: Any) -> None:
|
||||
self._call_rpc("set_model_compute_dtype", dtype)
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any) -> None:
|
||||
self._call_rpc("add_object_patch", name, obj)
|
||||
|
||||
def add_weight_wrapper(self, name: str, function: Any) -> None:
|
||||
self._call_rpc("add_weight_wrapper", name, function)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
|
||||
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
|
||||
|
||||
@property
|
||||
def wrappers(self) -> Any:
|
||||
return self._call_rpc("get_wrappers")
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
|
||||
self._call_rpc("add_callback_with_key", call_type, key, callback)
|
||||
|
||||
def add_callback(self, call_type: str, callback: Any) -> None:
|
||||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_callbacks_with_key", call_type, key)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Any:
|
||||
return self._call_rpc("get_callbacks")
|
||||
|
||||
def set_attachments(self, key: str, attachment: Any) -> None:
|
||||
self._call_rpc("set_attachments", key, attachment)
|
||||
|
||||
def get_attachment(self, key: str) -> Any:
|
||||
return self._call_rpc("get_attachment", key)
|
||||
|
||||
def remove_attachments(self, key: str) -> None:
|
||||
self._call_rpc("remove_attachments", key)
|
||||
|
||||
def set_injections(self, key: str, injections: Any) -> None:
|
||||
self._call_rpc("set_injections", key, injections)
|
||||
|
||||
def get_injections(self, key: str) -> Any:
|
||||
return self._call_rpc("get_injections", key)
|
||||
|
||||
def remove_injections(self, key: str) -> None:
|
||||
self._call_rpc("remove_injections", key)
|
||||
|
||||
def set_additional_models(self, key: str, models: Any) -> None:
|
||||
ids = [m._instance_id for m in models]
|
||||
self._call_rpc("set_additional_models", key, ids)
|
||||
|
||||
def remove_additional_models(self, key: str) -> None:
|
||||
self._call_rpc("remove_additional_models", key)
|
||||
|
||||
def get_nested_additional_models(self) -> Any:
|
||||
return self._call_rpc("get_nested_additional_models")
|
||||
|
||||
def get_additional_models(self) -> List[ModelPatcherProxy]:
|
||||
ids = self._call_rpc("get_additional_models")
|
||||
return [
|
||||
ModelPatcherProxy(
|
||||
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
for mid in ids
|
||||
]
|
||||
|
||||
def model_patches_models(self) -> Any:
|
||||
return self._call_rpc("model_patches_models")
|
||||
|
||||
@property
|
||||
def parent(self) -> Any:
|
||||
return self._call_rpc("get_parent")
|
||||
|
||||
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
if name in (
|
||||
"model_config",
|
||||
"latent_format",
|
||||
"model_type",
|
||||
"current_weight_patches_uuid",
|
||||
):
|
||||
return self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if name == "load_device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "load_device")
|
||||
if name == "device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "device")
|
||||
if name == "current_patcher":
|
||||
return ModelPatcherProxy(
|
||||
self._parent._instance_id,
|
||||
self._parent._registry,
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if name == "model_sampling":
|
||||
return self._parent._call_rpc("get_model_object", "model_sampling")
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
)
|
||||
if name == "extra_conds":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds", a, k
|
||||
)
|
||||
if name == "memory_required":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_memory_required", a, k
|
||||
)
|
||||
if name == "apply_model":
|
||||
# Delegate to parent's method to get the CPU->CUDA optimization
|
||||
return self._parent.apply_model
|
||||
if name == "process_latent_in":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
|
||||
if name == "process_latent_out":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
|
||||
if name == "scale_latent_inpaint":
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
875
comfy/isolation/model_patcher_proxy_registry.py
Normal file
875
comfy/isolation/model_patcher_proxy_registry.py
Normal file
@@ -0,0 +1,875 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import
|
||||
# RPC server for ModelPatcher isolation (child process)
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from typing import Any, Optional, List
|
||||
|
||||
try:
|
||||
from comfy.model_patcher import AutoPatcherEjector
|
||||
except ImportError:
|
||||
|
||||
class AutoPatcherEjector:
|
||||
def __init__(self, model, skip_and_inject_on_exit_only=False):
|
||||
self.model = model
|
||||
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||
self.prev_skip_injection = False
|
||||
self.was_injected = False
|
||||
|
||||
def __enter__(self):
|
||||
self.was_injected = False
|
||||
self.prev_skip_injection = self.model.skip_injection
|
||||
if self.skip_and_inject_on_exit_only:
|
||||
self.model.skip_injection = True
|
||||
if self.model.is_injected:
|
||||
self.model.eject_model()
|
||||
self.was_injected = True
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.skip_and_inject_on_exit_only:
|
||||
self.model.skip_injection = self.prev_skip_injection
|
||||
self.model.inject_model()
|
||||
if self.was_injected and not self.model.skip_injection:
|
||||
self.model.inject_model()
|
||||
self.model.skip_injection = self.prev_skip_injection
|
||||
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "model"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._pending_cleanup_ids: set[str] = set()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
instance = self._get_instance(instance_id)
|
||||
new_model = instance.clone()
|
||||
return self.register(new_model)
|
||||
|
||||
async def is_clone(self, instance_id: str, other: Any) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(other, "model"):
|
||||
return instance.is_clone(other)
|
||||
return False
|
||||
|
||||
async def get_model_object(self, instance_id: str, name: str) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if name == "model":
|
||||
return f"<ModelObject: {type(instance.model).__name__}>"
|
||||
result = instance.get_model_object(name)
|
||||
if name == "model_sampling":
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingRegistry,
|
||||
ModelSamplingProxy,
|
||||
)
|
||||
|
||||
registry = ModelSamplingRegistry()
|
||||
sampling_id = registry.register(result)
|
||||
return ModelSamplingProxy(sampling_id, registry)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def get_model_options(self, instance_id: str) -> dict:
|
||||
instance = self._get_instance(instance_id)
|
||||
import copy
|
||||
|
||||
opts = copy.deepcopy(instance.model_options)
|
||||
return self._sanitize_rpc_result(opts)
|
||||
|
||||
async def set_model_options(self, instance_id: str, options: dict) -> None:
|
||||
self._get_instance(instance_id).model_options = options
|
||||
|
||||
async def get_patcher_attr(self, instance_id: str, name: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), name, None)
|
||||
)
|
||||
|
||||
async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
sd_keys = instance.model.state_dict().keys()
|
||||
return dict.fromkeys(sd_keys, None)
|
||||
|
||||
def _sanitize_rpc_result(self, obj, seen=None):
|
||||
if seen is None:
|
||||
seen = set()
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (bool, int, float, str)):
|
||||
if isinstance(obj, str) and len(obj) > 500000:
|
||||
return f"<Truncated String len={len(obj)}>"
|
||||
return obj
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return None
|
||||
seen.add(obj_id)
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [self._sanitize_rpc_result(x, seen) for x in obj]
|
||||
if isinstance(obj, set):
|
||||
return [self._sanitize_rpc_result(x, seen) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
new_dict = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, tuple):
|
||||
import json
|
||||
|
||||
try:
|
||||
key_str = "__pyisolate_key__" + json.dumps(list(k))
|
||||
new_dict[key_str] = self._sanitize_rpc_result(v, seen)
|
||||
except Exception:
|
||||
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
|
||||
else:
|
||||
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
|
||||
return new_dict
|
||||
if (
|
||||
hasattr(obj, "__dict__")
|
||||
and not hasattr(obj, "__get__")
|
||||
and not hasattr(obj, "__call__")
|
||||
):
|
||||
return self._sanitize_rpc_result(obj.__dict__, seen)
|
||||
if hasattr(obj, "items") and hasattr(obj, "get"):
|
||||
return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()}
|
||||
return None
|
||||
|
||||
async def get_load_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).load_device
|
||||
|
||||
async def get_offload_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).offload_device
|
||||
|
||||
async def current_loaded_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).current_loaded_device()
|
||||
|
||||
async def get_size(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).size
|
||||
|
||||
async def model_size(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_size()
|
||||
|
||||
async def loaded_size(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).loaded_size()
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def lowvram_patch_counter(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).lowvram_patch_counter()
|
||||
|
||||
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
|
||||
return self._get_instance(instance_id).memory_required(input_shape)
|
||||
|
||||
async def is_dynamic(self, instance_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "is_dynamic"):
|
||||
return bool(instance.is_dynamic())
|
||||
return False
|
||||
|
||||
async def get_free_memory(self, instance_id: str, device: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "get_free_memory"):
|
||||
return instance.get_free_memory(device)
|
||||
import comfy.model_management
|
||||
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "partially_unload_ram"):
|
||||
return instance.partially_unload_ram(ram_to_unload)
|
||||
return None
|
||||
|
||||
async def model_dtype(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_dtype()
|
||||
|
||||
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
|
||||
return self._get_instance(instance_id).model_patches_to(device)
|
||||
|
||||
async def partially_load(
|
||||
self,
|
||||
instance_id: str,
|
||||
device: Any,
|
||||
extra_memory: Any,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
)
|
||||
|
||||
async def partially_unload(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any,
|
||||
memory_to_free: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
) -> int:
|
||||
return self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
async def load(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
async def patch_model(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
|
||||
async def unpatch_model(
|
||||
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
|
||||
|
||||
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
|
||||
self._get_instance(instance_id).detach(unpatch_all)
|
||||
|
||||
async def prepare_state(self, instance_id: str, timestep: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
cp = getattr(instance.model, "current_patcher", instance)
|
||||
if cp is None:
|
||||
cp = instance
|
||||
return cp.prepare_state(timestep)
|
||||
|
||||
async def pre_run(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).pre_run()
|
||||
|
||||
async def cleanup(self, instance_id: str) -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
|
||||
def sweep_pending_cleanup(self) -> int:
|
||||
removed = 0
|
||||
with self._lock:
|
||||
pending_ids = list(self._pending_cleanup_ids)
|
||||
self._pending_cleanup_ids.clear()
|
||||
for instance_id in pending_ids:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance is None:
|
||||
continue
|
||||
self._id_map.pop(id(instance), None)
|
||||
removed += 1
|
||||
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
def purge_all(self) -> int:
|
||||
with self._lock:
|
||||
removed = len(self._registry)
|
||||
self._registry.clear()
|
||||
self._id_map.clear()
|
||||
self._pending_cleanup_ids.clear()
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
async def apply_hooks(self, instance_id: str, hooks: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
cp = getattr(instance.model, "current_patcher", instance)
|
||||
if cp is None:
|
||||
cp = instance
|
||||
return cp.apply_hooks(hooks=hooks)
|
||||
|
||||
async def clean_hooks(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).clean_hooks()
|
||||
|
||||
async def restore_hook_patches(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).restore_hook_patches()
|
||||
|
||||
async def unpatch_hooks(
|
||||
self, instance_id: str, whitelist_keys_set: Optional[set] = None
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set)
|
||||
|
||||
async def register_all_hook_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any,
|
||||
registered: Any,
|
||||
) -> None:
|
||||
from types import SimpleNamespace
|
||||
import comfy.hooks
|
||||
|
||||
instance = self._get_instance(instance_id)
|
||||
if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"):
|
||||
hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks
|
||||
new_hooks = comfy.hooks.HookGroup()
|
||||
if hasattr(hook_data, "hooks"):
|
||||
new_hooks.hooks = (
|
||||
hook_data["hooks"]
|
||||
if isinstance(hook_data, dict)
|
||||
else hook_data.hooks
|
||||
)
|
||||
hooks = new_hooks
|
||||
instance.register_all_hook_patches(
|
||||
hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
async def get_hook_mode(self, instance_id: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), "hook_mode", None)
|
||||
|
||||
async def set_hook_mode(self, instance_id: str, value: Any) -> None:
|
||||
setattr(self._get_instance(instance_id), "hook_mode", value)
|
||||
|
||||
async def inject_model(self, instance_id: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
try:
|
||||
instance.inject_model()
|
||||
except AttributeError as e:
|
||||
if "inject" in str(e):
|
||||
logger.error(
|
||||
"Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping."
|
||||
)
|
||||
return
|
||||
raise e
|
||||
|
||||
async def eject_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).eject_model()
|
||||
|
||||
async def get_is_injected(self, instance_id: str) -> bool:
|
||||
return self._get_instance(instance_id).is_injected
|
||||
|
||||
async def set_skip_injection(self, instance_id: str, value: bool) -> None:
|
||||
self._get_instance(instance_id).skip_injection = value
|
||||
|
||||
async def get_skip_injection(self, instance_id: str) -> bool:
|
||||
return self._get_instance(instance_id).skip_injection
|
||||
|
||||
async def set_model_sampler_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
sampler_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
if not callable(sampler_cfg_function):
|
||||
logger.error(
|
||||
f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping."
|
||||
)
|
||||
return
|
||||
self._get_instance(instance_id).set_model_sampler_cfg_function(
|
||||
sampler_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_post_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
post_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_post_cfg_function(
|
||||
post_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_pre_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
pre_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_pre_cfg_function(
|
||||
pre_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_calc_cond_batch_function(
|
||||
self, instance_id: str, fn: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn)
|
||||
|
||||
async def set_model_unet_function_wrapper(
|
||||
self, instance_id: str, unet_wrapper_function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_unet_function_wrapper(
|
||||
unet_wrapper_function
|
||||
)
|
||||
|
||||
async def set_model_denoise_mask_function(
|
||||
self, instance_id: str, denoise_mask_function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_denoise_mask_function(
|
||||
denoise_mask_function
|
||||
)
|
||||
|
||||
async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None:
|
||||
self._get_instance(instance_id).set_model_patch(patch, name)
|
||||
|
||||
async def set_model_patch_replace(
|
||||
self,
|
||||
instance_id: str,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_patch_replace(
|
||||
patch, name, block_name, number, transformer_index
|
||||
)
|
||||
|
||||
async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_input_block_patch(patch)
|
||||
|
||||
async def set_model_input_block_patch_after_skip(
|
||||
self, instance_id: str, patch: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch)
|
||||
|
||||
async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_output_block_patch(patch)
|
||||
|
||||
async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_emb_patch(patch)
|
||||
|
||||
async def set_model_forward_timestep_embed_patch(
|
||||
self, instance_id: str, patch: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch)
|
||||
|
||||
async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_double_block_patch(patch)
|
||||
|
||||
async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_post_input_patch(patch)
|
||||
|
||||
async def set_model_rope_options(self, instance_id: str, options: dict) -> None:
|
||||
self._get_instance(instance_id).set_model_rope_options(**options)
|
||||
|
||||
async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_compute_dtype(dtype)
|
||||
|
||||
async def clone_has_same_weights_by_id(
|
||||
self, instance_id: str, other_id: str
|
||||
) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
other = self._get_instance(other_id)
|
||||
if not other:
|
||||
return False
|
||||
return instance.clone_has_same_weights(other)
|
||||
|
||||
async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any:
|
||||
return self._get_instance(instance_id)._load_list(*args, **kwargs)
|
||||
|
||||
async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
other = self._get_instance(other_id)
|
||||
if hasattr(instance, "is_clone"):
|
||||
return instance.is_clone(other)
|
||||
return False
|
||||
|
||||
async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None:
|
||||
self._get_instance(instance_id).add_object_patch(name, obj)
|
||||
|
||||
async def add_weight_wrapper(
|
||||
self, instance_id: str, name: str, function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_weight_wrapper(name, function)
|
||||
|
||||
async def add_wrapper_with_key(
|
||||
self, instance_id: str, wrapper_type: Any, key: str, fn: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn)
|
||||
|
||||
async def remove_wrappers_with_key(
|
||||
self, instance_id: str, wrapper_type: str, key: str
|
||||
) -> None:
|
||||
self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key)
|
||||
|
||||
async def get_wrappers(
|
||||
self, instance_id: str, wrapper_type: str = None, key: str = None
|
||||
) -> Any:
|
||||
if wrapper_type is None and key is None:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "wrappers", {})
|
||||
)
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_wrappers(wrapper_type, key)
|
||||
)
|
||||
|
||||
async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])(
|
||||
wrapper_type
|
||||
)
|
||||
)
|
||||
|
||||
async def add_callback_with_key(
|
||||
self, instance_id: str, call_type: str, key: str, callback: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_callback_with_key(call_type, key, callback)
|
||||
|
||||
async def remove_callbacks_with_key(
|
||||
self, instance_id: str, call_type: str, key: str
|
||||
) -> None:
|
||||
self._get_instance(instance_id).remove_callbacks_with_key(call_type, key)
|
||||
|
||||
async def get_callbacks(
|
||||
self, instance_id: str, call_type: str = None, key: str = None
|
||||
) -> Any:
|
||||
if call_type is None and key is None:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "callbacks", {})
|
||||
)
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_callbacks(call_type, key)
|
||||
)
|
||||
|
||||
async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])(
|
||||
call_type
|
||||
)
|
||||
)
|
||||
|
||||
async def set_attachments(
|
||||
self, instance_id: str, key: str, attachment: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_attachments(key, attachment)
|
||||
|
||||
async def get_attachment(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_attachment(key)
|
||||
)
|
||||
|
||||
async def remove_attachments(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_attachments(key)
|
||||
|
||||
async def set_injections(self, instance_id: str, key: str, injections: Any) -> None:
|
||||
self._get_instance(instance_id).set_injections(key, injections)
|
||||
|
||||
async def get_injections(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_injections(key)
|
||||
)
|
||||
|
||||
async def remove_injections(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_injections(key)
|
||||
|
||||
async def set_additional_models(
|
||||
self, instance_id: str, key: str, models: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_additional_models(key, models)
|
||||
|
||||
async def remove_additional_models(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_additional_models(key)
|
||||
|
||||
async def get_nested_additional_models(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_nested_additional_models()
|
||||
)
|
||||
|
||||
async def get_additional_models(self, instance_id: str) -> List[str]:
|
||||
models = self._get_instance(instance_id).get_additional_models()
|
||||
return [self.register(m) for m in models]
|
||||
|
||||
async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_additional_models_with_key(key)
|
||||
)
|
||||
|
||||
async def model_patches_models(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).model_patches_models()
|
||||
)
|
||||
|
||||
async def get_patches(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy())
|
||||
|
||||
async def get_object_patches(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).object_patches.copy()
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(
|
||||
self, instance_id: str, filter_prefix: Optional[str] = None
|
||||
) -> Any:
|
||||
res = self._get_instance(instance_id).get_key_patches()
|
||||
if filter_prefix:
|
||||
res = {k: v for k, v in res.items() if k.startswith(filter_prefix)}
|
||||
safe_res = {}
|
||||
for k, v in res.items():
|
||||
safe_res[k] = [
|
||||
f"<Tensor shape={t.shape} dtype={t.dtype}>"
|
||||
if hasattr(t, "shape")
|
||||
else str(t)
|
||||
for t in v
|
||||
]
|
||||
return safe_res
|
||||
|
||||
async def add_hook_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict):
|
||||
try:
|
||||
hook.hook_ref = tuple(sorted(hook.hook_ref.items()))
|
||||
except Exception:
|
||||
hook.hook_ref = None
|
||||
self._get_instance(instance_id).add_hook_patches(
|
||||
hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any:
|
||||
if hooks is not None and hasattr(hooks, "hooks"):
|
||||
for hook in getattr(hooks, "hooks", []):
|
||||
hook_ref = getattr(hook, "hook_ref", None)
|
||||
if isinstance(hook_ref, dict):
|
||||
try:
|
||||
hook.hook_ref = tuple(sorted(hook_ref.items()))
|
||||
except Exception:
|
||||
hook.hook_ref = None
|
||||
res = self._get_instance(instance_id).get_combined_hook_patches(hooks)
|
||||
return self._sanitize_rpc_result(res)
|
||||
|
||||
async def clear_cached_hook_weights(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).clear_cached_hook_weights()
|
||||
|
||||
async def prepare_hook_patches_current_keyframe(
|
||||
self, instance_id: str, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).prepare_hook_patches_current_keyframe(
|
||||
t, hook_group, model_options
|
||||
)
|
||||
|
||||
async def get_parent(self, instance_id: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), "parent", None)
|
||||
|
||||
async def patch_weight_to_device(
|
||||
self,
|
||||
instance_id: str,
|
||||
key: str,
|
||||
device_to: Any = None,
|
||||
inplace_update: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).patch_weight_to_device(
|
||||
key, device_to, inplace_update
|
||||
)
|
||||
|
||||
async def pin_weight_to_device(self, instance_id: str, key: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.pin_weight_to_device(key)
|
||||
|
||||
async def unpin_weight(self, instance_id: str, key: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.unpin_weight(key)
|
||||
|
||||
async def unpin_all_weights(self, instance_id: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.unpin_all_weights()
|
||||
|
||||
async def calculate_weight(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
weight: Any,
|
||||
key: str,
|
||||
intermediate_dtype: Any = float,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).calculate_weight(
|
||||
patches, weight, key, intermediate_dtype
|
||||
)
|
||||
)
|
||||
|
||||
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
|
||||
try:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id).model, name)
|
||||
)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
async def inner_model_memory_required(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
|
||||
|
||||
async def inner_model_extra_conds_shapes(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
|
||||
|
||||
async def inner_model_extra_conds(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
|
||||
async def inner_model_state_dict(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs)
|
||||
return {
|
||||
k: {"numel": v.numel(), "element_size": v.element_size()}
|
||||
for k, v in sd.items()
|
||||
}
|
||||
|
||||
async def inner_model_apply_model(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
return detach_if_grad(_move(result))
|
||||
|
||||
async def process_latent_in(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
|
||||
)
|
||||
|
||||
async def process_latent_out(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def scale_latent_inpaint(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def load_lora(
|
||||
self,
|
||||
instance_id: str,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip_id: Optional[str] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> dict:
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.isolation.clip_proxy import CLIPRegistry
|
||||
|
||||
model = self._get_instance(instance_id)
|
||||
clip = None
|
||||
if clip_id:
|
||||
clip = CLIPRegistry()._get_instance(clip_id)
|
||||
lora_full_path = folder_paths.get_full_path("loras", lora_path)
|
||||
if lora_full_path is None:
|
||||
raise ValueError(f"LoRA file not found: {lora_path}")
|
||||
lora = comfy.utils.load_torch_file(lora_full_path)
|
||||
new_model, new_clip = comfy.sd.load_lora_for_models(
|
||||
model, clip, lora, strength_model, strength_clip
|
||||
)
|
||||
new_model_id = self.register(new_model) if new_model else None
|
||||
new_clip_id = (
|
||||
CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None
|
||||
)
|
||||
return {"model_id": new_model_id, "clip_id": new_clip_id}
|
||||
154
comfy/isolation/model_patcher_proxy_utils.py
Normal file
154
comfy/isolation/model_patcher_proxy_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
|
||||
# Isolation utilities and serializers for ModelPatcherProxy
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
if not isolation_active:
|
||||
return model_patcher
|
||||
if is_child:
|
||||
return model_patcher
|
||||
if isinstance(model_patcher, ModelPatcherProxy):
|
||||
return model_patcher
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
model_id = registry.register(model_patcher)
|
||||
logger.debug(f"Isolated ModelPatcher: {model_id}")
|
||||
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
|
||||
|
||||
|
||||
def register_hooks_serializers(registry=None):
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
import comfy.hooks
|
||||
|
||||
if registry is None:
|
||||
registry = SerializerRegistry.get_instance()
|
||||
|
||||
def serialize_enum(obj):
|
||||
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
|
||||
|
||||
def deserialize_enum(data):
|
||||
cls_name, val_name = data["__enum__"].split(".")
|
||||
cls = getattr(comfy.hooks, cls_name)
|
||||
return cls[val_name]
|
||||
|
||||
registry.register("EnumHookType", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
|
||||
|
||||
def serialize_hook_group(obj):
|
||||
return {"__type__": "HookGroup", "hooks": obj.hooks}
|
||||
|
||||
def deserialize_hook_group(data):
|
||||
hg = comfy.hooks.HookGroup()
|
||||
for h in data["hooks"]:
|
||||
hg.add(h)
|
||||
return hg
|
||||
|
||||
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
|
||||
|
||||
def serialize_dict_state(obj):
|
||||
d = obj.__dict__.copy()
|
||||
d["__type__"] = type(obj).__name__
|
||||
if "custom_should_register" in d:
|
||||
del d["custom_should_register"]
|
||||
return d
|
||||
|
||||
def deserialize_dict_state_generic(cls):
|
||||
def _deserialize(data):
|
||||
h = cls()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
return _deserialize
|
||||
|
||||
def deserialize_hook_keyframe(data):
|
||||
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
|
||||
|
||||
def deserialize_hook_keyframe_group(data):
|
||||
h = comfy.hooks.HookKeyframeGroup()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register(
|
||||
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
|
||||
)
|
||||
|
||||
def deserialize_hook(data):
|
||||
h = comfy.hooks.Hook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("Hook", serialize_dict_state, deserialize_hook)
|
||||
|
||||
def deserialize_weight_hook(data):
|
||||
h = comfy.hooks.WeightHook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
|
||||
|
||||
def serialize_set(obj):
|
||||
return {"__set__": list(obj)}
|
||||
|
||||
def deserialize_set(data):
|
||||
return set(data["__set__"])
|
||||
|
||||
registry.register("set", serialize_set, deserialize_set)
|
||||
|
||||
try:
|
||||
from comfy.weight_adapter.lora import LoRAAdapter
|
||||
|
||||
def serialize_lora(obj):
|
||||
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
|
||||
|
||||
def deserialize_lora(data):
|
||||
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
|
||||
|
||||
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import uuid
|
||||
|
||||
def serialize_hook_ref(obj):
|
||||
return {
|
||||
"__hook_ref__": True,
|
||||
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
|
||||
}
|
||||
|
||||
def deserialize_hook_ref(data):
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
|
||||
return h
|
||||
|
||||
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register _HookRef: {e}")
|
||||
|
||||
|
||||
try:
|
||||
register_hooks_serializers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize hook serializers: {e}")
|
||||
253
comfy/isolation/model_sampling_proxy.py
Normal file
253
comfy/isolation/model_sampling_proxy.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor) and t.is_cuda:
|
||||
return t.device
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.device
|
||||
return None
|
||||
|
||||
|
||||
def _to_device(obj: Any, device: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device != device:
|
||||
return obj.to(device)
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_device(x, device) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_device(v, device) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
||||
|
||||
async def calculate_denoised(
|
||||
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.calculate_denoised(sigma, model_output, model_input)
|
||||
)
|
||||
|
||||
async def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
||||
)
|
||||
|
||||
async def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
||||
|
||||
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.timestep(sigma)
|
||||
|
||||
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.sigma(timestep)
|
||||
|
||||
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.percent_to_sigma(percent)
|
||||
|
||||
async def get_sigma_min(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_min)
|
||||
|
||||
async def get_sigma_max(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_max)
|
||||
|
||||
async def get_sigma_data(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_data)
|
||||
|
||||
async def get_sigmas(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigmas)
|
||||
|
||||
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
sampling = self._get_instance(instance_id)
|
||||
sampling.set_sigmas(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
_registry_class = ModelSamplingRegistry
|
||||
__module__ = "comfy.isolation.model_sampling_proxy"
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
||||
)
|
||||
else:
|
||||
registry = ModelSamplingRegistry()
|
||||
|
||||
class _LocalCaller:
|
||||
def calculate_input(
|
||||
self, instance_id: str, sigma: Any, noise: Any
|
||||
) -> Any:
|
||||
return registry.calculate_input(instance_id, sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
model_output: Any,
|
||||
model_input: Any,
|
||||
) -> Any:
|
||||
return registry.calculate_denoised(
|
||||
instance_id, sigma, model_output, model_input
|
||||
)
|
||||
|
||||
def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
return registry.noise_scaling(
|
||||
instance_id, sigma, noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
return registry.inverse_noise_scaling(
|
||||
instance_id, sigma, latent
|
||||
)
|
||||
|
||||
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
return registry.timestep(instance_id, sigma)
|
||||
|
||||
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
return registry.sigma(instance_id, timestep)
|
||||
|
||||
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
return registry.percent_to_sigma(instance_id, percent)
|
||||
|
||||
def get_sigma_min(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_min(instance_id)
|
||||
|
||||
def get_sigma_max(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_max(instance_id)
|
||||
|
||||
def get_sigma_data(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_data(instance_id)
|
||||
|
||||
def get_sigmas(self, instance_id: str) -> Any:
|
||||
return registry.get_sigmas(instance_id)
|
||||
|
||||
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
return registry.set_sigmas(instance_id, sigmas)
|
||||
|
||||
self._rpc_caller = _LocalCaller()
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
if asyncio.iscoroutine(result):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
return self._call("get_sigma_min")
|
||||
|
||||
@property
|
||||
def sigma_max(self) -> Any:
|
||||
return self._call("get_sigma_max")
|
||||
|
||||
@property
|
||||
def sigma_data(self) -> Any:
|
||||
return self._call("get_sigma_data")
|
||||
|
||||
@property
|
||||
def sigmas(self) -> Any:
|
||||
return self._call("get_sigmas")
|
||||
|
||||
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
||||
return self._call("calculate_input", sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
return self._call("calculate_denoised", sigma, model_output, model_input)
|
||||
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
return self._call("inverse_noise_scaling", sigma, latent)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
def sigma(self, timestep: Any) -> Any:
|
||||
return self._call("sigma", timestep)
|
||||
|
||||
def percent_to_sigma(self, percent: float) -> Any:
|
||||
return self._call("percent_to_sigma", percent)
|
||||
|
||||
def set_sigmas(self, sigmas: Any) -> None:
|
||||
return self._call("set_sigmas", sigmas)
|
||||
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IS_CHILD_PROCESS",
|
||||
"BaseRegistry",
|
||||
"BaseProxy",
|
||||
"get_thread_loop",
|
||||
"run_coro_in_new_loop",
|
||||
"detach_if_grad",
|
||||
]
|
||||
213
comfy/isolation/proxies/base.py
Normal file
213
comfy/isolation/proxies/base.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
_thread_local = threading.local()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||
loop = getattr(_thread_local, "loop", None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||
result_box: Dict[str, Any] = {}
|
||||
exc_box: Dict[str, BaseException] = {}
|
||||
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result_box["value"] = loop.run_until_complete(coro)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exc_box["exc"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
if "exc" in exc_box:
|
||||
raise exc_box["exc"]
|
||||
return result_box.get("value")
|
||||
|
||||
|
||||
def detach_if_grad(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.detach() if obj.requires_grad else obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(detach_if_grad(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||
_type_prefix: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||
super().__init__()
|
||||
self._registry: Dict[str, T] = {}
|
||||
self._id_map: Dict[int, str] = {}
|
||||
self._counter = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, instance: T) -> str:
|
||||
with self._lock:
|
||||
obj_id = id(instance)
|
||||
if obj_id in self._id_map:
|
||||
return self._id_map[obj_id]
|
||||
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||
self._counter += 1
|
||||
self._registry[instance_id] = instance
|
||||
self._id_map[obj_id] = instance_id
|
||||
return instance_id
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._lock:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance:
|
||||
self._id_map.pop(id(instance), None)
|
||||
|
||||
def _get_instance(self, instance_id: str) -> T:
|
||||
if IS_CHILD_PROCESS:
|
||||
raise RuntimeError(
|
||||
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||
)
|
||||
with self._lock:
|
||||
instance = self._registry.get(instance_id)
|
||||
if instance is None:
|
||||
raise ValueError(f"{instance_id} not found")
|
||||
return instance
|
||||
|
||||
|
||||
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
|
||||
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _GLOBAL_LOOP
|
||||
_GLOBAL_LOOP = loop
|
||||
|
||||
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
registry: Optional[Any] = None,
|
||||
manage_lifecycle: bool = False,
|
||||
) -> None:
|
||||
self._instance_id = instance_id
|
||||
self._rpc_caller: Optional[Any] = None
|
||||
self._registry = registry if registry is not None else self._registry_class()
|
||||
self._manage_lifecycle = manage_lifecycle
|
||||
self._cleaned_up = False
|
||||
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||
self._finalizer = weakref.finalize(
|
||||
self, self._registry.unregister_sync, instance_id
|
||||
)
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is None:
|
||||
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
# If we are already in the global loop, we can't block on it?
|
||||
# Actually, this method is synchronous (__getattr__ -> lambda).
|
||||
# If called from async context in main loop, we need to handle that.
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
||||
# But proxies are often called from sync code.
|
||||
# If called from sync code in main loop, creating a new loop is bad.
|
||||
# But we can't await `coro`.
|
||||
# This implies proxies MUST be awaited if called from async context?
|
||||
# Existing code used `run_coro_in_new_loop` which is weird.
|
||||
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
||||
# we use run_coroutine_threadsafe.
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result()
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self._instance_id = state["_instance_id"]
|
||||
self._rpc_caller = None
|
||||
self._registry = self._registry_class()
|
||||
self._manage_lifecycle = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
finalizer = getattr(self, "_finalizer", None)
|
||||
if finalizer is not None:
|
||||
finalizer.detach()
|
||||
self._registry.unregister_sync(self._instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||
|
||||
|
||||
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(method_name, *args, **kwargs)
|
||||
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
|
||||
import folder_paths
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class FolderPathsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for folder_paths.
|
||||
Uses __getattr__ for most lookups, with explicit handling for
|
||||
mutable collections to ensure efficient by-value transfer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(folder_paths, name)
|
||||
|
||||
# Return dict snapshots (avoid RPC chatter)
|
||||
@property
|
||||
def folder_names_and_paths(self) -> Dict:
|
||||
return dict(folder_paths.folder_names_and_paths)
|
||||
|
||||
@property
|
||||
def extension_mimetypes_cache(self) -> Dict:
|
||||
return dict(folder_paths.extension_mimetypes_cache)
|
||||
|
||||
@property
|
||||
def filename_list_cache(self) -> Dict:
|
||||
return dict(folder_paths.filename_list_cache)
|
||||
98
comfy/isolation/proxies/helper_proxies.py
Normal file
98
comfy/isolation/proxies/helper_proxies.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class AnyTypeProxy(str):
|
||||
"""Replacement for custom AnyType objects used by some nodes."""
|
||||
|
||||
def __new__(cls, value: str = "*"):
|
||||
return super().__new__(cls, value)
|
||||
|
||||
def __ne__(self, other): # type: ignore[override]
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputProxy(dict):
|
||||
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||
|
||||
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||
super().__init__()
|
||||
self.type = flex_type
|
||||
if data:
|
||||
self.update(data)
|
||||
|
||||
def __getitem__(self, key): # type: ignore[override]
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key): # type: ignore[override]
|
||||
return True
|
||||
|
||||
|
||||
class ByPassTypeTupleProxy(tuple):
|
||||
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||
|
||||
def __new__(cls, values):
|
||||
return super().__new__(cls, values)
|
||||
|
||||
def __getitem__(self, index): # type: ignore[override]
|
||||
if index >= len(self):
|
||||
return AnyTypeProxy("*")
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
def _restore_special_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__pyisolate_any_type__"):
|
||||
return AnyTypeProxy(value.get("value", "*"))
|
||||
if value.get("__pyisolate_flexible_optional__"):
|
||||
flex_type = _restore_special_value(value.get("type"))
|
||||
data_raw = value.get("data")
|
||||
data = (
|
||||
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||
if isinstance(data_raw, dict)
|
||||
else {}
|
||||
)
|
||||
return FlexibleOptionalInputProxy(flex_type, data)
|
||||
if value.get("__pyisolate_tuple__") is not None:
|
||||
return tuple(
|
||||
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||
)
|
||||
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||
return ByPassTypeTupleProxy(
|
||||
tuple(
|
||||
_restore_special_value(v)
|
||||
for v in value["__pyisolate_bypass_tuple__"]
|
||||
)
|
||||
)
|
||||
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_restore_special_value(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
restored: Dict[str, object] = {}
|
||||
for section, entries in raw.items():
|
||||
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||
restored[section] = _restore_special_value(entries)
|
||||
elif isinstance(entries, dict):
|
||||
restored[section] = {
|
||||
k: _restore_special_value(v) for k, v in entries.items()
|
||||
}
|
||||
else:
|
||||
restored[section] = _restore_special_value(entries)
|
||||
return restored
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnyTypeProxy",
|
||||
"FlexibleOptionalInputProxy",
|
||||
"ByPassTypeTupleProxy",
|
||||
"restore_input_types",
|
||||
]
|
||||
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import comfy.model_management as mm
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class ModelManagementProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for comfy.model_management.
|
||||
Uses __getattr__ to forward all calls to the underlying module,
|
||||
reducing maintenance burden.
|
||||
"""
|
||||
|
||||
# Explicitly expose Enums/Classes as properties
|
||||
@property
|
||||
def VRAMState(self):
|
||||
return mm.VRAMState
|
||||
|
||||
@property
|
||||
def CPUState(self):
|
||||
return mm.CPUState
|
||||
|
||||
@property
|
||||
def OOM_EXCEPTION(self):
|
||||
return mm.OOM_EXCEPTION
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward all other attribute access to the module."""
|
||||
return getattr(mm, name)
|
||||
35
comfy/isolation/proxies/progress_proxy.py
Normal file
35
comfy/isolation/proxies/progress_proxy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton:
|
||||
pass
|
||||
|
||||
|
||||
from comfy_execution.progress import get_progress_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressProxy(ProxiedSingleton):
|
||||
def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ProgressProxy"]
|
||||
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||
"""Stateless RPC Implementation for PromptServer.
|
||||
|
||||
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||
- Host: PromptServerService (RPC Handler)
|
||||
- Child: PromptServerStub (Interface Implementation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
# IMPORTS
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_PREFIX = "[Isolation:C<->H]"
|
||||
|
||||
# ...
|
||||
|
||||
# =============================================================================
|
||||
# CHILD SIDE: PromptServerStub
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerStub:
|
||||
"""Stateless Stub for PromptServer."""
|
||||
|
||||
# Masquerade as the real server module
|
||||
__module__ = "server"
|
||||
|
||||
_instance: Optional["PromptServerStub"] = None
|
||||
_rpc: Optional[Any] = None # This will be the Caller object
|
||||
_source_file: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = RouteStub(self)
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
"""Inject RPC client (called by adapter.py or manually)."""
|
||||
# Create caller for HOST Service
|
||||
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||
# We target the Host Service Class
|
||||
target_id = "PromptServerService"
|
||||
# We need to pass a class to create_caller? Usually yes.
|
||||
# But we don't have the Service class imported here necessarily (if running on child).
|
||||
# pyisolate check verify_service type?
|
||||
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||
# We need a dummy class with right name?
|
||||
# Or just rely on string ID if create_caller supports it?
|
||||
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||
# But wait, PromptServerStub is the *Local* class.
|
||||
# We want to call *Remote* class.
|
||||
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||
# The first arg is 'service_cls'.
|
||||
cls._rpc = rpc.create_caller(
|
||||
PromptServerService, target_id
|
||||
) # We import Service below?
|
||||
|
||||
# We need PromptServerService available for the create_caller call?
|
||||
# Or just use the Stub class if ID matches?
|
||||
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||
|
||||
@property
|
||||
def instance(self) -> "PromptServerStub":
|
||||
return self
|
||||
|
||||
# ... Compatibility ...
|
||||
@classmethod
|
||||
def _get_source_file(cls) -> str:
|
||||
if cls._source_file is None:
|
||||
import folder_paths
|
||||
|
||||
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||
return cls._source_file
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self._get_source_file()
|
||||
|
||||
# --- Properties ---
|
||||
@property
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||
)
|
||||
|
||||
# --- UI Communication (RPC Delegates) ---
|
||||
async def send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send_sync(event, data, sid)
|
||||
|
||||
async def send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send(event, data, sid)
|
||||
|
||||
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||
if self._rpc:
|
||||
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||
# We must schedule it?
|
||||
# Or use fire_remote equivalent?
|
||||
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||
# But UtilsProxy hook wrapper creates task.
|
||||
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||
except RuntimeError:
|
||||
pass # Sync context without loop?
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
class RouteStub:
|
||||
"""Simulates aiohttp.web.RouteTableDef."""
|
||||
|
||||
def __init__(self, stub: PromptServerStub):
|
||||
self._stub = stub
|
||||
|
||||
def get(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST SIDE: PromptServerService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
from server import PromptServer
|
||||
|
||||
return PromptServer.instance
|
||||
|
||||
async def ui_send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send_sync(event, data, sid)
|
||||
|
||||
async def ui_send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send(event, data, sid)
|
||||
|
||||
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||
# Made async to be awaitable by RPC layer
|
||||
self.server.send_progress_text(text, node_id, sid)
|
||||
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
req_data = {
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"query": dict(request.query),
|
||||
}
|
||||
if request.can_read_body:
|
||||
req_data["text"] = await request.text()
|
||||
|
||||
try:
|
||||
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||
result = await child_handler_proxy(req_data)
|
||||
|
||||
# 3. Serialize Response
|
||||
return self._serialize_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
|
||||
def _serialize_response(self, result: Any) -> web.Response:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
if isinstance(result, web.Response):
|
||||
return result
|
||||
# Handle dict (json)
|
||||
if isinstance(result, dict):
|
||||
return web.json_response(result)
|
||||
# Handle string
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
import comfy.utils
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class UtilsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Proxy for comfy.utils.
|
||||
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||
from isolated nodes reach the host.
|
||||
"""
|
||||
|
||||
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
# Create caller using class name as ID (standard for Singletons)
|
||||
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||
|
||||
async def progress_bar_hook(
|
||||
self,
|
||||
value: int,
|
||||
total: int,
|
||||
preview: Optional[bytes] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Host-side implementation: forwards the call to the real global hook.
|
||||
Child-side: this method call is intercepted by RPC and sent to host.
|
||||
"""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
# Manual RPC dispatch for Child process
|
||||
# Use class-level RPC storage (Static Injection)
|
||||
if UtilsProxy._rpc:
|
||||
return await UtilsProxy._rpc.progress_bar_hook(
|
||||
value, total, preview, node_id
|
||||
)
|
||||
|
||||
# Fallback channel: global child rpc
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
get_child_rpc_instance()
|
||||
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
|
||||
# but we need a caller. For now, just pass to avoid crashing.
|
||||
pass
|
||||
except (ImportError, LookupError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# Host Execution
|
||||
if comfy.utils.PROGRESS_BAR_HOOK is not None:
|
||||
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||
|
||||
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||
"""Forward hook registration (though usually not needed from child)."""
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpcBridge:
|
||||
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||
|
||||
If an event loop is already running, the coroutine is executed on a fresh
|
||||
thread with its own loop to avoid nested run_until_complete errors.
|
||||
"""
|
||||
|
||||
def run_sync(self, maybe_coro):
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
return maybe_coro
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
result_container = {}
|
||||
exc_container = {}
|
||||
|
||||
def _runner():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||
except Exception as exc: # pragma: no cover
|
||||
exc_container["error"] = exc
|
||||
finally:
|
||||
try:
|
||||
new_loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if "error" in exc_container:
|
||||
raise exc_container["error"]
|
||||
return result_container.get("value")
|
||||
|
||||
return asyncio.run(maybe_coro)
|
||||
343
comfy/isolation/runtime_helpers.py
Normal file
343
comfy/isolation/runtime_helpers.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from .proxies.helper_proxies import restore_input_types
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from comfy_api.latest import _io as latest_io
|
||||
from .shm_forensics import scan_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def _resource_snapshot() -> Dict[str, int]:
|
||||
fd_count = -1
|
||||
shm_sender_files = 0
|
||||
try:
|
||||
fd_count = len(os.listdir("/proc/self/fd"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
shm_root = Path("/dev/shm")
|
||||
if shm_root.exists():
|
||||
prefix = f"torch_{os.getpid()}_"
|
||||
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||
except Exception:
|
||||
pass
|
||||
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||
|
||||
|
||||
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||
summary: Dict[str, int] = {
|
||||
"tensor_count": 0,
|
||||
"cpu_tensors": 0,
|
||||
"cuda_tensors": 0,
|
||||
"shared_cpu_tensors": 0,
|
||||
"tensor_bytes": 0,
|
||||
}
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return summary
|
||||
|
||||
def visit(node: Any) -> None:
|
||||
if isinstance(node, torch.Tensor):
|
||||
summary["tensor_count"] += 1
|
||||
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||
if node.device.type == "cpu":
|
||||
summary["cpu_tensors"] += 1
|
||||
if node.is_shared():
|
||||
summary["shared_cpu_tensors"] += 1
|
||||
elif node.device.type == "cuda":
|
||||
summary["cuda_tensors"] += 1
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
for v in node.values():
|
||||
visit(v)
|
||||
return
|
||||
if isinstance(node, (list, tuple)):
|
||||
for v in node:
|
||||
visit(v)
|
||||
|
||||
visit(value)
|
||||
return summary
|
||||
|
||||
|
||||
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||
for key, value in inputs.items():
|
||||
key_text = str(key)
|
||||
if "unique_id" in key_text:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return
|
||||
if not callable(flush_tensor_keeper):
|
||||
return
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
|
||||
|
||||
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.device.type == "cpu" and value.is_shared():
|
||||
clone = value.clone()
|
||||
if value.requires_grad:
|
||||
clone.requires_grad_(True)
|
||||
return clone
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def build_stub_class(
|
||||
node_name: str,
|
||||
info: Dict[str, object],
|
||||
extension: "ComfyNodeExtension",
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
logger: logging.Logger,
|
||||
) -> type:
|
||||
is_v3 = bool(info.get("is_v3", False))
|
||||
function_name = "_pyisolate_execute"
|
||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||
|
||||
async def _execute(self, **inputs):
|
||||
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||
|
||||
# Update BOTH the local dict AND the module-level dict
|
||||
running_extensions[extension.name] = extension
|
||||
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||
prev_child = None
|
||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||
summary = _tensor_transport_summary(inputs)
|
||||
resources = _resource_snapshot()
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
summary["tensor_count"],
|
||||
summary["cpu_tensors"],
|
||||
summary["cuda_tensors"],
|
||||
summary["shared_cpu_tensors"],
|
||||
summary["tensor_bytes"],
|
||||
resources["fd_count"],
|
||||
resources["shm_sender_files"],
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||
try:
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
|
||||
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||
from pyisolate._internal.model_serialization import (
|
||||
serialize_for_isolation,
|
||||
deserialize_from_isolation,
|
||||
)
|
||||
|
||||
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
serialized = serialize_for_isolation(inputs)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
result = await extension.execute_node(node_name, **serialized)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
deserialized = await deserialize_from_isolation(result, extension)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return _detach_shared_cpu_tensors(deserialized)
|
||||
except ImportError:
|
||||
return await extension.execute_node(node_name, **inputs)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if prev_child is not None:
|
||||
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||
logger.debug(
|
||||
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||
|
||||
def _input_types(
|
||||
cls,
|
||||
include_hidden: bool = True,
|
||||
return_schema: bool = False,
|
||||
live_inputs: Any = None,
|
||||
):
|
||||
if not is_v3:
|
||||
return restored_input_types
|
||||
|
||||
inputs_copy = copy.deepcopy(restored_input_types)
|
||||
if not include_hidden:
|
||||
inputs_copy.pop("hidden", None)
|
||||
|
||||
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
|
||||
if return_schema:
|
||||
hidden_vals = info.get("hidden", []) or []
|
||||
hidden_enums = []
|
||||
for h in hidden_vals:
|
||||
try:
|
||||
hidden_enums.append(latest_io.Hidden(h))
|
||||
except Exception:
|
||||
hidden_enums.append(h)
|
||||
|
||||
class SchemaProxy:
|
||||
hidden = hidden_enums
|
||||
|
||||
return inputs_copy, SchemaProxy, v3_data
|
||||
return inputs_copy
|
||||
|
||||
def _validate_class(cls):
|
||||
return True
|
||||
|
||||
def _get_node_info_v1(cls):
|
||||
return info.get("schema_v1", {})
|
||||
|
||||
def _get_base_class(cls):
|
||||
return latest_io.ComfyNode
|
||||
|
||||
attributes: Dict[str, object] = {
|
||||
"FUNCTION": function_name,
|
||||
"CATEGORY": info.get("category", ""),
|
||||
"OUTPUT_NODE": info.get("output_node", False),
|
||||
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||
"RETURN_NAMES": info.get("return_names"),
|
||||
function_name: _execute,
|
||||
"_pyisolate_extension": extension,
|
||||
"_pyisolate_node_name": node_name,
|
||||
"INPUT_TYPES": classmethod(_input_types),
|
||||
}
|
||||
|
||||
output_is_list = info.get("output_is_list")
|
||||
if output_is_list is not None:
|
||||
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||
|
||||
if is_v3:
|
||||
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||
attributes["DESCRIPTION"] = info.get("description", "")
|
||||
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||
attributes["API_NODE"] = info.get("api_node", False)
|
||||
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||
|
||||
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||
stub_cls = type(class_name, bases, attributes)
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
stub_cls.VALIDATE_CLASS()
|
||||
except Exception as e:
|
||||
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||
|
||||
return stub_cls
|
||||
|
||||
|
||||
def get_class_types_for_extension(
|
||||
extension_name: str,
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
specs: List[Any],
|
||||
) -> Set[str]:
|
||||
extension = running_extensions.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in specs:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
return class_types
|
||||
|
||||
|
||||
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||
217
comfy/isolation/shm_forensics.py
Normal file
217
comfy/isolation/shm_forensics.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _shm_debug_enabled() -> bool:
|
||||
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
||||
|
||||
|
||||
class _SHMForensicsTracker:
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._tracked_files: Set[str] = set()
|
||||
self._current_model_context: Dict[str, str] = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_shm() -> Set[str]:
|
||||
shm_path = Path("/dev/shm")
|
||||
if not shm_path.exists():
|
||||
return set()
|
||||
return {f.name for f in shm_path.glob("torch_*")}
|
||||
|
||||
def start(self) -> None:
|
||||
if self._started or not _shm_debug_enabled():
|
||||
return
|
||||
self._tracked_files = self._snapshot_shm()
|
||||
self._started = True
|
||||
logger.debug(
|
||||
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
self.scan("shutdown", refresh_model_context=True)
|
||||
self._started = False
|
||||
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
||||
|
||||
def _compute_model_hash(self, model_patcher: Any) -> str:
|
||||
try:
|
||||
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
||||
if model_instance_id is not None:
|
||||
model_id_text = str(model_instance_id)
|
||||
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
||||
|
||||
import torch
|
||||
|
||||
real_model = (
|
||||
model_patcher.model
|
||||
if hasattr(model_patcher, "model")
|
||||
else model_patcher
|
||||
)
|
||||
tensor = None
|
||||
if hasattr(real_model, "parameters"):
|
||||
for p in real_model.parameters():
|
||||
if torch.is_tensor(p) and p.numel() > 0:
|
||||
tensor = p
|
||||
break
|
||||
|
||||
if tensor is None:
|
||||
return "0000"
|
||||
|
||||
flat = tensor.flatten()
|
||||
values = []
|
||||
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
||||
for i in indices:
|
||||
if i < flat.shape[0]:
|
||||
values.append(flat[i].item())
|
||||
|
||||
size = 0
|
||||
if hasattr(model_patcher, "model_size"):
|
||||
size = model_patcher.model_size()
|
||||
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
||||
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
||||
except Exception:
|
||||
return "err!"
|
||||
|
||||
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
snapshot: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for loaded_model in model_management.current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is None:
|
||||
continue
|
||||
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
||||
continue
|
||||
|
||||
name = (
|
||||
model.model.__class__.__name__
|
||||
if hasattr(model, "model")
|
||||
else type(model).__name__
|
||||
)
|
||||
model_hash = self._compute_model_hash(model)
|
||||
model_instance_id = getattr(model, "_instance_id", None)
|
||||
if model_instance_id is None:
|
||||
model_instance_id = model_hash
|
||||
snapshot.append(
|
||||
{
|
||||
"name": str(name),
|
||||
"id": str(model_instance_id),
|
||||
"hash": str(model_hash or "????"),
|
||||
"used": bool(getattr(loaded_model, "currently_used", False)),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return snapshot
|
||||
|
||||
def _update_model_context(self) -> None:
|
||||
snapshot = self._get_models_snapshot()
|
||||
selected = None
|
||||
|
||||
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
||||
if used_models:
|
||||
selected = used_models[-1]
|
||||
else:
|
||||
live_models = [m for m in snapshot if m.get("id")]
|
||||
if live_models:
|
||||
selected = live_models[-1]
|
||||
|
||||
if selected is None:
|
||||
self._current_model_context = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
return
|
||||
|
||||
self._current_model_context = {
|
||||
"id": str(selected.get("id", "unknown")),
|
||||
"name": str(selected.get("name", "unknown")),
|
||||
"hash": str(selected.get("hash", "????") or "????"),
|
||||
}
|
||||
|
||||
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
||||
if not self._started or not _shm_debug_enabled():
|
||||
return
|
||||
|
||||
if refresh_model_context:
|
||||
self._update_model_context()
|
||||
|
||||
current = self._snapshot_shm()
|
||||
added = current - self._tracked_files
|
||||
removed = self._tracked_files - current
|
||||
self._tracked_files = current
|
||||
|
||||
if not added and not removed:
|
||||
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
||||
return
|
||||
|
||||
for filename in sorted(added):
|
||||
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
||||
model_id = self._current_model_context["id"]
|
||||
if model_id == "unknown":
|
||||
logger.error(
|
||||
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
||||
LOG_PREFIX,
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
||||
LOG_PREFIX,
|
||||
model_id,
|
||||
filename,
|
||||
self._current_model_context["name"],
|
||||
self._current_model_context["hash"],
|
||||
)
|
||||
|
||||
for filename in sorted(removed):
|
||||
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
||||
|
||||
logger.debug(
|
||||
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
len(added),
|
||||
len(removed),
|
||||
len(self._tracked_files),
|
||||
)
|
||||
|
||||
|
||||
_TRACKER = _SHMForensicsTracker()
|
||||
|
||||
|
||||
def start_shm_forensics() -> None:
|
||||
_TRACKER.start()
|
||||
|
||||
|
||||
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
||||
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
||||
|
||||
|
||||
def stop_shm_forensics() -> None:
|
||||
_TRACKER.stop()
|
||||
|
||||
|
||||
atexit.register(stop_shm_forensics)
|
||||
214
comfy/isolation/vae_proxy.py
Normal file
214
comfy/isolation/vae_proxy.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FirstStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "first_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
async def has_property(self, instance_id: str, name: str) -> bool:
|
||||
obj = self._get_instance(instance_id)
|
||||
return hasattr(obj, name)
|
||||
|
||||
|
||||
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
|
||||
_registry_class = FirstStageModelRegistry
|
||||
__module__ = "comfy.ldm.models.autoencoder"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirstStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class VAERegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "vae"
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return ModelPatcherRegistry().register(vae.patcher)
|
||||
|
||||
async def get_first_stage_model_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return FirstStageModelRegistry().register(vae.first_stage_model)
|
||||
|
||||
async def encode(self, instance_id: str, pixels: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
|
||||
|
||||
async def encode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
pixels: Any,
|
||||
tile_x: int = 512,
|
||||
tile_y: int = 512,
|
||||
overlap: int = 64,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_tiled(
|
||||
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
|
||||
|
||||
async def decode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).decode_tiled(
|
||||
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
|
||||
|
||||
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
|
||||
|
||||
async def process_input(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_input(image))
|
||||
|
||||
async def process_output(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_output(image))
|
||||
|
||||
|
||||
class VAEProxy(BaseProxy[VAERegistry]):
|
||||
_registry_class = VAERegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
@property
|
||||
def patcher(self) -> ModelPatcherProxy:
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@property
|
||||
def first_stage_model(self) -> FirstStageModelProxy:
|
||||
if not hasattr(self, "_first_stage_model_proxy"):
|
||||
fsm_id = self._call_rpc("get_first_stage_model_id")
|
||||
self._first_stage_model_proxy = FirstStageModelProxy(
|
||||
fsm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._first_stage_model_proxy
|
||||
|
||||
@property
|
||||
def vae_dtype(self) -> Any:
|
||||
return self._get_property("vae_dtype")
|
||||
|
||||
def encode(self, pixels: Any) -> Any:
|
||||
return self._call_rpc("encode", pixels)
|
||||
|
||||
def encode_tiled(
|
||||
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
|
||||
) -> Any:
|
||||
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
|
||||
|
||||
def decode(self, samples: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc("decode", samples, **kwargs)
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
|
||||
)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def _get_property(self, name: str) -> Any:
|
||||
return self._call_rpc("get_property", name)
|
||||
|
||||
@property
|
||||
def latent_dim(self) -> int:
|
||||
return self._get_property("latent_dim")
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return self._get_property("latent_channels")
|
||||
|
||||
@property
|
||||
def downscale_ratio(self) -> Any:
|
||||
return self._get_property("downscale_ratio")
|
||||
|
||||
@property
|
||||
def upscale_ratio(self) -> Any:
|
||||
return self._get_property("upscale_ratio")
|
||||
|
||||
@property
|
||||
def output_channels(self) -> int:
|
||||
return self._get_property("output_channels")
|
||||
|
||||
@property
|
||||
def check_not_vide(self) -> bool:
|
||||
return self._get_property("not_video")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self._get_property("device")
|
||||
|
||||
@property
|
||||
def working_dtypes(self) -> Any:
|
||||
return self._get_property("working_dtypes")
|
||||
|
||||
@property
|
||||
def disable_offload(self) -> bool:
|
||||
return self._get_property("disable_offload")
|
||||
|
||||
@property
|
||||
def size(self) -> Any:
|
||||
return self._get_property("size")
|
||||
|
||||
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_encode", shape, dtype)
|
||||
|
||||
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_decode", shape, dtype)
|
||||
|
||||
def process_input(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_input", image)
|
||||
|
||||
def process_output(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_output", image)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_VAE_REGISTRY_SINGLETON = VAERegistry()
|
||||
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()
|
||||
@@ -1,48 +1,21 @@
|
||||
import math
|
||||
import time
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange as trange_, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
|
||||
|
||||
def trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange_(*args, **kwargs)
|
||||
|
||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@@ -219,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
x = x.to(target_device)
|
||||
s_in = s_in.to(target_device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
|
||||
@@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
|
||||
return encoder_hidden, encoder_mask, context_latents
|
||||
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||
text_attention_mask = None
|
||||
lyric_attention_mask = None
|
||||
refer_audio_order_mask = None
|
||||
@@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||
)
|
||||
|
||||
if replace_with_null_embeds:
|
||||
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||
|
||||
out = self.decoder(hidden_states=x,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep,
|
||||
|
||||
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -4,8 +4,6 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -463,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -512,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -522,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -536,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -555,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -563,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -884,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
@@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@@ -16,7 +16,6 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -81,7 +80,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
@@ -143,6 +142,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -232,6 +232,7 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -217,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -233,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
@@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -702,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
@@ -746,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -766,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -787,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user