mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-02 14:19:17 +00:00
Compare commits
293 Commits
v0.9.0
...
pyisolate-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0f8784e9f | ||
|
|
95e1059661 | ||
|
|
80d49441e5 | ||
|
|
9d0e114ee3 | ||
|
|
ac4412d0fa | ||
|
|
94f1a1cc9d | ||
|
|
e721e24136 | ||
|
|
25ec3d96a3 | ||
|
|
7962db477a | ||
|
|
3c8ba051b6 | ||
|
|
a1c3124821 | ||
|
|
9ca799362d | ||
|
|
22f5e43c12 | ||
|
|
3cfd5e3311 | ||
|
|
1f1ec377ce | ||
|
|
0a7f8e11b6 | ||
|
|
35e9fce775 | ||
|
|
c7f7d52b68 | ||
|
|
08b26ed7c2 | ||
|
|
b233dbe0bc | ||
|
|
3811780e4f | ||
|
|
3dd10a59c0 | ||
|
|
88d05fe483 | ||
|
|
fd41ec97cc | ||
|
|
420e900f69 | ||
|
|
38ca94599f | ||
|
|
74b5a337dc | ||
|
|
8a4d85c708 | ||
|
|
a4522017c5 | ||
|
|
907e5dcbbf | ||
|
|
7253531670 | ||
|
|
e14b04478c | ||
|
|
eb8737d675 | ||
|
|
0467f690a8 | ||
|
|
4f5b7dbf1f | ||
|
|
3ebe1ac22e | ||
|
|
befa83d434 | ||
|
|
33f83d53ae | ||
|
|
b874bd2b8c | ||
|
|
0aa02453bb | ||
|
|
599f9c5010 | ||
|
|
11fefa58e9 | ||
|
|
d8090013b8 | ||
|
|
048dd2f321 | ||
|
|
84aba95e03 | ||
|
|
9b1c63eb69 | ||
|
|
7a7debcaf1 | ||
|
|
dba2766e53 | ||
|
|
caa43d2395 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 | ||
|
|
5fb8f06495 | ||
|
|
5a182bfaf1 | ||
|
|
f394af8d0f | ||
|
|
aeb5bdc8f6 | ||
|
|
64953bda0a | ||
|
|
b254cecd03 | ||
|
|
1bb956fb66 | ||
|
|
96d6bd1a4a | ||
|
|
5f2117528a | ||
|
|
0301ccf745 | ||
|
|
4d172e9ad7 | ||
|
|
5632b2df9d | ||
|
|
2687652530 | ||
|
|
6d11cc7354 | ||
|
|
f262444dd4 | ||
|
|
239ddd3327 | ||
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf | ||
|
|
19236edfa4 | ||
|
|
73c3f86973 | ||
|
|
262abf437b | ||
|
|
5284e6bf69 | ||
|
|
44f8598521 | ||
|
|
fe52843fe5 | ||
|
|
c39653163d | ||
|
|
18927538a1 | ||
|
|
8a6fbc2dc2 | ||
|
|
b44fc4c589 | ||
|
|
4454fab7f0 | ||
|
|
1978f59ffd | ||
|
|
88e6370527 | ||
|
|
c0370044cd | ||
|
|
ecd2a19661 | ||
|
|
2c1d06a4e3 | ||
|
|
e2c71ceb00 | ||
|
|
596ed68691 | ||
|
|
ce4a1ab48d | ||
|
|
e1ede29d82 | ||
|
|
df1e5e8514 | ||
|
|
dc9822b7df | ||
|
|
712efb466b | ||
|
|
726af73867 | ||
|
|
831351a29e | ||
|
|
e1add563f9 | ||
|
|
8902907d7a | ||
|
|
e03fe8b591 | ||
|
|
ae79e33345 | ||
|
|
117e214354 | ||
|
|
4a93a62371 | ||
|
|
66c18522fb | ||
|
|
e5ae670a40 | ||
|
|
3fe61cedda | ||
|
|
2a4328d639 | ||
|
|
d297a749a2 | ||
|
|
2b7cc7e3b6 | ||
|
|
4993411fd9 | ||
|
|
2c7cef4a23 | ||
|
|
76a7fa96db | ||
|
|
cdcf4119b3 | ||
|
|
dbe70b6821 | ||
|
|
00fff6019e | ||
|
|
123a7874a9 | ||
|
|
f719f9c062 | ||
|
|
fe053ba5eb | ||
|
|
6648ab68bc | ||
|
|
6615db925c | ||
|
|
8ca842a8ed | ||
|
|
c1b63a7e78 | ||
|
|
349a636a2b | ||
|
|
a4be04c5d7 | ||
|
|
baf8c87455 | ||
|
|
62315fbb15 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd | ||
|
|
a1c101f861 | ||
|
|
c2d7f07dbf | ||
|
|
458292fef0 | ||
|
|
6555dc65b8 | ||
|
|
2b70ab9ad0 | ||
|
|
00efcc6cd0 | ||
|
|
cb459573c8 | ||
|
|
35183543e0 | ||
|
|
a246cc02b2 | ||
|
|
a50c32d63f | ||
|
|
6125b80979 | ||
|
|
c8fcbd66ee | ||
|
|
26dd7eb421 | ||
|
|
e77b34dfea | ||
|
|
ef73070ea4 | ||
|
|
d30c609f5a | ||
|
|
5087f1d497 | ||
|
|
a31681564d | ||
|
|
855849c658 | ||
|
|
fe2511468d | ||
|
|
3be0175166 | ||
|
|
b8315e66cb | ||
|
|
ab1050bec3 | ||
|
|
fb23935c11 | ||
|
|
85fc35e8fa | ||
|
|
223364743c | ||
|
|
affe881354 | ||
|
|
f5030e26fd | ||
|
|
66e1b07402 | ||
|
|
be4345d1c9 | ||
|
|
3c1a1a2df8 | ||
|
|
ba5bf3f1a8 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
dd86b15521 | ||
|
|
021ba20719 | ||
|
|
b60be02aaf | ||
|
|
2b5da3b72e | ||
|
|
794d05bdb1 | ||
|
|
361b9a82a3 | ||
|
|
667a1b8878 | ||
|
|
32621c6a11 | ||
|
|
f8acd9c402 | ||
|
|
873de5f37a | ||
|
|
aa6f7a83bb | ||
|
|
6ea8c128a3 | ||
|
|
6e469a3f35 | ||
|
|
b8f848bfe3 | ||
|
|
4064062e7d | ||
|
|
8aabe2403e | ||
|
|
0167653781 | ||
|
|
0a7993729c | ||
|
|
bbe2c13a70 | ||
|
|
3aace5c8dc | ||
|
|
b0d9708974 | ||
|
|
c9b633d84f | ||
|
|
1711020904 | ||
|
|
d9b8567547 | ||
|
|
6c5f906bf2 | ||
|
|
4f5bd39b1c | ||
|
|
dcff27fe3f | ||
|
|
09725967cf | ||
|
|
5f62440fbb | ||
|
|
ac91c340f4 | ||
|
|
2db3b0ff90 | ||
|
|
6516ab335d | ||
|
|
ad53e78f11 | ||
|
|
29011ba87e | ||
|
|
cd4985e2f3 | ||
|
|
bfe31d0b9d | ||
|
|
2129e7d278 | ||
|
|
7ee77ff038 | ||
|
|
26c5bbb875 | ||
|
|
a97c98068f | ||
|
|
635406e283 | ||
|
|
ed6002cb60 | ||
|
|
bc72d7f8d1 | ||
|
|
aef4e13588 | ||
|
|
4e6a1b66a9 | ||
|
|
9cf299a9f9 | ||
|
|
e89b22993a | ||
|
|
55bd606e92 | ||
|
|
79cdbc81cb | ||
|
|
f443b9f2ca | ||
|
|
4e3038114a | ||
|
|
bbb8864778 | ||
|
|
d7f3241bf6 | ||
|
|
09a2e67151 | ||
|
|
0fd1b78736 | ||
|
|
8490eedadf | ||
|
|
72f6be1690 | ||
|
|
16b9aabd52 | ||
|
|
245f6139b6 | ||
|
|
3365ad18a5 | ||
|
|
f09904720d | ||
|
|
abe2ec26a6 | ||
|
|
bdeac8897e | ||
|
|
451af70154 | ||
|
|
0fc15700be | ||
|
|
e755268e7b | ||
|
|
c4a14df9a3 | ||
|
|
965d0ed509 | ||
|
|
ddc541ffda | ||
|
|
8ccc0c94fa | ||
|
|
4edb87aa50 | ||
|
|
0fc3b6e3a6 | ||
|
|
2108167f9f | ||
|
|
9d273d3ab1 | ||
|
|
70c91b8248 | ||
|
|
0da5a0fe58 | ||
|
|
e0eacb0688 | ||
|
|
7458e20465 | ||
|
|
b931b37e30 | ||
|
|
866a4619db | ||
|
|
1a72bf2046 | ||
|
|
034fac7054 | ||
|
|
a498556d0d | ||
|
|
f7ca41ff62 | ||
|
|
ac26065e61 | ||
|
|
190c4416cc | ||
|
|
0fd10ffa09 | ||
|
|
00c775950a | ||
|
|
7ac999bf30 | ||
|
|
0c6b36c6ac | ||
|
|
9125613b53 | ||
|
|
732b707397 | ||
|
|
4c816d5c69 | ||
|
|
6125b3a5e7 | ||
|
|
12918a5f78 | ||
|
|
8f40b43e02 | ||
|
|
3b832231bb | ||
|
|
be518db5a7 | ||
|
|
80441eb15e | ||
|
|
07f2462eae | ||
|
|
d150440466 | ||
|
|
6165c38cb5 | ||
|
|
712cca36a1 | ||
|
|
ac4d8ea9b3 | ||
|
|
c9196f355e | ||
|
|
7eb959ce93 | ||
|
|
469dd9c16a | ||
|
|
eff2b9d412 | ||
|
|
15b312de7a | ||
|
|
1419047fdb | ||
|
|
79f6bb5e4f | ||
|
|
e4b4fb3479 | ||
|
|
d9dc02a7d6 | ||
|
|
c543ad81c3 | ||
|
|
5ac1372533 | ||
|
|
1dcbd9efaf |
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
@@ -0,0 +1,127 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: false
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
sequence_diagrams: false
|
||||
estimate_code_review_effort: false
|
||||
assess_linked_issues: false
|
||||
related_issues: false
|
||||
related_prs: false
|
||||
suggested_labels: false
|
||||
auto_apply_labels: false
|
||||
suggested_reviewers: false
|
||||
auto_assign_reviewers: false
|
||||
in_progress_fortune: false
|
||||
enable_prompt_for_ai_agents: true
|
||||
|
||||
path_filters:
|
||||
- "!comfy_api_nodes/apis/**"
|
||||
- "!**/generated/*.pyi"
|
||||
- "!.ci/**"
|
||||
- "!script_examples/**"
|
||||
- "!**/__pycache__/**"
|
||||
- "!**/*.ipynb"
|
||||
- "!**/*.png"
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
- Backward compatibility (breaking changes affect all custom nodes)
|
||||
- Memory management and GPU resource handling
|
||||
- Performance implications in hot paths
|
||||
- Thread safety for concurrent execution
|
||||
- path: "comfy_api_nodes/**"
|
||||
instructions: |
|
||||
Third-party API integration nodes. Focus on:
|
||||
- No hardcoded API keys or secrets
|
||||
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||
- Correct Pydantic model usage
|
||||
- Security of user data passed to external APIs
|
||||
- path: "comfy_extras/**"
|
||||
instructions: |
|
||||
Community-contributed extra nodes. Focus on:
|
||||
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||
- No breaking changes to existing node interfaces
|
||||
- path: "comfy_execution/**"
|
||||
instructions: |
|
||||
Execution engine (graph execution, caching, jobs). Focus on:
|
||||
- Caching correctness
|
||||
- Concurrent execution safety
|
||||
- Graph validation edge cases
|
||||
- path: "nodes.py"
|
||||
instructions: |
|
||||
Core node definitions (2500+ lines). Focus on:
|
||||
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||
- Consistency of INPUT_TYPES return format
|
||||
- path: "alembic_db/**"
|
||||
instructions: |
|
||||
Database migrations. Focus on:
|
||||
- Migration safety and rollback support
|
||||
- Data preservation during schema changes
|
||||
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
enabled: false
|
||||
gitleaks:
|
||||
enabled: true
|
||||
shellcheck:
|
||||
enabled: false
|
||||
markdownlint:
|
||||
enabled: false
|
||||
yamllint:
|
||||
enabled: false
|
||||
languagetool:
|
||||
enabled: false
|
||||
github-checks:
|
||||
enabled: true
|
||||
timeout_ms: 90000
|
||||
ast-grep:
|
||||
essential_rules: true
|
||||
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
learnings:
|
||||
scope: "auto"
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
python_patch: "11"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
name: "Release AMD ROCm 7.2"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
cache_tag: "rocm72"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
repository: "Comfy-Org/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
|
||||
59
.github/workflows/update-ci-container.yml
vendored
Normal file
59
.github/workflows/update-ci-container.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
name: "CI: Update CI Container"
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'ComfyUI version (e.g., v0.7.0)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-ci-container:
|
||||
runs-on: ubuntu-latest
|
||||
# Skip pre-releases unless manually triggered
|
||||
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
|
||||
steps:
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
VERSION="${{ github.event.release.tag_name }}"
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout comfyui-ci-container
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: comfy-org/comfyui-ci-container
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
|
||||
- name: Check current version
|
||||
id: current
|
||||
run: |
|
||||
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
|
||||
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update Dockerfile
|
||||
run: |
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
|
||||
|
||||
- name: Create Pull Request
|
||||
id: create-pr
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
branch: automation/comfyui-${{ steps.version.outputs.version }}
|
||||
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
body: |
|
||||
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
|
||||
|
||||
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
|
||||
|
||||
labels: automation
|
||||
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "11"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
venv*/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
16
README.md
16
README.md
@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
@@ -208,11 +206,11 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
@@ -227,11 +225,11 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
@@ -240,7 +238,7 @@ These have less hardware support than the builds above but they work on windows.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@@ -8,6 +11,9 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@@ -8,9 +7,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@@ -171,12 +230,14 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@@ -265,3 +814,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@@ -1,13 +1,33 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@@ -63,7 +100,6 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@@ -76,7 +112,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
107
app/node_replace_manager.py
Normal file
107
app/node_replace_manager.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@@ -10,6 +10,7 @@ import hashlib
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
templates = "templates"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
@@ -38,9 +39,21 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
|
||||
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
|
||||
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": source,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": {"node_pack": node_pack},
|
||||
}
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
@@ -60,53 +73,60 @@ class SubgraphManager:
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
"""Load subgraphs from custom nodes."""
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
pattern = os.path.join(folder, "*/subgraphs/*.json")
|
||||
for file in glob.glob(pattern):
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
node_pack = "custom_nodes." + file.split('/')[-3]
|
||||
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
async def get_blueprint_subgraphs(self, force_reload=False):
|
||||
"""Load subgraphs from the blueprints directory."""
|
||||
if not force_reload and self.cached_blueprint_subgraphs is not None:
|
||||
return self.cached_blueprint_subgraphs
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
|
||||
|
||||
if os.path.exists(blueprints_dir):
|
||||
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
||||
file = file.replace('\\', '/')
|
||||
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_blueprint_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_all_subgraphs(self, loadedModules, force_reload=False):
|
||||
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
|
||||
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
|
||||
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
|
||||
return {**custom_node_subgraphs, **blueprint_subgraphs}
|
||||
|
||||
async def get_subgraph(self, id: str, loadedModules):
|
||||
"""Get a specific subgraph by ID from any source."""
|
||||
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
|
||||
if entry is not None and entry.get('data') is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
subgraph = await self.get_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
|
||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
0
blueprints/put_blueprints_here
Normal file
0
blueprints/put_blueprints_here
Normal file
@@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@@ -178,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
@@ -257,3 +260,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
@@ -66,6 +66,7 @@ class ClipVisionModel():
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
|
||||
@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
@@ -236,6 +238,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
DEV_ONLY: bool
|
||||
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
|
||||
@@ -4,6 +4,25 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
if not is_equal(self.cond, other.cond):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
return data_lp, scaled_block_scales_fp8
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
|
||||
return x, to_blocked(blocked_scaled, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
@@ -158,28 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
orig_shape = list(orig_shape)
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
|
||||
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
num_slices = max(1, (x.numel() / block_size))
|
||||
slice_size = max(1, (round(x.shape[0] / num_slices)))
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
for i in range(0, x.shape[0], slice_size):
|
||||
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
|
||||
output_fp4[i:i + slice_size].copy_(fp4)
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
@@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
327
comfy/isolation/__init__.py
Normal file
327
comfy/isolation/__init__.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
import folder_paths
|
||||
from .extension_loader import load_isolated_node
|
||||
from .manifest_loader import find_manifest_directories
|
||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyisolate import ExtensionManager
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||
|
||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolatedNodeSpec:
|
||||
node_name: str
|
||||
display_name: str
|
||||
stub_class: type
|
||||
module_path: Path
|
||||
|
||||
|
||||
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||
_CLAIMED_PATHS: Set[Path] = set()
|
||||
_ISOLATION_SCAN_ATTEMPTED = False
|
||||
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||
_EARLY_START_TIME: Optional[float] = None
|
||||
|
||||
|
||||
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
return
|
||||
_EARLY_START_TIME = time.perf_counter()
|
||||
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||
|
||||
|
||||
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
specs = await _ISOLATION_BACKGROUND_TASK
|
||||
return specs
|
||||
return await initialize_isolation_nodes()
|
||||
|
||||
|
||||
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||
|
||||
if _ISOLATED_NODE_SPECS:
|
||||
return _ISOLATED_NODE_SPECS
|
||||
|
||||
if _ISOLATION_SCAN_ATTEMPTED:
|
||||
return []
|
||||
|
||||
_ISOLATION_SCAN_ATTEMPTED = True
|
||||
manifest_entries = find_manifest_directories()
|
||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||
|
||||
if not manifest_entries:
|
||||
return []
|
||||
|
||||
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def load_with_semaphore(
|
||||
node_dir: Path, manifest: Path
|
||||
) -> List[IsolatedNodeSpec]:
|
||||
async with semaphore:
|
||||
load_start = time.perf_counter()
|
||||
spec_list = await load_isolated_node(
|
||||
node_dir,
|
||||
manifest,
|
||||
logger,
|
||||
lambda name, info, extension: build_stub_class(
|
||||
name,
|
||||
info,
|
||||
extension,
|
||||
_RUNNING_EXTENSIONS,
|
||||
logger,
|
||||
),
|
||||
PYISOLATE_VENV_ROOT,
|
||||
_EXTENSION_MANAGERS,
|
||||
)
|
||||
spec_list = [
|
||||
IsolatedNodeSpec(
|
||||
node_name=node_name,
|
||||
display_name=display_name,
|
||||
stub_class=stub_cls,
|
||||
module_path=node_dir,
|
||||
)
|
||||
for node_name, display_name, stub_cls in spec_list
|
||||
]
|
||||
isolated_node_timings.append(
|
||||
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||
)
|
||||
return spec_list
|
||||
|
||||
tasks = [
|
||||
load_with_semaphore(node_dir, manifest)
|
||||
for node_dir, manifest in manifest_entries
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
specs: List[IsolatedNodeSpec] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"%s Isolated node failed during startup; continuing: %s",
|
||||
LOG_PREFIX,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
specs.extend(result)
|
||||
|
||||
_ISOLATED_NODE_SPECS = specs
|
||||
return list(_ISOLATED_NODE_SPECS)
|
||||
|
||||
|
||||
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
"""Get all node class types (node names) belonging to an extension."""
|
||||
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in _ISOLATED_NODE_SPECS:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
|
||||
return class_types
|
||||
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
) -> None:
|
||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph → evict
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if getattr(device, "type", None) == "cuda":
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.unload_all_models()
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
finally:
|
||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
if not callable(flush_fn):
|
||||
continue
|
||||
try:
|
||||
flushed = await flush_fn()
|
||||
if isinstance(flushed, int):
|
||||
total_flushed += flushed
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush released=%d",
|
||||
LOG_PREFIX,
|
||||
ext_name,
|
||||
flushed,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||
)
|
||||
scan_shm_forensics(
|
||||
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||
)
|
||||
return total_flushed
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
|
||||
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||
"""Update all active RPC instances with the current event loop.
|
||||
|
||||
This MUST be called at the start of each workflow execution to ensure
|
||||
RPC calls are scheduled on the correct event loop. This handles the case
|
||||
where asyncio.run() creates a new event loop for each workflow.
|
||||
|
||||
Args:
|
||||
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||
"""
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
update_count = 0
|
||||
|
||||
# Update RPCs from ExtensionManagers
|
||||
for manager in _EXTENSION_MANAGERS:
|
||||
if not hasattr(manager, "extensions"):
|
||||
continue
|
||||
for name, extension in manager.extensions.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||
|
||||
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||
|
||||
if update_count > 0:
|
||||
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"initialize_proxies",
|
||||
"initialize_isolation_nodes",
|
||||
"start_isolation_loading_early",
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
"get_class_types_for_extension",
|
||||
]
|
||||
505
comfy/isolation/adapter.py
Normal file
505
comfy/isolation/adapter.py
Normal file
@@ -0,0 +1,505 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||
from comfy.isolation.model_patcher_proxy import (
|
||||
ModelPatcherProxy,
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingProxy,
|
||||
ModelSamplingRegistry,
|
||||
)
|
||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
except ImportError as exc: # Fail loud if Comfy environment is incomplete
|
||||
raise ImportError(f"ComfyUI environment incomplete: {exc}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||
import tempfile
|
||||
|
||||
if os.path.exists("/dev/shm"):
|
||||
# Only override if not already set or if default is not /dev/shm
|
||||
current_tmp = tempfile.gettempdir()
|
||||
if not current_tmp.startswith("/dev/shm"):
|
||||
logger.debug(
|
||||
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||
)
|
||||
os.environ["TMPDIR"] = "/dev/shm"
|
||||
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||
|
||||
|
||||
class ComfyUIAdapter(IsolationAdapter):
|
||||
# ComfyUI-specific IsolationAdapter implementation
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "comfyui"
|
||||
|
||||
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||
parts = module_path.split("ComfyUI")
|
||||
if len(parts) > 1:
|
||||
comfy_root = parts[0] + "ComfyUI"
|
||||
return {
|
||||
"preferred_root": comfy_root,
|
||||
"additional_paths": [
|
||||
os.path.join(comfy_root, "custom_nodes"),
|
||||
os.path.join(comfy_root, "comfy"),
|
||||
],
|
||||
}
|
||||
return None
|
||||
|
||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||
comfy_root = snapshot.get("preferred_root")
|
||||
if not comfy_root:
|
||||
return
|
||||
|
||||
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||
if requirements_path.exists():
|
||||
import re
|
||||
|
||||
for line in requirements_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||
if pkg_name:
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelPatcher in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with registry
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
model_id = ModelPatcherRegistry().register(obj)
|
||||
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||
|
||||
def deserialize_model_patcher(data: Any) -> Any:
|
||||
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
return data
|
||||
|
||||
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
else:
|
||||
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||
|
||||
# Register ModelPatcher type for serialization
|
||||
registry.register(
|
||||
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||
registry.register(
|
||||
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||
|
||||
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||
clip_id = CLIPRegistry().register(obj)
|
||||
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||
|
||||
def deserialize_clip(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
return data
|
||||
|
||||
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
else:
|
||||
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||
|
||||
# Register CLIP type for serialization
|
||||
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||
# Register CLIPProxy type (already a proxy, just return ref)
|
||||
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||
|
||||
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||
vae_id = VAERegistry().register(obj)
|
||||
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||
|
||||
def deserialize_vae(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return VAEProxy(data["vae_id"])
|
||||
return data
|
||||
|
||||
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware VAERef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
# Child: create a proxy
|
||||
return VAEProxy(data["vae_id"])
|
||||
else:
|
||||
# Host: lookup real VAE from registry
|
||||
return VAERegistry()._get_instance(data["vae_id"])
|
||||
|
||||
# Register VAE type for serialization
|
||||
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||
# Register VAEProxy type (already a proxy, just return ref)
|
||||
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||
# Register VAERef for deserialization (context-aware: host or child)
|
||||
registry.register("VAERef", None, deserialize_vae_ref)
|
||||
|
||||
# ModelSampling serialization - handles ModelSampling* types
|
||||
# copyreg removed - no pickle fallback allowed
|
||||
|
||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelSampling in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
|
||||
def deserialize_model_sampling(data: Any) -> Any:
|
||||
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
return data
|
||||
|
||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register ModelSampling type and proxy
|
||||
registry.register(
|
||||
"ModelSamplingDiscrete",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousEDM",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousV",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||
|
||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"cond": obj.cond,
|
||||
}
|
||||
|
||||
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(data["cond"])
|
||||
|
||||
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if callable(value):
|
||||
continue
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"state": _serialize_public_state(obj),
|
||||
}
|
||||
|
||||
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
obj = cls()
|
||||
for key, value in data.get("state", {}).items():
|
||||
prop = getattr(type(obj), key, None)
|
||||
if isinstance(prop, property) and prop.fset is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
import comfy.conds
|
||||
|
||||
for cond_cls in vars(comfy.conds).values():
|
||||
if not isinstance(cond_cls, type):
|
||||
continue
|
||||
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||
continue
|
||||
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||
|
||||
import comfy.latent_formats
|
||||
|
||||
for latent_cls in vars(comfy.latent_formats).values():
|
||||
if not isinstance(latent_cls, type):
|
||||
continue
|
||||
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||
continue
|
||||
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||
registry.register(
|
||||
type_key, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
registry.register(
|
||||
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
|
||||
# V3 API: unwrap NodeOutput.args
|
||||
def deserialize_node_output(data: Any) -> Any:
|
||||
return getattr(data, "args", data)
|
||||
|
||||
registry.register("NodeOutput", None, deserialize_node_output)
|
||||
|
||||
# KSAMPLER serializer: stores sampler name instead of function object
|
||||
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||
func_name = obj.sampler_function.__name__
|
||||
# Map function name back to sampler name
|
||||
if func_name == "sample_unipc":
|
||||
sampler_name = "uni_pc"
|
||||
elif func_name == "sample_unipc_bh2":
|
||||
sampler_name = "uni_pc_bh2"
|
||||
elif func_name == "dpm_fast_function":
|
||||
sampler_name = "dpm_fast"
|
||||
elif func_name == "dpm_adaptive_function":
|
||||
sampler_name = "dpm_adaptive"
|
||||
elif func_name.startswith("sample_"):
|
||||
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||
else:
|
||||
sampler_name = func_name
|
||||
return {
|
||||
"__type__": "KSAMPLER",
|
||||
"sampler_name": sampler_name,
|
||||
"extra_options": obj.extra_options,
|
||||
"inpaint_options": obj.inpaint_options,
|
||||
}
|
||||
|
||||
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||
import comfy.samplers
|
||||
|
||||
return comfy.samplers.ksampler(
|
||||
data["sampler_name"],
|
||||
data.get("extra_options", {}),
|
||||
data.get("inpaint_options", {}),
|
||||
)
|
||||
|
||||
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||
|
||||
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
registry.register("ndarray", serialize_numpy, None)
|
||||
|
||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||
return [
|
||||
PromptServerService,
|
||||
FolderPathsProxy,
|
||||
ModelManagementProxy,
|
||||
UtilsProxy,
|
||||
ProgressProxy,
|
||||
VAERegistry,
|
||||
CLIPRegistry,
|
||||
ModelPatcherRegistry,
|
||||
ModelSamplingRegistry,
|
||||
FirstStageModelRegistry,
|
||||
]
|
||||
|
||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||
|
||||
if api_name == "FolderPathsProxy":
|
||||
import folder_paths
|
||||
|
||||
# Replace module-level functions with proxy methods
|
||||
# This is aggressive but necessary for transparent proxying
|
||||
# Handle both instance and class cases
|
||||
instance = api() if isinstance(api, type) else api
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(folder_paths, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "ModelManagementProxy":
|
||||
import comfy.model_management
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
# Replace module-level functions with proxy methods
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(comfy.model_management, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "UtilsProxy":
|
||||
import comfy.utils
|
||||
|
||||
# Static Injection of RPC mechanism to ensure Child can access it
|
||||
# independent of instance lifecycle.
|
||||
api.set_rpc(rpc)
|
||||
|
||||
# Don't overwrite host hook (infinite recursion)
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
141
comfy/isolation/child_hooks.py
Normal file
141
comfy/isolation/child_hooks.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||
# Child process initialization for PyIsolate
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
# Manual RPC injection
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc:
|
||||
_setup_prompt_server_stub(rpc)
|
||||
_setup_utils_proxy(rpc)
|
||||
else:
|
||||
logger.warning("Could not get child RPC instance for manual injection")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
except Exception as e:
|
||||
logger.error(f"Manual RPC Injection failed: {e}")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Mock server module
|
||||
if "server" not in sys.modules:
|
||||
mock_server = types.ModuleType("server")
|
||||
sys.modules["server"] = mock_server
|
||||
|
||||
server = sys.modules["server"]
|
||||
|
||||
if not hasattr(server, "PromptServer"):
|
||||
|
||||
class MockPromptServer:
|
||||
pass
|
||||
|
||||
server.PromptServer = MockPromptServer
|
||||
|
||||
stub = PromptServerStub()
|
||||
|
||||
if rpc:
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
if hasattr(stub, "set_rpc"):
|
||||
stub.set_rpc(rpc)
|
||||
|
||||
server.PromptServer.instance = stub
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||
|
||||
|
||||
def _setup_utils_proxy(rpc=None) -> None:
|
||||
try:
|
||||
import comfy.utils
|
||||
import asyncio
|
||||
|
||||
# Capture main loop during initialization (safe context)
|
||||
main_loop = None
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .proxies.base import set_global_loop
|
||||
|
||||
if main_loop:
|
||||
set_global_loop(main_loop)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Sync hook wrapper for progress updates
|
||||
def sync_hook_wrapper(
|
||||
value: int, total: int, preview: None = None, node_id: None = None
|
||||
) -> None:
|
||||
if node_id is None:
|
||||
try:
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
ctx = get_executing_context()
|
||||
if ctx:
|
||||
node_id = ctx.node_id
|
||||
else:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Bypass blocked event loop by direct outbox injection
|
||||
if rpc:
|
||||
try:
|
||||
# Use captured main loop if available (for threaded execution), or current loop
|
||||
loop = main_loop
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
rpc.outbox.put(
|
||||
{
|
||||
"kind": "call",
|
||||
"object_id": "UtilsProxy",
|
||||
"parent_call_id": None, # We are root here usually
|
||||
"calling_loop": loop,
|
||||
"future": loop.create_future(), # Dummy future
|
||||
"method": "progress_bar_hook",
|
||||
"args": (value, total, preview, node_id),
|
||||
"kwargs": {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(
|
||||
"No RPC instance available for progress update"
|
||||
)
|
||||
|
||||
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup UtilsProxy hook: {e}")
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
327
comfy/isolation/clip_proxy.py
Normal file
327
comfy/isolation/clip_proxy.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
|
||||
# CLIP Proxy implementation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
class CondStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "cond_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
|
||||
_registry_class = CondStageModelRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CondStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class TokenizerRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "tokenizer"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
|
||||
_registry_class = TokenizerRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TokenizerProxy {self._instance_id}>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "clip"
|
||||
_allowed_setters = {
|
||||
"layer_idx",
|
||||
"tokenizer_options",
|
||||
"use_clip_schedule",
|
||||
"apply_hooks_to_conds",
|
||||
}
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
|
||||
|
||||
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
|
||||
|
||||
async def get_cond_stage_model_id(self, instance_id: str) -> str:
|
||||
return CondStageModelRegistry().register(
|
||||
self._get_instance(instance_id).cond_stage_model
|
||||
)
|
||||
|
||||
async def get_tokenizer_id(self, instance_id: str) -> str:
|
||||
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
|
||||
|
||||
async def load_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).load_model()
|
||||
|
||||
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
|
||||
self._get_instance(instance_id).clip_layer(layer_idx)
|
||||
|
||||
async def set_tokenizer_option(
|
||||
self, instance_id: str, option_name: str, value: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
|
||||
if name not in self._allowed_setters:
|
||||
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
|
||||
setattr(self._get_instance(instance_id), name, value)
|
||||
|
||||
async def tokenize(
|
||||
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).tokenize(
|
||||
text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
async def encode(self, instance_id: str, text: str) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(text))
|
||||
|
||||
async def encode_from_tokens(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
return_pooled: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens(
|
||||
tokens, return_pooled=return_pooled, return_dict=return_dict
|
||||
)
|
||||
)
|
||||
|
||||
async def encode_from_tokens_scheduled(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens_scheduled(
|
||||
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
|
||||
)
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch=strength_patch, strength_model=strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_key_patches()
|
||||
|
||||
async def load_sd(
|
||||
self, instance_id: str, sd: dict, full_model: bool = False
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
|
||||
|
||||
async def get_sd(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_sd()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
return self.register(self._get_instance(instance_id).clone())
|
||||
|
||||
|
||||
class CLIPProxy(BaseProxy[CLIPRegistry]):
|
||||
_registry_class = CLIPRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
@property
|
||||
def patcher(self) -> "ModelPatcherProxy":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@patcher.setter
|
||||
def patcher(self, value: Any) -> None:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if isinstance(value, ModelPatcherProxy):
|
||||
self._patcher_proxy = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self) -> CondStageModelProxy:
|
||||
if not hasattr(self, "_cond_stage_model_proxy"):
|
||||
csm_id = self._call_rpc("get_cond_stage_model_id")
|
||||
self._cond_stage_model_proxy = CondStageModelProxy(
|
||||
csm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._cond_stage_model_proxy
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerProxy:
|
||||
if not hasattr(self, "_tokenizer_proxy"):
|
||||
tok_id = self._call_rpc("get_tokenizer_id")
|
||||
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
|
||||
return self._tokenizer_proxy
|
||||
|
||||
def load_model(self) -> ModelPatcherProxy:
|
||||
self._call_rpc("load_model")
|
||||
return self.patcher
|
||||
|
||||
@property
|
||||
def layer_idx(self) -> Optional[int]:
|
||||
return self._call_rpc("get_property", "layer_idx")
|
||||
|
||||
@layer_idx.setter
|
||||
def layer_idx(self, value: Optional[int]) -> None:
|
||||
self._call_rpc("set_property", "layer_idx", value)
|
||||
|
||||
@property
|
||||
def tokenizer_options(self) -> dict:
|
||||
return self._call_rpc("get_property", "tokenizer_options")
|
||||
|
||||
@tokenizer_options.setter
|
||||
def tokenizer_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_property", "tokenizer_options", value)
|
||||
|
||||
@property
|
||||
def use_clip_schedule(self) -> bool:
|
||||
return self._call_rpc("get_property", "use_clip_schedule")
|
||||
|
||||
@use_clip_schedule.setter
|
||||
def use_clip_schedule(self, value: bool) -> None:
|
||||
self._call_rpc("set_property", "use_clip_schedule", value)
|
||||
|
||||
@property
|
||||
def apply_hooks_to_conds(self) -> Any:
|
||||
return self._call_rpc("get_property", "apply_hooks_to_conds")
|
||||
|
||||
@apply_hooks_to_conds.setter
|
||||
def apply_hooks_to_conds(self, value: Any) -> None:
|
||||
self._call_rpc("set_property", "apply_hooks_to_conds", value)
|
||||
|
||||
def clip_layer(self, layer_idx: int) -> None:
|
||||
return self._call_rpc("clip_layer", layer_idx)
|
||||
|
||||
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
|
||||
return self._call_rpc("set_tokenizer_option", option_name, value)
|
||||
|
||||
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(
|
||||
"tokenize", text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> Any:
|
||||
return self._call_rpc("encode", text)
|
||||
|
||||
def encode_from_tokens(
|
||||
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
|
||||
) -> Any:
|
||||
res = self._call_rpc(
|
||||
"encode_from_tokens",
|
||||
tokens,
|
||||
return_pooled=return_pooled,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if return_pooled and isinstance(res, list) and not return_dict:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
def encode_from_tokens_scheduled(
|
||||
self,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return self._call_rpc(
|
||||
"encode_from_tokens_scheduled",
|
||||
tokens,
|
||||
unprojected=unprojected,
|
||||
add_dict=add_dict,
|
||||
show_pbar=show_pbar,
|
||||
)
|
||||
|
||||
def add_patches(
|
||||
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"add_patches",
|
||||
patches,
|
||||
strength_patch=strength_patch,
|
||||
strength_model=strength_model,
|
||||
)
|
||||
|
||||
def get_key_patches(self) -> Any:
|
||||
return self._call_rpc("get_key_patches")
|
||||
|
||||
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
|
||||
return self._call_rpc("load_sd", sd, full_model=full_model)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def clone(self) -> CLIPProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
|
||||
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
|
||||
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()
|
||||
248
comfy/isolation/extension_loader.py
Normal file
248
comfy/isolation/extension_loader.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _stop_extension_safe(
|
||||
extension: ComfyNodeExtension, extension_name: str
|
||||
) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
logger.info(f"][ Loading isolated node: {extension_name}")
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
ComfyNodeExtension, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
extension_config = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.debug(
|
||||
"][ %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"][ {extension_name} loaded from cache")
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
673
comfy/isolation/extension_wrapper.py
Normal file
673
comfy/isolation/extension_wrapper.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(handler, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
result = result.args
|
||||
if self._is_comfy_protocol_return(result):
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(result),
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.detach() if data.requires_grad else data
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = data
|
||||
return RemoteObjectHandle(object_id, type(data).__name__)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
f"Failed to call initialize_proxies in before_module_loaded: {e}"
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
26
comfy/isolation/host_hooks.py
Normal file
26
comfy/isolation/host_hooks.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# Host process initialization for PyIsolate
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_host_process() -> None:
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
root.addHandler(logging.NullHandler())
|
||||
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerService
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
from .vae_proxy import VAERegistry
|
||||
|
||||
FolderPathsProxy()
|
||||
ModelManagementProxy()
|
||||
ProgressProxy()
|
||||
PromptServerService()
|
||||
UtilsProxy()
|
||||
VAERegistry()
|
||||
83
comfy/isolation/host_policy.py
Normal file
83
comfy/isolation/host_policy.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm", "/tmp"],
|
||||
"readonly_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_path = comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
|
||||
|
||||
logger.debug(
|
||||
f"Loaded Host Policy: {len(policy['whitelist'])} whitelisted nodes, Network={policy['allow_network']}"
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
186
comfy/isolation/manifest_loader.py
Normal file
186
comfy/isolation/manifest_loader.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_SUBDIR = "cache"
|
||||
CACHE_KEY_FILE = "cache_key"
|
||||
CACHE_DATA_FILE = "node_info.json"
|
||||
CACHE_KEY_LENGTH = 16
|
||||
|
||||
|
||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||
|
||||
# Standard custom_nodes paths
|
||||
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||
base = Path(base_path)
|
||||
if not base.exists() or not base.is_dir():
|
||||
continue
|
||||
|
||||
for entry in base.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Look for pyproject.toml
|
||||
manifest = entry / "pyproject.toml"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
|
||||
# Validate [tool.comfy.isolation] section existence
|
||||
try:
|
||||
with manifest.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
if (
|
||||
"tool" in data
|
||||
and "comfy" in data["tool"]
|
||||
and "isolation" in data["tool"]["comfy"]
|
||||
):
|
||||
manifest_dirs.append((entry, manifest))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return manifest_dirs
|
||||
|
||||
|
||||
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
try:
|
||||
# Hashing the manifest content ensures config changes invalidate cache
|
||||
hasher.update(manifest_path.read_bytes())
|
||||
except OSError:
|
||||
hasher.update(b"__manifest_read_error__")
|
||||
|
||||
try:
|
||||
py_files = sorted(node_dir.rglob("*.py"))
|
||||
for py_file in py_files:
|
||||
rel_path = py_file.relative_to(node_dir)
|
||||
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||
continue
|
||||
hasher.update(str(rel_path).encode("utf-8"))
|
||||
try:
|
||||
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||
except OSError:
|
||||
hasher.update(b"__file_stat_error__")
|
||||
except OSError:
|
||||
hasher.update(b"__dir_scan_error__")
|
||||
|
||||
hasher.update(sys.version.encode("utf-8"))
|
||||
|
||||
try:
|
||||
import pyisolate
|
||||
|
||||
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||
except (ImportError, AttributeError):
|
||||
hasher.update(b"__pyisolate_unknown__")
|
||||
|
||||
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||
|
||||
|
||||
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||
|
||||
|
||||
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||
"""Return True only if stored cache key matches current computed key."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||
return False
|
||||
current_key = compute_cache_key(node_dir, manifest_path)
|
||||
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||
return current_key == stored_key
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||
"""Load node metadata from cache, return None on any error."""
|
||||
try:
|
||||
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_data_file.exists():
|
||||
return None
|
||||
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(
|
||||
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||
) -> None:
|
||||
"""Save node metadata and cache key atomically."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
cache_dir = cache_key_file.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||
|
||||
# Atomic write: data
|
||||
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||
json.dump(node_data, f, indent=2)
|
||||
os.replace(tmp_data_path, cache_data_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_data_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Atomic write: key
|
||||
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||
f.write(cache_key)
|
||||
os.replace(tmp_key_path, cache_key_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_key_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"find_manifest_directories",
|
||||
"compute_cache_key",
|
||||
"get_cache_path",
|
||||
"is_cache_valid",
|
||||
"load_from_cache",
|
||||
"save_to_cache",
|
||||
]
|
||||
820
comfy/isolation/model_patcher_proxy.py
Normal file
820
comfy/isolation/model_patcher_proxy.py
Normal file
@@ -0,0 +1,820 @@
|
||||
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
|
||||
# RPC proxy for ModelPatcher (parent process)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, List, Set, Dict, Callable
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
AutoPatcherEjector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
_registry_class = ModelPatcherRegistry
|
||||
__module__ = "comfy.model_patcher"
|
||||
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
else:
|
||||
self._rpc_caller = self._registry
|
||||
return self._rpc_caller
|
||||
|
||||
def get_all_callbacks(self, call_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_callbacks", call_type)
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_wrappers", wrapper_type)
|
||||
|
||||
def _load_list(self, *args, **kwargs) -> Any:
|
||||
return self._call_rpc("load_list_internal", *args, **kwargs)
|
||||
|
||||
def prepare_hook_patches_current_keyframe(
|
||||
self, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
|
||||
)
|
||||
|
||||
def add_hook_patches(
|
||||
self,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"add_hook_patches", hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
def clear_cached_hook_weights(self) -> None:
|
||||
self._call_rpc("clear_cached_hook_weights")
|
||||
|
||||
def get_combined_hook_patches(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("get_combined_hook_patches", hooks)
|
||||
|
||||
def get_additional_models_with_key(self, key: str) -> Any:
|
||||
return self._call_rpc("get_additional_models_with_key", key)
|
||||
|
||||
@property
|
||||
def object_patches(self) -> Any:
|
||||
return self._call_rpc("get_object_patches")
|
||||
|
||||
@property
|
||||
def patches(self) -> Any:
|
||||
res = self._call_rpc("get_patches")
|
||||
if isinstance(res, dict):
|
||||
new_res = {}
|
||||
for k, v in res.items():
|
||||
new_list = []
|
||||
for item in v:
|
||||
if isinstance(item, list):
|
||||
new_list.append(tuple(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_res[k] = new_list
|
||||
return new_res
|
||||
return res
|
||||
|
||||
@property
|
||||
def pinned(self) -> Set:
|
||||
val = self._call_rpc("get_patcher_attr", "pinned")
|
||||
return set(val) if val is not None else set()
|
||||
|
||||
@property
|
||||
def hook_patches(self) -> Dict:
|
||||
val = self._call_rpc("get_patcher_attr", "hook_patches")
|
||||
if val is None:
|
||||
return {}
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import json
|
||||
|
||||
new_val = {}
|
||||
for k, v in val.items():
|
||||
if isinstance(k, str):
|
||||
if k.startswith("PYISOLATE_HOOKREF:"):
|
||||
ref_id = k.split(":", 1)[1]
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
elif k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[len("__pyisolate_key__") :]
|
||||
data = json.loads(json_str)
|
||||
ref_id = None
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and item[0] == "id"
|
||||
):
|
||||
ref_id = item[1]
|
||||
break
|
||||
if ref_id:
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
except Exception:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
return new_val
|
||||
except ImportError:
|
||||
return val
|
||||
|
||||
def set_hook_mode(self, hook_mode: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", hook_mode)
|
||||
|
||||
def register_all_hook_patches(
|
||||
self,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any = None,
|
||||
registered: Any = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"register_all_hook_patches", hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
if isinstance(other, ModelPatcherProxy):
|
||||
return self._call_rpc("is_clone_by_id", other._instance_id)
|
||||
return False
|
||||
|
||||
def clone(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return ModelPatcherProxy(
|
||||
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if isinstance(clone, ModelPatcherProxy):
|
||||
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
|
||||
if not IS_CHILD_PROCESS:
|
||||
return self._call_rpc("is_clone", clone)
|
||||
return False
|
||||
|
||||
def get_model_object(self, name: str) -> Any:
|
||||
return self._call_rpc("get_model_object", name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
data = self._call_rpc("get_model_options")
|
||||
import json
|
||||
|
||||
def _decode_keys(obj):
|
||||
if isinstance(obj, dict):
|
||||
new_d = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[17:]
|
||||
val = json.loads(json_str)
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
new_d[val] = _decode_keys(v)
|
||||
except:
|
||||
new_d[k] = _decode_keys(v)
|
||||
else:
|
||||
new_d[k] = _decode_keys(v)
|
||||
return new_d
|
||||
if isinstance(obj, list):
|
||||
return [_decode_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
return _decode_keys(data)
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_model_options", value)
|
||||
|
||||
def apply_hooks(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("apply_hooks", hooks)
|
||||
|
||||
def prepare_state(self, timestep: Any) -> Any:
|
||||
return self._call_rpc("prepare_state", timestep)
|
||||
|
||||
def restore_hook_patches(self) -> None:
|
||||
self._call_rpc("restore_hook_patches")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
|
||||
self._call_rpc("unpatch_hooks", whitelist_keys_set)
|
||||
|
||||
def model_patches_to(self, device: Any) -> Any:
|
||||
return self._call_rpc("model_patches_to", device)
|
||||
|
||||
def partially_load(
|
||||
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"partially_load", device, extra_memory, force_patch_weights
|
||||
)
|
||||
|
||||
def partially_unload(
|
||||
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
|
||||
) -> int:
|
||||
return self._call_rpc(
|
||||
"partially_unload", device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
def load(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
def patch_model(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
self._call_rpc(
|
||||
"patch_model",
|
||||
device_to,
|
||||
lowvram_model_memory,
|
||||
load_weights,
|
||||
force_patch_weights,
|
||||
)
|
||||
return self
|
||||
|
||||
def unpatch_model(
|
||||
self, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._call_rpc("unpatch_model", device_to, unpatch_weights)
|
||||
|
||||
def detach(self, unpatch_all: bool = True) -> Any:
|
||||
self._call_rpc("detach", unpatch_all)
|
||||
return self.model
|
||||
|
||||
def _cpu_tensor_bytes(self, obj: Any) -> int:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device.type == "cpu":
|
||||
return obj.nbytes
|
||||
return 0
|
||||
if isinstance(obj, dict):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj)
|
||||
return 0
|
||||
|
||||
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
|
||||
if required_bytes <= 0:
|
||||
return True
|
||||
|
||||
import torch
|
||||
import comfy.model_management as model_management
|
||||
|
||||
target_raw = self.load_device
|
||||
try:
|
||||
if isinstance(target_raw, torch.device):
|
||||
target = target_raw
|
||||
elif isinstance(target_raw, str):
|
||||
target = torch.device(target_raw)
|
||||
elif isinstance(target_raw, int):
|
||||
target = torch.device(f"cuda:{target_raw}")
|
||||
else:
|
||||
target = torch.device(target_raw)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if target.type != "cuda":
|
||||
return True
|
||||
|
||||
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
|
||||
if model_management.get_free_memory(target) >= required:
|
||||
return True
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.free_memory(required, target, for_dynamic=True)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
|
||||
model_management.free_memory(required, target, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.load_models_gpu(
|
||||
[self],
|
||||
minimum_memory_required=required,
|
||||
)
|
||||
|
||||
return model_management.get_free_memory(target) >= required
|
||||
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
def _to_cuda(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
|
||||
return obj.to("cuda")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cuda(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cuda(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cuda(v) for v in obj)
|
||||
return obj
|
||||
|
||||
try:
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
return dict.fromkeys(keys, None)
|
||||
|
||||
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self._call_rpc("add_patches", *args, **kwargs)
|
||||
if isinstance(res, list):
|
||||
return [tuple(x) if isinstance(x, list) else x for x in res]
|
||||
return res
|
||||
|
||||
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
return self._call_rpc("get_key_patches", filter_prefix)
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
self._call_rpc("pin_weight_to_device", key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
self._call_rpc("unpin_weight", key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self._call_rpc("unpin_all_weights")
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
|
||||
return self._call_rpc(
|
||||
"calculate_weight", patches, weight, key, intermediate_dtype
|
||||
)
|
||||
|
||||
def inject_model(self) -> None:
|
||||
self._call_rpc("inject_model")
|
||||
|
||||
def eject_model(self) -> None:
|
||||
self._call_rpc("eject_model")
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
|
||||
return AutoPatcherEjector(
|
||||
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
|
||||
)
|
||||
|
||||
@property
|
||||
def is_injected(self) -> bool:
|
||||
return self._call_rpc("get_is_injected")
|
||||
|
||||
@property
|
||||
def skip_injection(self) -> bool:
|
||||
return self._call_rpc("get_skip_injection")
|
||||
|
||||
@skip_injection.setter
|
||||
def skip_injection(self, value: bool) -> None:
|
||||
self._call_rpc("set_skip_injection", value)
|
||||
|
||||
def clean_hooks(self) -> None:
|
||||
self._call_rpc("clean_hooks")
|
||||
|
||||
def pre_run(self) -> None:
|
||||
self._call_rpc("pre_run")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
try:
|
||||
self._call_rpc("cleanup")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcherProxy cleanup RPC failed for %s",
|
||||
self._instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def model(self) -> _InnerModelProxy:
|
||||
return _InnerModelProxy(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
_whitelisted_attrs = {
|
||||
"hook_patches_backup",
|
||||
"hook_backup",
|
||||
"cached_hook_patches",
|
||||
"current_hooks",
|
||||
"forced_hooks",
|
||||
"is_clip",
|
||||
"patches_uuid",
|
||||
"pinned",
|
||||
"attachments",
|
||||
"additional_models",
|
||||
"injections",
|
||||
"hook_patches",
|
||||
"model_lowvram",
|
||||
"model_loaded_weight_memory",
|
||||
"backup",
|
||||
"object_patches_backup",
|
||||
"weight_wrapper_patches",
|
||||
"weight_inplace_update",
|
||||
"force_cast_weights",
|
||||
}
|
||||
if name in _whitelisted_attrs:
|
||||
return self._call_rpc("get_patcher_attr", name)
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip: Optional[Any] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> tuple:
|
||||
clip_id = None
|
||||
if clip is not None:
|
||||
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
|
||||
result = self._call_rpc(
|
||||
"load_lora", lora_path, strength_model, clip_id, strength_clip
|
||||
)
|
||||
new_model = None
|
||||
if result.get("model_id"):
|
||||
new_model = ModelPatcherProxy(
|
||||
result["model_id"],
|
||||
self._registry,
|
||||
manage_lifecycle=not IS_CHILD_PROCESS,
|
||||
)
|
||||
new_clip = None
|
||||
if result.get("clip_id"):
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
new_clip = CLIPProxy(result["clip_id"])
|
||||
return (new_model, new_clip)
|
||||
|
||||
@property
|
||||
def load_device(self) -> Any:
|
||||
return self._call_rpc("get_load_device")
|
||||
|
||||
@property
|
||||
def offload_device(self) -> Any:
|
||||
return self._call_rpc("get_offload_device")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.load_device
|
||||
|
||||
def current_loaded_device(self) -> Any:
|
||||
return self._call_rpc("current_loaded_device")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._call_rpc("get_size")
|
||||
|
||||
def model_size(self) -> Any:
|
||||
return self._call_rpc("model_size")
|
||||
|
||||
def loaded_size(self) -> Any:
|
||||
return self._call_rpc("loaded_size")
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._call_rpc("lowvram_patch_counter")
|
||||
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
def get_free_memory(self, device: Any) -> Any:
|
||||
return self._call_rpc("get_free_memory", device)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload: int) -> Any:
|
||||
return self._call_rpc("partially_unload_ram", ram_to_unload)
|
||||
|
||||
def model_dtype(self) -> Any:
|
||||
res = self._call_rpc("model_dtype")
|
||||
if isinstance(res, str) and res.startswith("torch."):
|
||||
try:
|
||||
import torch
|
||||
|
||||
attr = res.split(".")[-1]
|
||||
if hasattr(torch, attr):
|
||||
return getattr(torch, attr)
|
||||
except ImportError:
|
||||
pass
|
||||
return res
|
||||
|
||||
@property
|
||||
def hook_mode(self) -> Any:
|
||||
return self._call_rpc("get_hook_mode")
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", value)
|
||||
|
||||
def set_model_sampler_cfg_function(
|
||||
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_cfg_function",
|
||||
sampler_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_post_cfg_function(
|
||||
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_post_cfg_function",
|
||||
post_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(
|
||||
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_pre_cfg_function",
|
||||
pre_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
|
||||
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
|
||||
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
|
||||
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
|
||||
|
||||
def set_model_patch(self, patch: Any, name: str) -> None:
|
||||
self._call_rpc("set_model_patch", patch, name)
|
||||
|
||||
def set_model_patch_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_patch_replace",
|
||||
patch,
|
||||
name,
|
||||
block_name,
|
||||
number,
|
||||
transformer_index,
|
||||
)
|
||||
|
||||
def set_model_attn1_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn1", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn2_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn2", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def set_model_emb_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "emb_patch")
|
||||
|
||||
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(
|
||||
self,
|
||||
scale_x=1.0,
|
||||
shift_x=0.0,
|
||||
scale_y=1.0,
|
||||
shift_y=0.0,
|
||||
scale_t=1.0,
|
||||
shift_t=0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
options = {
|
||||
"scale_x": scale_x,
|
||||
"shift_x": shift_x,
|
||||
"scale_y": scale_y,
|
||||
"shift_y": shift_y,
|
||||
"scale_t": scale_t,
|
||||
"shift_t": shift_t,
|
||||
}
|
||||
options.update(kwargs)
|
||||
self._call_rpc("set_model_rope_options", options)
|
||||
|
||||
def set_model_compute_dtype(self, dtype: Any) -> None:
|
||||
self._call_rpc("set_model_compute_dtype", dtype)
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any) -> None:
|
||||
self._call_rpc("add_object_patch", name, obj)
|
||||
|
||||
def add_weight_wrapper(self, name: str, function: Any) -> None:
|
||||
self._call_rpc("add_weight_wrapper", name, function)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
|
||||
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
|
||||
|
||||
@property
|
||||
def wrappers(self) -> Any:
|
||||
return self._call_rpc("get_wrappers")
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
|
||||
self._call_rpc("add_callback_with_key", call_type, key, callback)
|
||||
|
||||
def add_callback(self, call_type: str, callback: Any) -> None:
|
||||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_callbacks_with_key", call_type, key)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Any:
|
||||
return self._call_rpc("get_callbacks")
|
||||
|
||||
def set_attachments(self, key: str, attachment: Any) -> None:
|
||||
self._call_rpc("set_attachments", key, attachment)
|
||||
|
||||
def get_attachment(self, key: str) -> Any:
|
||||
return self._call_rpc("get_attachment", key)
|
||||
|
||||
def remove_attachments(self, key: str) -> None:
|
||||
self._call_rpc("remove_attachments", key)
|
||||
|
||||
def set_injections(self, key: str, injections: Any) -> None:
|
||||
self._call_rpc("set_injections", key, injections)
|
||||
|
||||
def get_injections(self, key: str) -> Any:
|
||||
return self._call_rpc("get_injections", key)
|
||||
|
||||
def remove_injections(self, key: str) -> None:
|
||||
self._call_rpc("remove_injections", key)
|
||||
|
||||
def set_additional_models(self, key: str, models: Any) -> None:
|
||||
ids = [m._instance_id for m in models]
|
||||
self._call_rpc("set_additional_models", key, ids)
|
||||
|
||||
def remove_additional_models(self, key: str) -> None:
|
||||
self._call_rpc("remove_additional_models", key)
|
||||
|
||||
def get_nested_additional_models(self) -> Any:
|
||||
return self._call_rpc("get_nested_additional_models")
|
||||
|
||||
def get_additional_models(self) -> List[ModelPatcherProxy]:
|
||||
ids = self._call_rpc("get_additional_models")
|
||||
return [
|
||||
ModelPatcherProxy(
|
||||
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
for mid in ids
|
||||
]
|
||||
|
||||
def model_patches_models(self) -> Any:
|
||||
return self._call_rpc("model_patches_models")
|
||||
|
||||
@property
|
||||
def parent(self) -> Any:
|
||||
return self._call_rpc("get_parent")
|
||||
|
||||
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
if name in (
|
||||
"model_config",
|
||||
"latent_format",
|
||||
"model_type",
|
||||
"current_weight_patches_uuid",
|
||||
):
|
||||
return self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if name == "load_device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "load_device")
|
||||
if name == "device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "device")
|
||||
if name == "current_patcher":
|
||||
return ModelPatcherProxy(
|
||||
self._parent._instance_id,
|
||||
self._parent._registry,
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if name == "model_sampling":
|
||||
return self._parent._call_rpc("get_model_object", "model_sampling")
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
)
|
||||
if name == "extra_conds":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds", a, k
|
||||
)
|
||||
if name == "memory_required":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_memory_required", a, k
|
||||
)
|
||||
if name == "apply_model":
|
||||
# Delegate to parent's method to get the CPU->CUDA optimization
|
||||
return self._parent.apply_model
|
||||
if name == "process_latent_in":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
|
||||
if name == "process_latent_out":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
|
||||
if name == "scale_latent_inpaint":
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
875
comfy/isolation/model_patcher_proxy_registry.py
Normal file
875
comfy/isolation/model_patcher_proxy_registry.py
Normal file
@@ -0,0 +1,875 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import
|
||||
# RPC server for ModelPatcher isolation (child process)
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from typing import Any, Optional, List
|
||||
|
||||
try:
|
||||
from comfy.model_patcher import AutoPatcherEjector
|
||||
except ImportError:
|
||||
|
||||
class AutoPatcherEjector:
|
||||
def __init__(self, model, skip_and_inject_on_exit_only=False):
|
||||
self.model = model
|
||||
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||
self.prev_skip_injection = False
|
||||
self.was_injected = False
|
||||
|
||||
def __enter__(self):
|
||||
self.was_injected = False
|
||||
self.prev_skip_injection = self.model.skip_injection
|
||||
if self.skip_and_inject_on_exit_only:
|
||||
self.model.skip_injection = True
|
||||
if self.model.is_injected:
|
||||
self.model.eject_model()
|
||||
self.was_injected = True
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.skip_and_inject_on_exit_only:
|
||||
self.model.skip_injection = self.prev_skip_injection
|
||||
self.model.inject_model()
|
||||
if self.was_injected and not self.model.skip_injection:
|
||||
self.model.inject_model()
|
||||
self.model.skip_injection = self.prev_skip_injection
|
||||
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "model"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._pending_cleanup_ids: set[str] = set()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
instance = self._get_instance(instance_id)
|
||||
new_model = instance.clone()
|
||||
return self.register(new_model)
|
||||
|
||||
async def is_clone(self, instance_id: str, other: Any) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(other, "model"):
|
||||
return instance.is_clone(other)
|
||||
return False
|
||||
|
||||
async def get_model_object(self, instance_id: str, name: str) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if name == "model":
|
||||
return f"<ModelObject: {type(instance.model).__name__}>"
|
||||
result = instance.get_model_object(name)
|
||||
if name == "model_sampling":
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingRegistry,
|
||||
ModelSamplingProxy,
|
||||
)
|
||||
|
||||
registry = ModelSamplingRegistry()
|
||||
sampling_id = registry.register(result)
|
||||
return ModelSamplingProxy(sampling_id, registry)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def get_model_options(self, instance_id: str) -> dict:
|
||||
instance = self._get_instance(instance_id)
|
||||
import copy
|
||||
|
||||
opts = copy.deepcopy(instance.model_options)
|
||||
return self._sanitize_rpc_result(opts)
|
||||
|
||||
async def set_model_options(self, instance_id: str, options: dict) -> None:
|
||||
self._get_instance(instance_id).model_options = options
|
||||
|
||||
async def get_patcher_attr(self, instance_id: str, name: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), name, None)
|
||||
)
|
||||
|
||||
async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
sd_keys = instance.model.state_dict().keys()
|
||||
return dict.fromkeys(sd_keys, None)
|
||||
|
||||
def _sanitize_rpc_result(self, obj, seen=None):
|
||||
if seen is None:
|
||||
seen = set()
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (bool, int, float, str)):
|
||||
if isinstance(obj, str) and len(obj) > 500000:
|
||||
return f"<Truncated String len={len(obj)}>"
|
||||
return obj
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return None
|
||||
seen.add(obj_id)
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [self._sanitize_rpc_result(x, seen) for x in obj]
|
||||
if isinstance(obj, set):
|
||||
return [self._sanitize_rpc_result(x, seen) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
new_dict = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, tuple):
|
||||
import json
|
||||
|
||||
try:
|
||||
key_str = "__pyisolate_key__" + json.dumps(list(k))
|
||||
new_dict[key_str] = self._sanitize_rpc_result(v, seen)
|
||||
except Exception:
|
||||
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
|
||||
else:
|
||||
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
|
||||
return new_dict
|
||||
if (
|
||||
hasattr(obj, "__dict__")
|
||||
and not hasattr(obj, "__get__")
|
||||
and not hasattr(obj, "__call__")
|
||||
):
|
||||
return self._sanitize_rpc_result(obj.__dict__, seen)
|
||||
if hasattr(obj, "items") and hasattr(obj, "get"):
|
||||
return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()}
|
||||
return None
|
||||
|
||||
async def get_load_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).load_device
|
||||
|
||||
async def get_offload_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).offload_device
|
||||
|
||||
async def current_loaded_device(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).current_loaded_device()
|
||||
|
||||
async def get_size(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).size
|
||||
|
||||
async def model_size(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_size()
|
||||
|
||||
async def loaded_size(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).loaded_size()
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def lowvram_patch_counter(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).lowvram_patch_counter()
|
||||
|
||||
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
|
||||
return self._get_instance(instance_id).memory_required(input_shape)
|
||||
|
||||
async def is_dynamic(self, instance_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "is_dynamic"):
|
||||
return bool(instance.is_dynamic())
|
||||
return False
|
||||
|
||||
async def get_free_memory(self, instance_id: str, device: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "get_free_memory"):
|
||||
return instance.get_free_memory(device)
|
||||
import comfy.model_management
|
||||
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "partially_unload_ram"):
|
||||
return instance.partially_unload_ram(ram_to_unload)
|
||||
return None
|
||||
|
||||
async def model_dtype(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_dtype()
|
||||
|
||||
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
|
||||
return self._get_instance(instance_id).model_patches_to(device)
|
||||
|
||||
async def partially_load(
|
||||
self,
|
||||
instance_id: str,
|
||||
device: Any,
|
||||
extra_memory: Any,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
)
|
||||
|
||||
async def partially_unload(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any,
|
||||
memory_to_free: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
) -> int:
|
||||
return self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
async def load(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
async def patch_model(
|
||||
self,
|
||||
instance_id: str,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
|
||||
async def unpatch_model(
|
||||
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
|
||||
|
||||
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
|
||||
self._get_instance(instance_id).detach(unpatch_all)
|
||||
|
||||
async def prepare_state(self, instance_id: str, timestep: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
cp = getattr(instance.model, "current_patcher", instance)
|
||||
if cp is None:
|
||||
cp = instance
|
||||
return cp.prepare_state(timestep)
|
||||
|
||||
async def pre_run(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).pre_run()
|
||||
|
||||
async def cleanup(self, instance_id: str) -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
|
||||
def sweep_pending_cleanup(self) -> int:
|
||||
removed = 0
|
||||
with self._lock:
|
||||
pending_ids = list(self._pending_cleanup_ids)
|
||||
self._pending_cleanup_ids.clear()
|
||||
for instance_id in pending_ids:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance is None:
|
||||
continue
|
||||
self._id_map.pop(id(instance), None)
|
||||
removed += 1
|
||||
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
def purge_all(self) -> int:
|
||||
with self._lock:
|
||||
removed = len(self._registry)
|
||||
self._registry.clear()
|
||||
self._id_map.clear()
|
||||
self._pending_cleanup_ids.clear()
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
async def apply_hooks(self, instance_id: str, hooks: Any) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
cp = getattr(instance.model, "current_patcher", instance)
|
||||
if cp is None:
|
||||
cp = instance
|
||||
return cp.apply_hooks(hooks=hooks)
|
||||
|
||||
async def clean_hooks(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).clean_hooks()
|
||||
|
||||
async def restore_hook_patches(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).restore_hook_patches()
|
||||
|
||||
async def unpatch_hooks(
|
||||
self, instance_id: str, whitelist_keys_set: Optional[set] = None
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set)
|
||||
|
||||
async def register_all_hook_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any,
|
||||
registered: Any,
|
||||
) -> None:
|
||||
from types import SimpleNamespace
|
||||
import comfy.hooks
|
||||
|
||||
instance = self._get_instance(instance_id)
|
||||
if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"):
|
||||
hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks
|
||||
new_hooks = comfy.hooks.HookGroup()
|
||||
if hasattr(hook_data, "hooks"):
|
||||
new_hooks.hooks = (
|
||||
hook_data["hooks"]
|
||||
if isinstance(hook_data, dict)
|
||||
else hook_data.hooks
|
||||
)
|
||||
hooks = new_hooks
|
||||
instance.register_all_hook_patches(
|
||||
hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
async def get_hook_mode(self, instance_id: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), "hook_mode", None)
|
||||
|
||||
async def set_hook_mode(self, instance_id: str, value: Any) -> None:
|
||||
setattr(self._get_instance(instance_id), "hook_mode", value)
|
||||
|
||||
async def inject_model(self, instance_id: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
try:
|
||||
instance.inject_model()
|
||||
except AttributeError as e:
|
||||
if "inject" in str(e):
|
||||
logger.error(
|
||||
"Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping."
|
||||
)
|
||||
return
|
||||
raise e
|
||||
|
||||
async def eject_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).eject_model()
|
||||
|
||||
async def get_is_injected(self, instance_id: str) -> bool:
|
||||
return self._get_instance(instance_id).is_injected
|
||||
|
||||
async def set_skip_injection(self, instance_id: str, value: bool) -> None:
|
||||
self._get_instance(instance_id).skip_injection = value
|
||||
|
||||
async def get_skip_injection(self, instance_id: str) -> bool:
|
||||
return self._get_instance(instance_id).skip_injection
|
||||
|
||||
async def set_model_sampler_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
sampler_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
if not callable(sampler_cfg_function):
|
||||
logger.error(
|
||||
f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping."
|
||||
)
|
||||
return
|
||||
self._get_instance(instance_id).set_model_sampler_cfg_function(
|
||||
sampler_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_post_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
post_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_post_cfg_function(
|
||||
post_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_pre_cfg_function(
|
||||
self,
|
||||
instance_id: str,
|
||||
pre_cfg_function: Any,
|
||||
disable_cfg1_optimization: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_pre_cfg_function(
|
||||
pre_cfg_function, disable_cfg1_optimization
|
||||
)
|
||||
|
||||
async def set_model_sampler_calc_cond_batch_function(
|
||||
self, instance_id: str, fn: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn)
|
||||
|
||||
async def set_model_unet_function_wrapper(
|
||||
self, instance_id: str, unet_wrapper_function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_unet_function_wrapper(
|
||||
unet_wrapper_function
|
||||
)
|
||||
|
||||
async def set_model_denoise_mask_function(
|
||||
self, instance_id: str, denoise_mask_function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_denoise_mask_function(
|
||||
denoise_mask_function
|
||||
)
|
||||
|
||||
async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None:
|
||||
self._get_instance(instance_id).set_model_patch(patch, name)
|
||||
|
||||
async def set_model_patch_replace(
|
||||
self,
|
||||
instance_id: str,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_patch_replace(
|
||||
patch, name, block_name, number, transformer_index
|
||||
)
|
||||
|
||||
async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_input_block_patch(patch)
|
||||
|
||||
async def set_model_input_block_patch_after_skip(
|
||||
self, instance_id: str, patch: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch)
|
||||
|
||||
async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_output_block_patch(patch)
|
||||
|
||||
async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_emb_patch(patch)
|
||||
|
||||
async def set_model_forward_timestep_embed_patch(
|
||||
self, instance_id: str, patch: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch)
|
||||
|
||||
async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_double_block_patch(patch)
|
||||
|
||||
async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_post_input_patch(patch)
|
||||
|
||||
async def set_model_rope_options(self, instance_id: str, options: dict) -> None:
|
||||
self._get_instance(instance_id).set_model_rope_options(**options)
|
||||
|
||||
async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None:
|
||||
self._get_instance(instance_id).set_model_compute_dtype(dtype)
|
||||
|
||||
async def clone_has_same_weights_by_id(
|
||||
self, instance_id: str, other_id: str
|
||||
) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
other = self._get_instance(other_id)
|
||||
if not other:
|
||||
return False
|
||||
return instance.clone_has_same_weights(other)
|
||||
|
||||
async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any:
|
||||
return self._get_instance(instance_id)._load_list(*args, **kwargs)
|
||||
|
||||
async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
other = self._get_instance(other_id)
|
||||
if hasattr(instance, "is_clone"):
|
||||
return instance.is_clone(other)
|
||||
return False
|
||||
|
||||
async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None:
|
||||
self._get_instance(instance_id).add_object_patch(name, obj)
|
||||
|
||||
async def add_weight_wrapper(
|
||||
self, instance_id: str, name: str, function: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_weight_wrapper(name, function)
|
||||
|
||||
async def add_wrapper_with_key(
|
||||
self, instance_id: str, wrapper_type: Any, key: str, fn: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn)
|
||||
|
||||
async def remove_wrappers_with_key(
|
||||
self, instance_id: str, wrapper_type: str, key: str
|
||||
) -> None:
|
||||
self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key)
|
||||
|
||||
async def get_wrappers(
|
||||
self, instance_id: str, wrapper_type: str = None, key: str = None
|
||||
) -> Any:
|
||||
if wrapper_type is None and key is None:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "wrappers", {})
|
||||
)
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_wrappers(wrapper_type, key)
|
||||
)
|
||||
|
||||
async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])(
|
||||
wrapper_type
|
||||
)
|
||||
)
|
||||
|
||||
async def add_callback_with_key(
|
||||
self, instance_id: str, call_type: str, key: str, callback: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).add_callback_with_key(call_type, key, callback)
|
||||
|
||||
async def remove_callbacks_with_key(
|
||||
self, instance_id: str, call_type: str, key: str
|
||||
) -> None:
|
||||
self._get_instance(instance_id).remove_callbacks_with_key(call_type, key)
|
||||
|
||||
async def get_callbacks(
|
||||
self, instance_id: str, call_type: str = None, key: str = None
|
||||
) -> Any:
|
||||
if call_type is None and key is None:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "callbacks", {})
|
||||
)
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_callbacks(call_type, key)
|
||||
)
|
||||
|
||||
async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])(
|
||||
call_type
|
||||
)
|
||||
)
|
||||
|
||||
async def set_attachments(
|
||||
self, instance_id: str, key: str, attachment: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_attachments(key, attachment)
|
||||
|
||||
async def get_attachment(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_attachment(key)
|
||||
)
|
||||
|
||||
async def remove_attachments(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_attachments(key)
|
||||
|
||||
async def set_injections(self, instance_id: str, key: str, injections: Any) -> None:
|
||||
self._get_instance(instance_id).set_injections(key, injections)
|
||||
|
||||
async def get_injections(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_injections(key)
|
||||
)
|
||||
|
||||
async def remove_injections(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_injections(key)
|
||||
|
||||
async def set_additional_models(
|
||||
self, instance_id: str, key: str, models: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_additional_models(key, models)
|
||||
|
||||
async def remove_additional_models(self, instance_id: str, key: str) -> None:
|
||||
self._get_instance(instance_id).remove_additional_models(key)
|
||||
|
||||
async def get_nested_additional_models(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_nested_additional_models()
|
||||
)
|
||||
|
||||
async def get_additional_models(self, instance_id: str) -> List[str]:
|
||||
models = self._get_instance(instance_id).get_additional_models()
|
||||
return [self.register(m) for m in models]
|
||||
|
||||
async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).get_additional_models_with_key(key)
|
||||
)
|
||||
|
||||
async def model_patches_models(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).model_patches_models()
|
||||
)
|
||||
|
||||
async def get_patches(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy())
|
||||
|
||||
async def get_object_patches(self, instance_id: str) -> Any:
|
||||
return self._sanitize_rpc_result(
|
||||
self._get_instance(instance_id).object_patches.copy()
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(
|
||||
self, instance_id: str, filter_prefix: Optional[str] = None
|
||||
) -> Any:
|
||||
res = self._get_instance(instance_id).get_key_patches()
|
||||
if filter_prefix:
|
||||
res = {k: v for k, v in res.items() if k.startswith(filter_prefix)}
|
||||
safe_res = {}
|
||||
for k, v in res.items():
|
||||
safe_res[k] = [
|
||||
f"<Tensor shape={t.shape} dtype={t.dtype}>"
|
||||
if hasattr(t, "shape")
|
||||
else str(t)
|
||||
for t in v
|
||||
]
|
||||
return safe_res
|
||||
|
||||
async def add_hook_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict):
|
||||
try:
|
||||
hook.hook_ref = tuple(sorted(hook.hook_ref.items()))
|
||||
except Exception:
|
||||
hook.hook_ref = None
|
||||
self._get_instance(instance_id).add_hook_patches(
|
||||
hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any:
|
||||
if hooks is not None and hasattr(hooks, "hooks"):
|
||||
for hook in getattr(hooks, "hooks", []):
|
||||
hook_ref = getattr(hook, "hook_ref", None)
|
||||
if isinstance(hook_ref, dict):
|
||||
try:
|
||||
hook.hook_ref = tuple(sorted(hook_ref.items()))
|
||||
except Exception:
|
||||
hook.hook_ref = None
|
||||
res = self._get_instance(instance_id).get_combined_hook_patches(hooks)
|
||||
return self._sanitize_rpc_result(res)
|
||||
|
||||
async def clear_cached_hook_weights(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).clear_cached_hook_weights()
|
||||
|
||||
async def prepare_hook_patches_current_keyframe(
|
||||
self, instance_id: str, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).prepare_hook_patches_current_keyframe(
|
||||
t, hook_group, model_options
|
||||
)
|
||||
|
||||
async def get_parent(self, instance_id: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), "parent", None)
|
||||
|
||||
async def patch_weight_to_device(
|
||||
self,
|
||||
instance_id: str,
|
||||
key: str,
|
||||
device_to: Any = None,
|
||||
inplace_update: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).patch_weight_to_device(
|
||||
key, device_to, inplace_update
|
||||
)
|
||||
|
||||
async def pin_weight_to_device(self, instance_id: str, key: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.pin_weight_to_device(key)
|
||||
|
||||
async def unpin_weight(self, instance_id: str, key: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.unpin_weight(key)
|
||||
|
||||
async def unpin_all_weights(self, instance_id: str) -> None:
|
||||
instance = self._get_instance(instance_id)
|
||||
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
|
||||
instance.pinned = set(instance.pinned)
|
||||
instance.unpin_all_weights()
|
||||
|
||||
async def calculate_weight(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
weight: Any,
|
||||
key: str,
|
||||
intermediate_dtype: Any = float,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).calculate_weight(
|
||||
patches, weight, key, intermediate_dtype
|
||||
)
|
||||
)
|
||||
|
||||
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
|
||||
try:
|
||||
return self._sanitize_rpc_result(
|
||||
getattr(self._get_instance(instance_id).model, name)
|
||||
)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
async def inner_model_memory_required(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
|
||||
|
||||
async def inner_model_extra_conds_shapes(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
|
||||
|
||||
async def inner_model_extra_conds(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
|
||||
async def inner_model_state_dict(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs)
|
||||
return {
|
||||
k: {"numel": v.numel(), "element_size": v.element_size()}
|
||||
for k, v in sd.items()
|
||||
}
|
||||
|
||||
async def inner_model_apply_model(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
return detach_if_grad(_move(result))
|
||||
|
||||
async def process_latent_in(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
|
||||
)
|
||||
|
||||
async def process_latent_out(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def scale_latent_inpaint(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def load_lora(
|
||||
self,
|
||||
instance_id: str,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip_id: Optional[str] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> dict:
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.isolation.clip_proxy import CLIPRegistry
|
||||
|
||||
model = self._get_instance(instance_id)
|
||||
clip = None
|
||||
if clip_id:
|
||||
clip = CLIPRegistry()._get_instance(clip_id)
|
||||
lora_full_path = folder_paths.get_full_path("loras", lora_path)
|
||||
if lora_full_path is None:
|
||||
raise ValueError(f"LoRA file not found: {lora_path}")
|
||||
lora = comfy.utils.load_torch_file(lora_full_path)
|
||||
new_model, new_clip = comfy.sd.load_lora_for_models(
|
||||
model, clip, lora, strength_model, strength_clip
|
||||
)
|
||||
new_model_id = self.register(new_model) if new_model else None
|
||||
new_clip_id = (
|
||||
CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None
|
||||
)
|
||||
return {"model_id": new_model_id, "clip_id": new_clip_id}
|
||||
154
comfy/isolation/model_patcher_proxy_utils.py
Normal file
154
comfy/isolation/model_patcher_proxy_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
|
||||
# Isolation utilities and serializers for ModelPatcherProxy
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
if not isolation_active:
|
||||
return model_patcher
|
||||
if is_child:
|
||||
return model_patcher
|
||||
if isinstance(model_patcher, ModelPatcherProxy):
|
||||
return model_patcher
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
model_id = registry.register(model_patcher)
|
||||
logger.debug(f"Isolated ModelPatcher: {model_id}")
|
||||
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
|
||||
|
||||
|
||||
def register_hooks_serializers(registry=None):
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
import comfy.hooks
|
||||
|
||||
if registry is None:
|
||||
registry = SerializerRegistry.get_instance()
|
||||
|
||||
def serialize_enum(obj):
|
||||
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
|
||||
|
||||
def deserialize_enum(data):
|
||||
cls_name, val_name = data["__enum__"].split(".")
|
||||
cls = getattr(comfy.hooks, cls_name)
|
||||
return cls[val_name]
|
||||
|
||||
registry.register("EnumHookType", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
|
||||
|
||||
def serialize_hook_group(obj):
|
||||
return {"__type__": "HookGroup", "hooks": obj.hooks}
|
||||
|
||||
def deserialize_hook_group(data):
|
||||
hg = comfy.hooks.HookGroup()
|
||||
for h in data["hooks"]:
|
||||
hg.add(h)
|
||||
return hg
|
||||
|
||||
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
|
||||
|
||||
def serialize_dict_state(obj):
|
||||
d = obj.__dict__.copy()
|
||||
d["__type__"] = type(obj).__name__
|
||||
if "custom_should_register" in d:
|
||||
del d["custom_should_register"]
|
||||
return d
|
||||
|
||||
def deserialize_dict_state_generic(cls):
|
||||
def _deserialize(data):
|
||||
h = cls()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
return _deserialize
|
||||
|
||||
def deserialize_hook_keyframe(data):
|
||||
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
|
||||
|
||||
def deserialize_hook_keyframe_group(data):
|
||||
h = comfy.hooks.HookKeyframeGroup()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register(
|
||||
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
|
||||
)
|
||||
|
||||
def deserialize_hook(data):
|
||||
h = comfy.hooks.Hook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("Hook", serialize_dict_state, deserialize_hook)
|
||||
|
||||
def deserialize_weight_hook(data):
|
||||
h = comfy.hooks.WeightHook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
|
||||
|
||||
def serialize_set(obj):
|
||||
return {"__set__": list(obj)}
|
||||
|
||||
def deserialize_set(data):
|
||||
return set(data["__set__"])
|
||||
|
||||
registry.register("set", serialize_set, deserialize_set)
|
||||
|
||||
try:
|
||||
from comfy.weight_adapter.lora import LoRAAdapter
|
||||
|
||||
def serialize_lora(obj):
|
||||
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
|
||||
|
||||
def deserialize_lora(data):
|
||||
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
|
||||
|
||||
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import uuid
|
||||
|
||||
def serialize_hook_ref(obj):
|
||||
return {
|
||||
"__hook_ref__": True,
|
||||
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
|
||||
}
|
||||
|
||||
def deserialize_hook_ref(data):
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
|
||||
return h
|
||||
|
||||
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register _HookRef: {e}")
|
||||
|
||||
|
||||
try:
|
||||
register_hooks_serializers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize hook serializers: {e}")
|
||||
253
comfy/isolation/model_sampling_proxy.py
Normal file
253
comfy/isolation/model_sampling_proxy.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor) and t.is_cuda:
|
||||
return t.device
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.device
|
||||
return None
|
||||
|
||||
|
||||
def _to_device(obj: Any, device: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device != device:
|
||||
return obj.to(device)
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_device(x, device) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_device(v, device) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
||||
|
||||
async def calculate_denoised(
|
||||
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.calculate_denoised(sigma, model_output, model_input)
|
||||
)
|
||||
|
||||
async def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
||||
)
|
||||
|
||||
async def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
||||
|
||||
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.timestep(sigma)
|
||||
|
||||
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.sigma(timestep)
|
||||
|
||||
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.percent_to_sigma(percent)
|
||||
|
||||
async def get_sigma_min(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_min)
|
||||
|
||||
async def get_sigma_max(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_max)
|
||||
|
||||
async def get_sigma_data(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_data)
|
||||
|
||||
async def get_sigmas(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigmas)
|
||||
|
||||
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
sampling = self._get_instance(instance_id)
|
||||
sampling.set_sigmas(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
_registry_class = ModelSamplingRegistry
|
||||
__module__ = "comfy.isolation.model_sampling_proxy"
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
||||
)
|
||||
else:
|
||||
registry = ModelSamplingRegistry()
|
||||
|
||||
class _LocalCaller:
|
||||
def calculate_input(
|
||||
self, instance_id: str, sigma: Any, noise: Any
|
||||
) -> Any:
|
||||
return registry.calculate_input(instance_id, sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
model_output: Any,
|
||||
model_input: Any,
|
||||
) -> Any:
|
||||
return registry.calculate_denoised(
|
||||
instance_id, sigma, model_output, model_input
|
||||
)
|
||||
|
||||
def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
return registry.noise_scaling(
|
||||
instance_id, sigma, noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
return registry.inverse_noise_scaling(
|
||||
instance_id, sigma, latent
|
||||
)
|
||||
|
||||
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
return registry.timestep(instance_id, sigma)
|
||||
|
||||
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
return registry.sigma(instance_id, timestep)
|
||||
|
||||
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
return registry.percent_to_sigma(instance_id, percent)
|
||||
|
||||
def get_sigma_min(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_min(instance_id)
|
||||
|
||||
def get_sigma_max(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_max(instance_id)
|
||||
|
||||
def get_sigma_data(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_data(instance_id)
|
||||
|
||||
def get_sigmas(self, instance_id: str) -> Any:
|
||||
return registry.get_sigmas(instance_id)
|
||||
|
||||
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
return registry.set_sigmas(instance_id, sigmas)
|
||||
|
||||
self._rpc_caller = _LocalCaller()
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
if asyncio.iscoroutine(result):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
return self._call("get_sigma_min")
|
||||
|
||||
@property
|
||||
def sigma_max(self) -> Any:
|
||||
return self._call("get_sigma_max")
|
||||
|
||||
@property
|
||||
def sigma_data(self) -> Any:
|
||||
return self._call("get_sigma_data")
|
||||
|
||||
@property
|
||||
def sigmas(self) -> Any:
|
||||
return self._call("get_sigmas")
|
||||
|
||||
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
||||
return self._call("calculate_input", sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
return self._call("calculate_denoised", sigma, model_output, model_input)
|
||||
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
return self._call("inverse_noise_scaling", sigma, latent)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
def sigma(self, timestep: Any) -> Any:
|
||||
return self._call("sigma", timestep)
|
||||
|
||||
def percent_to_sigma(self, percent: float) -> Any:
|
||||
return self._call("percent_to_sigma", percent)
|
||||
|
||||
def set_sigmas(self, sigmas: Any) -> None:
|
||||
return self._call("set_sigmas", sigmas)
|
||||
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IS_CHILD_PROCESS",
|
||||
"BaseRegistry",
|
||||
"BaseProxy",
|
||||
"get_thread_loop",
|
||||
"run_coro_in_new_loop",
|
||||
"detach_if_grad",
|
||||
]
|
||||
213
comfy/isolation/proxies/base.py
Normal file
213
comfy/isolation/proxies/base.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
_thread_local = threading.local()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||
loop = getattr(_thread_local, "loop", None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||
result_box: Dict[str, Any] = {}
|
||||
exc_box: Dict[str, BaseException] = {}
|
||||
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result_box["value"] = loop.run_until_complete(coro)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exc_box["exc"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
if "exc" in exc_box:
|
||||
raise exc_box["exc"]
|
||||
return result_box.get("value")
|
||||
|
||||
|
||||
def detach_if_grad(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.detach() if obj.requires_grad else obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(detach_if_grad(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||
_type_prefix: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||
super().__init__()
|
||||
self._registry: Dict[str, T] = {}
|
||||
self._id_map: Dict[int, str] = {}
|
||||
self._counter = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, instance: T) -> str:
|
||||
with self._lock:
|
||||
obj_id = id(instance)
|
||||
if obj_id in self._id_map:
|
||||
return self._id_map[obj_id]
|
||||
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||
self._counter += 1
|
||||
self._registry[instance_id] = instance
|
||||
self._id_map[obj_id] = instance_id
|
||||
return instance_id
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._lock:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance:
|
||||
self._id_map.pop(id(instance), None)
|
||||
|
||||
def _get_instance(self, instance_id: str) -> T:
|
||||
if IS_CHILD_PROCESS:
|
||||
raise RuntimeError(
|
||||
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||
)
|
||||
with self._lock:
|
||||
instance = self._registry.get(instance_id)
|
||||
if instance is None:
|
||||
raise ValueError(f"{instance_id} not found")
|
||||
return instance
|
||||
|
||||
|
||||
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
|
||||
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _GLOBAL_LOOP
|
||||
_GLOBAL_LOOP = loop
|
||||
|
||||
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
registry: Optional[Any] = None,
|
||||
manage_lifecycle: bool = False,
|
||||
) -> None:
|
||||
self._instance_id = instance_id
|
||||
self._rpc_caller: Optional[Any] = None
|
||||
self._registry = registry if registry is not None else self._registry_class()
|
||||
self._manage_lifecycle = manage_lifecycle
|
||||
self._cleaned_up = False
|
||||
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||
self._finalizer = weakref.finalize(
|
||||
self, self._registry.unregister_sync, instance_id
|
||||
)
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is None:
|
||||
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
# If we are already in the global loop, we can't block on it?
|
||||
# Actually, this method is synchronous (__getattr__ -> lambda).
|
||||
# If called from async context in main loop, we need to handle that.
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
||||
# But proxies are often called from sync code.
|
||||
# If called from sync code in main loop, creating a new loop is bad.
|
||||
# But we can't await `coro`.
|
||||
# This implies proxies MUST be awaited if called from async context?
|
||||
# Existing code used `run_coro_in_new_loop` which is weird.
|
||||
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
||||
# we use run_coroutine_threadsafe.
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result()
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self._instance_id = state["_instance_id"]
|
||||
self._rpc_caller = None
|
||||
self._registry = self._registry_class()
|
||||
self._manage_lifecycle = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
finalizer = getattr(self, "_finalizer", None)
|
||||
if finalizer is not None:
|
||||
finalizer.detach()
|
||||
self._registry.unregister_sync(self._instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||
|
||||
|
||||
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(method_name, *args, **kwargs)
|
||||
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
|
||||
import folder_paths
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class FolderPathsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for folder_paths.
|
||||
Uses __getattr__ for most lookups, with explicit handling for
|
||||
mutable collections to ensure efficient by-value transfer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(folder_paths, name)
|
||||
|
||||
# Return dict snapshots (avoid RPC chatter)
|
||||
@property
|
||||
def folder_names_and_paths(self) -> Dict:
|
||||
return dict(folder_paths.folder_names_and_paths)
|
||||
|
||||
@property
|
||||
def extension_mimetypes_cache(self) -> Dict:
|
||||
return dict(folder_paths.extension_mimetypes_cache)
|
||||
|
||||
@property
|
||||
def filename_list_cache(self) -> Dict:
|
||||
return dict(folder_paths.filename_list_cache)
|
||||
98
comfy/isolation/proxies/helper_proxies.py
Normal file
98
comfy/isolation/proxies/helper_proxies.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class AnyTypeProxy(str):
|
||||
"""Replacement for custom AnyType objects used by some nodes."""
|
||||
|
||||
def __new__(cls, value: str = "*"):
|
||||
return super().__new__(cls, value)
|
||||
|
||||
def __ne__(self, other): # type: ignore[override]
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputProxy(dict):
|
||||
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||
|
||||
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||
super().__init__()
|
||||
self.type = flex_type
|
||||
if data:
|
||||
self.update(data)
|
||||
|
||||
def __getitem__(self, key): # type: ignore[override]
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key): # type: ignore[override]
|
||||
return True
|
||||
|
||||
|
||||
class ByPassTypeTupleProxy(tuple):
|
||||
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||
|
||||
def __new__(cls, values):
|
||||
return super().__new__(cls, values)
|
||||
|
||||
def __getitem__(self, index): # type: ignore[override]
|
||||
if index >= len(self):
|
||||
return AnyTypeProxy("*")
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
def _restore_special_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__pyisolate_any_type__"):
|
||||
return AnyTypeProxy(value.get("value", "*"))
|
||||
if value.get("__pyisolate_flexible_optional__"):
|
||||
flex_type = _restore_special_value(value.get("type"))
|
||||
data_raw = value.get("data")
|
||||
data = (
|
||||
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||
if isinstance(data_raw, dict)
|
||||
else {}
|
||||
)
|
||||
return FlexibleOptionalInputProxy(flex_type, data)
|
||||
if value.get("__pyisolate_tuple__") is not None:
|
||||
return tuple(
|
||||
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||
)
|
||||
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||
return ByPassTypeTupleProxy(
|
||||
tuple(
|
||||
_restore_special_value(v)
|
||||
for v in value["__pyisolate_bypass_tuple__"]
|
||||
)
|
||||
)
|
||||
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_restore_special_value(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
restored: Dict[str, object] = {}
|
||||
for section, entries in raw.items():
|
||||
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||
restored[section] = _restore_special_value(entries)
|
||||
elif isinstance(entries, dict):
|
||||
restored[section] = {
|
||||
k: _restore_special_value(v) for k, v in entries.items()
|
||||
}
|
||||
else:
|
||||
restored[section] = _restore_special_value(entries)
|
||||
return restored
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnyTypeProxy",
|
||||
"FlexibleOptionalInputProxy",
|
||||
"ByPassTypeTupleProxy",
|
||||
"restore_input_types",
|
||||
]
|
||||
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import comfy.model_management as mm
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class ModelManagementProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for comfy.model_management.
|
||||
Uses __getattr__ to forward all calls to the underlying module,
|
||||
reducing maintenance burden.
|
||||
"""
|
||||
|
||||
# Explicitly expose Enums/Classes as properties
|
||||
@property
|
||||
def VRAMState(self):
|
||||
return mm.VRAMState
|
||||
|
||||
@property
|
||||
def CPUState(self):
|
||||
return mm.CPUState
|
||||
|
||||
@property
|
||||
def OOM_EXCEPTION(self):
|
||||
return mm.OOM_EXCEPTION
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward all other attribute access to the module."""
|
||||
return getattr(mm, name)
|
||||
35
comfy/isolation/proxies/progress_proxy.py
Normal file
35
comfy/isolation/proxies/progress_proxy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton:
|
||||
pass
|
||||
|
||||
|
||||
from comfy_execution.progress import get_progress_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressProxy(ProxiedSingleton):
|
||||
def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ProgressProxy"]
|
||||
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||
"""Stateless RPC Implementation for PromptServer.
|
||||
|
||||
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||
- Host: PromptServerService (RPC Handler)
|
||||
- Child: PromptServerStub (Interface Implementation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
# IMPORTS
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_PREFIX = "[Isolation:C<->H]"
|
||||
|
||||
# ...
|
||||
|
||||
# =============================================================================
|
||||
# CHILD SIDE: PromptServerStub
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerStub:
|
||||
"""Stateless Stub for PromptServer."""
|
||||
|
||||
# Masquerade as the real server module
|
||||
__module__ = "server"
|
||||
|
||||
_instance: Optional["PromptServerStub"] = None
|
||||
_rpc: Optional[Any] = None # This will be the Caller object
|
||||
_source_file: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = RouteStub(self)
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
"""Inject RPC client (called by adapter.py or manually)."""
|
||||
# Create caller for HOST Service
|
||||
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||
# We target the Host Service Class
|
||||
target_id = "PromptServerService"
|
||||
# We need to pass a class to create_caller? Usually yes.
|
||||
# But we don't have the Service class imported here necessarily (if running on child).
|
||||
# pyisolate check verify_service type?
|
||||
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||
# We need a dummy class with right name?
|
||||
# Or just rely on string ID if create_caller supports it?
|
||||
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||
# But wait, PromptServerStub is the *Local* class.
|
||||
# We want to call *Remote* class.
|
||||
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||
# The first arg is 'service_cls'.
|
||||
cls._rpc = rpc.create_caller(
|
||||
PromptServerService, target_id
|
||||
) # We import Service below?
|
||||
|
||||
# We need PromptServerService available for the create_caller call?
|
||||
# Or just use the Stub class if ID matches?
|
||||
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||
|
||||
@property
|
||||
def instance(self) -> "PromptServerStub":
|
||||
return self
|
||||
|
||||
# ... Compatibility ...
|
||||
@classmethod
|
||||
def _get_source_file(cls) -> str:
|
||||
if cls._source_file is None:
|
||||
import folder_paths
|
||||
|
||||
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||
return cls._source_file
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self._get_source_file()
|
||||
|
||||
# --- Properties ---
|
||||
@property
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||
)
|
||||
|
||||
# --- UI Communication (RPC Delegates) ---
|
||||
async def send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send_sync(event, data, sid)
|
||||
|
||||
async def send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send(event, data, sid)
|
||||
|
||||
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||
if self._rpc:
|
||||
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||
# We must schedule it?
|
||||
# Or use fire_remote equivalent?
|
||||
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||
# But UtilsProxy hook wrapper creates task.
|
||||
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||
except RuntimeError:
|
||||
pass # Sync context without loop?
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
class RouteStub:
|
||||
"""Simulates aiohttp.web.RouteTableDef."""
|
||||
|
||||
def __init__(self, stub: PromptServerStub):
|
||||
self._stub = stub
|
||||
|
||||
def get(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST SIDE: PromptServerService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
from server import PromptServer
|
||||
|
||||
return PromptServer.instance
|
||||
|
||||
async def ui_send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send_sync(event, data, sid)
|
||||
|
||||
async def ui_send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send(event, data, sid)
|
||||
|
||||
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||
# Made async to be awaitable by RPC layer
|
||||
self.server.send_progress_text(text, node_id, sid)
|
||||
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
req_data = {
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"query": dict(request.query),
|
||||
}
|
||||
if request.can_read_body:
|
||||
req_data["text"] = await request.text()
|
||||
|
||||
try:
|
||||
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||
result = await child_handler_proxy(req_data)
|
||||
|
||||
# 3. Serialize Response
|
||||
return self._serialize_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
|
||||
def _serialize_response(self, result: Any) -> web.Response:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
if isinstance(result, web.Response):
|
||||
return result
|
||||
# Handle dict (json)
|
||||
if isinstance(result, dict):
|
||||
return web.json_response(result)
|
||||
# Handle string
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
import comfy.utils
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class UtilsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Proxy for comfy.utils.
|
||||
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||
from isolated nodes reach the host.
|
||||
"""
|
||||
|
||||
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
# Create caller using class name as ID (standard for Singletons)
|
||||
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||
|
||||
async def progress_bar_hook(
|
||||
self,
|
||||
value: int,
|
||||
total: int,
|
||||
preview: Optional[bytes] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Host-side implementation: forwards the call to the real global hook.
|
||||
Child-side: this method call is intercepted by RPC and sent to host.
|
||||
"""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
# Manual RPC dispatch for Child process
|
||||
# Use class-level RPC storage (Static Injection)
|
||||
if UtilsProxy._rpc:
|
||||
return await UtilsProxy._rpc.progress_bar_hook(
|
||||
value, total, preview, node_id
|
||||
)
|
||||
|
||||
# Fallback channel: global child rpc
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
get_child_rpc_instance()
|
||||
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
|
||||
# but we need a caller. For now, just pass to avoid crashing.
|
||||
pass
|
||||
except (ImportError, LookupError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# Host Execution
|
||||
if comfy.utils.PROGRESS_BAR_HOOK is not None:
|
||||
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||
|
||||
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||
"""Forward hook registration (though usually not needed from child)."""
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpcBridge:
|
||||
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||
|
||||
If an event loop is already running, the coroutine is executed on a fresh
|
||||
thread with its own loop to avoid nested run_until_complete errors.
|
||||
"""
|
||||
|
||||
def run_sync(self, maybe_coro):
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
return maybe_coro
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
result_container = {}
|
||||
exc_container = {}
|
||||
|
||||
def _runner():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||
except Exception as exc: # pragma: no cover
|
||||
exc_container["error"] = exc
|
||||
finally:
|
||||
try:
|
||||
new_loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if "error" in exc_container:
|
||||
raise exc_container["error"]
|
||||
return result_container.get("value")
|
||||
|
||||
return asyncio.run(maybe_coro)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user