mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-02 22:29:01 +00:00
Compare commits
227 Commits
feature/fr
...
remove-cac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
699659c06e | ||
|
|
95e1059661 | ||
|
|
80d49441e5 | ||
|
|
9d0e114ee3 | ||
|
|
ac4412d0fa | ||
|
|
94f1a1cc9d | ||
|
|
e721e24136 | ||
|
|
25ec3d96a3 | ||
|
|
1f1ec377ce | ||
|
|
0a7f8e11b6 | ||
|
|
35e9fce775 | ||
|
|
c7f7d52b68 | ||
|
|
08b26ed7c2 | ||
|
|
b233dbe0bc | ||
|
|
3811780e4f | ||
|
|
3dd10a59c0 | ||
|
|
88d05fe483 | ||
|
|
fd41ec97cc | ||
|
|
420e900f69 | ||
|
|
38ca94599f | ||
|
|
74b5a337dc | ||
|
|
8a4d85c708 | ||
|
|
a4522017c5 | ||
|
|
907e5dcbbf | ||
|
|
7253531670 | ||
|
|
e14b04478c | ||
|
|
eb8737d675 | ||
|
|
0467f690a8 | ||
|
|
4f5b7dbf1f | ||
|
|
3ebe1ac22e | ||
|
|
befa83d434 | ||
|
|
33f83d53ae | ||
|
|
b874bd2b8c | ||
|
|
0aa02453bb | ||
|
|
599f9c5010 | ||
|
|
11fefa58e9 | ||
|
|
d8090013b8 | ||
|
|
048dd2f321 | ||
|
|
84aba95e03 | ||
|
|
9b1c63eb69 | ||
|
|
7a7debcaf1 | ||
|
|
dba2766e53 | ||
|
|
caa43d2395 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 | ||
|
|
5fb8f06495 | ||
|
|
5a182bfaf1 | ||
|
|
f394af8d0f | ||
|
|
aeb5bdc8f6 | ||
|
|
64953bda0a | ||
|
|
b254cecd03 | ||
|
|
1bb956fb66 | ||
|
|
96d6bd1a4a | ||
|
|
5f2117528a | ||
|
|
0301ccf745 | ||
|
|
4d172e9ad7 | ||
|
|
5632b2df9d | ||
|
|
2687652530 | ||
|
|
6d11cc7354 | ||
|
|
f262444dd4 | ||
|
|
239ddd3327 | ||
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf | ||
|
|
19236edfa4 | ||
|
|
73c3f86973 | ||
|
|
262abf437b | ||
|
|
5284e6bf69 | ||
|
|
44f8598521 | ||
|
|
fe52843fe5 | ||
|
|
c39653163d | ||
|
|
18927538a1 | ||
|
|
8a6fbc2dc2 | ||
|
|
b44fc4c589 | ||
|
|
4454fab7f0 | ||
|
|
1978f59ffd | ||
|
|
88e6370527 | ||
|
|
c0370044cd | ||
|
|
ecd2a19661 | ||
|
|
2c1d06a4e3 | ||
|
|
e2c71ceb00 | ||
|
|
596ed68691 | ||
|
|
ce4a1ab48d | ||
|
|
e1ede29d82 | ||
|
|
df1e5e8514 | ||
|
|
dc9822b7df | ||
|
|
712efb466b | ||
|
|
726af73867 | ||
|
|
831351a29e | ||
|
|
e1add563f9 | ||
|
|
8902907d7a | ||
|
|
e03fe8b591 | ||
|
|
ae79e33345 | ||
|
|
117e214354 | ||
|
|
4a93a62371 | ||
|
|
66c18522fb | ||
|
|
e5ae670a40 | ||
|
|
3fe61cedda | ||
|
|
2a4328d639 | ||
|
|
d297a749a2 | ||
|
|
2b7cc7e3b6 | ||
|
|
4993411fd9 | ||
|
|
2c7cef4a23 | ||
|
|
76a7fa96db | ||
|
|
cdcf4119b3 | ||
|
|
dbe70b6821 | ||
|
|
00fff6019e | ||
|
|
123a7874a9 | ||
|
|
f719f9c062 | ||
|
|
fe053ba5eb | ||
|
|
6648ab68bc | ||
|
|
6615db925c | ||
|
|
8ca842a8ed | ||
|
|
c1b63a7e78 | ||
|
|
349a636a2b | ||
|
|
a4be04c5d7 | ||
|
|
baf8c87455 | ||
|
|
62315fbb15 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd | ||
|
|
a1c101f861 | ||
|
|
c2d7f07dbf | ||
|
|
458292fef0 | ||
|
|
6555dc65b8 | ||
|
|
2b70ab9ad0 | ||
|
|
00efcc6cd0 | ||
|
|
cb459573c8 | ||
|
|
35183543e0 | ||
|
|
a246cc02b2 | ||
|
|
a50c32d63f | ||
|
|
6125b80979 | ||
|
|
c8fcbd66ee | ||
|
|
26dd7eb421 | ||
|
|
e77b34dfea | ||
|
|
ef73070ea4 | ||
|
|
d30c609f5a | ||
|
|
5087f1d497 | ||
|
|
a31681564d | ||
|
|
855849c658 | ||
|
|
fe2511468d | ||
|
|
3be0175166 | ||
|
|
b8315e66cb | ||
|
|
ab1050bec3 | ||
|
|
fb23935c11 | ||
|
|
85fc35e8fa | ||
|
|
223364743c | ||
|
|
affe881354 | ||
|
|
f5030e26fd | ||
|
|
66e1b07402 | ||
|
|
be4345d1c9 | ||
|
|
3c1a1a2df8 | ||
|
|
ba5bf3f1a8 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
dd86b15521 | ||
|
|
021ba20719 | ||
|
|
b60be02aaf | ||
|
|
2b5da3b72e | ||
|
|
794d05bdb1 | ||
|
|
361b9a82a3 | ||
|
|
667a1b8878 | ||
|
|
32621c6a11 | ||
|
|
f8acd9c402 | ||
|
|
873de5f37a | ||
|
|
aa6f7a83bb | ||
|
|
6ea8c128a3 | ||
|
|
6e469a3f35 | ||
|
|
b8f848bfe3 | ||
|
|
4064062e7d | ||
|
|
8aabe2403e | ||
|
|
0167653781 | ||
|
|
0a7993729c | ||
|
|
bbe2c13a70 | ||
|
|
3aace5c8dc | ||
|
|
b0d9708974 | ||
|
|
c9b633d84f | ||
|
|
1711020904 | ||
|
|
d9b8567547 | ||
|
|
6c5f906bf2 | ||
|
|
4f5bd39b1c | ||
|
|
dcff27fe3f | ||
|
|
09725967cf | ||
|
|
5f62440fbb | ||
|
|
ac91c340f4 | ||
|
|
2db3b0ff90 | ||
|
|
6516ab335d | ||
|
|
ad53e78f11 | ||
|
|
29011ba87e | ||
|
|
cd4985e2f3 | ||
|
|
bfe31d0b9d | ||
|
|
2129e7d278 | ||
|
|
7ee77ff038 | ||
|
|
26c5bbb875 | ||
|
|
a97c98068f | ||
|
|
635406e283 | ||
|
|
ed6002cb60 | ||
|
|
bc72d7f8d1 | ||
|
|
aef4e13588 | ||
|
|
4e6a1b66a9 | ||
|
|
9cf299a9f9 | ||
|
|
e89b22993a | ||
|
|
55bd606e92 | ||
|
|
79cdbc81cb | ||
|
|
f443b9f2ca | ||
|
|
4e3038114a | ||
|
|
bbb8864778 | ||
|
|
d7f3241bf6 | ||
|
|
09a2e67151 | ||
|
|
0fd1b78736 | ||
|
|
8490eedadf |
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
@@ -0,0 +1,127 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: false
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
sequence_diagrams: false
|
||||
estimate_code_review_effort: false
|
||||
assess_linked_issues: false
|
||||
related_issues: false
|
||||
related_prs: false
|
||||
suggested_labels: false
|
||||
auto_apply_labels: false
|
||||
suggested_reviewers: false
|
||||
auto_assign_reviewers: false
|
||||
in_progress_fortune: false
|
||||
enable_prompt_for_ai_agents: true
|
||||
|
||||
path_filters:
|
||||
- "!comfy_api_nodes/apis/**"
|
||||
- "!**/generated/*.pyi"
|
||||
- "!.ci/**"
|
||||
- "!script_examples/**"
|
||||
- "!**/__pycache__/**"
|
||||
- "!**/*.ipynb"
|
||||
- "!**/*.png"
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
- Backward compatibility (breaking changes affect all custom nodes)
|
||||
- Memory management and GPU resource handling
|
||||
- Performance implications in hot paths
|
||||
- Thread safety for concurrent execution
|
||||
- path: "comfy_api_nodes/**"
|
||||
instructions: |
|
||||
Third-party API integration nodes. Focus on:
|
||||
- No hardcoded API keys or secrets
|
||||
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||
- Correct Pydantic model usage
|
||||
- Security of user data passed to external APIs
|
||||
- path: "comfy_extras/**"
|
||||
instructions: |
|
||||
Community-contributed extra nodes. Focus on:
|
||||
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||
- No breaking changes to existing node interfaces
|
||||
- path: "comfy_execution/**"
|
||||
instructions: |
|
||||
Execution engine (graph execution, caching, jobs). Focus on:
|
||||
- Caching correctness
|
||||
- Concurrent execution safety
|
||||
- Graph validation edge cases
|
||||
- path: "nodes.py"
|
||||
instructions: |
|
||||
Core node definitions (2500+ lines). Focus on:
|
||||
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||
- Consistency of INPUT_TYPES return format
|
||||
- path: "alembic_db/**"
|
||||
instructions: |
|
||||
Database migrations. Focus on:
|
||||
- Migration safety and rollback support
|
||||
- Data preservation during schema changes
|
||||
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
enabled: false
|
||||
gitleaks:
|
||||
enabled: true
|
||||
shellcheck:
|
||||
enabled: false
|
||||
markdownlint:
|
||||
enabled: false
|
||||
yamllint:
|
||||
enabled: false
|
||||
languagetool:
|
||||
enabled: false
|
||||
github-checks:
|
||||
enabled: true
|
||||
timeout_ms: 90000
|
||||
ast-grep:
|
||||
essential_rules: true
|
||||
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
learnings:
|
||||
scope: "auto"
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
python_patch: "11"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
name: "Release AMD ROCm 7.2"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
cache_tag: "rocm72"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "11"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
venv*/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
10
README.md
10
README.md
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
@@ -208,7 +206,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
@@ -227,11 +225,11 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@@ -8,6 +11,9 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@@ -8,9 +7,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@@ -171,12 +230,14 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@@ -265,3 +814,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@@ -1,13 +1,33 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@@ -63,7 +100,6 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@@ -76,7 +112,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
107
app/node_replace_manager.py
Normal file
107
app/node_replace_manager.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
|
||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
@@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@@ -257,3 +258,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
@@ -236,6 +238,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
DEV_ONLY: bool
|
||||
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
|
||||
@@ -4,6 +4,25 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
if not is_equal(self.cond, other.cond):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -5,7 +5,7 @@ from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@@ -13,6 +13,9 @@ from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@@ -80,6 +81,7 @@ class SD_X4(LatentFormat):
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
spacial_downscale_ratio = 42
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
@@ -102,6 +104,7 @@ class SC_Prior(LatentFormat):
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
spacial_downscale_ratio = 4
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
@@ -181,6 +184,7 @@ class Flux(SD3):
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
@@ -272,6 +276,7 @@ class Mochi(LatentFormat):
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@@ -515,6 +520,7 @@ class Wan21(LatentFormat):
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
@@ -592,6 +598,7 @@ class Wan22(Wan21):
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
spacial_downscale_ratio = 32
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
@@ -725,6 +732,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@@ -747,8 +755,13 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
|
||||
1155
comfy/ldm/ace/ace_step15.py
Normal file
1155
comfy/ldm/ace/ace_step15.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -4,8 +4,6 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@@ -13,6 +13,7 @@ from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
@@ -334,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -462,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -511,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -521,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -535,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -554,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -562,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
orig_shape = list(x.shape)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
@@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
@@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@@ -16,7 +16,6 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -81,7 +80,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
@@ -143,6 +142,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -232,6 +232,7 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -18,12 +19,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@@ -215,22 +216,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -240,144 +228,102 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@@ -505,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -589,9 +558,20 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@@ -618,8 +598,9 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@@ -662,10 +643,11 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
@@ -744,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
@@ -788,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -808,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -829,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
torch.empty(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
@@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
@@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
@@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Base implementation returns None (no attenuation). Subclasses that
|
||||
support guide-based attention control should override this.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Build self-attention mask for per-guide attenuation
|
||||
self_attention_mask = self._build_guide_self_attention_mask(
|
||||
x, transformer_options, merged_args
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
x, context, attention_mask, timestep, pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
**merged_args,
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
|
||||
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||
# Each entry tracks one guide reference; they are appended in order and
|
||||
# their pre_filter_counts partition the kf_grid_mask.
|
||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_entries:
|
||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||
if total_pfc != len(kf_grid_mask):
|
||||
raise ValueError(
|
||||
f"guide pre_filter_counts ({total_pfc}) != "
|
||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||
)
|
||||
resolved_entries = []
|
||||
offset = 0
|
||||
for entry in guide_entries:
|
||||
pfc = entry["pre_filter_count"]
|
||||
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||
surviving = int(entry_mask.sum().item())
|
||||
resolved_entries.append({
|
||||
**entry,
|
||||
"surviving_count": surviving,
|
||||
})
|
||||
offset += pfc
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||
attention for each guide reference independently.
|
||||
|
||||
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||
spatial masks, or no guide tokens).
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
# AV model: x = [vx, ax]; use vx for token count and device
|
||||
total_tokens = x[0].shape[1]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
else:
|
||||
total_tokens = x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||
if num_guide_tokens == 0:
|
||||
return None
|
||||
|
||||
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
# Guides are appended in order at the end of the sequence.
|
||||
guide_start = total_tokens - num_guide_tokens
|
||||
all_weights = []
|
||||
total_tracked = 0
|
||||
|
||||
for entry in resolved_entries:
|
||||
surviving = entry["surviving_count"]
|
||||
if surviving == 0:
|
||||
continue
|
||||
|
||||
strength = entry["strength"]
|
||||
pixel_mask = entry.get("pixel_mask")
|
||||
latent_shape = entry.get("latent_shape")
|
||||
|
||||
if pixel_mask is not None and latent_shape is not None:
|
||||
f_lat, h_lat, w_lat = latent_shape
|
||||
per_token = self._downsample_mask_to_latent(
|
||||
pixel_mask.to(device=device, dtype=dtype),
|
||||
f_lat, h_lat, w_lat,
|
||||
)
|
||||
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||
# Collapse batch dim — the mask is assumed identical across the
|
||||
# batch; validate and take the first element to get (1, tokens).
|
||||
if per_token.shape[0] > 1:
|
||||
ref = per_token[0]
|
||||
for bi in range(1, per_token.shape[0]):
|
||||
if not torch.equal(ref, per_token[bi]):
|
||||
logger.warning(
|
||||
"pixel_mask differs across batch elements; "
|
||||
"using first element only."
|
||||
)
|
||||
break
|
||||
per_token = per_token[:1]
|
||||
# `surviving` is the post-grid_mask token count.
|
||||
# Clamp to surviving to handle any mismatch safely.
|
||||
n_weights = min(per_token.shape[1], surviving)
|
||||
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||
else:
|
||||
weights = torch.full(
|
||||
(1, surviving), strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
all_weights.append(weights)
|
||||
total_tracked += weights.shape[1]
|
||||
|
||||
if not all_weights:
|
||||
return None
|
||||
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
"""Downsample a pixel-space mask to per-token latent weights.
|
||||
|
||||
Args:
|
||||
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||
f_lat: Number of latent frames (pre-dilation original count).
|
||||
h_lat: Latent height (pre-dilation original height).
|
||||
w_lat: Latent width (pre-dilation original width).
|
||||
|
||||
Returns:
|
||||
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||
"""
|
||||
b = mask.shape[0]
|
||||
f_pix = mask.shape[2]
|
||||
|
||||
# Spatial downsampling: area interpolation per frame
|
||||
spatial_down = torch.nn.functional.interpolate(
|
||||
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||
size=(h_lat, w_lat),
|
||||
mode="area",
|
||||
)
|
||||
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||
|
||||
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||
first_frame = spatial_down[:, :, :1, :, :]
|
||||
if f_pix > 1 and f_lat > 1:
|
||||
remaining_pix = f_pix - 1
|
||||
remaining_lat = f_lat - 1
|
||||
t = remaining_pix // remaining_lat
|
||||
if t < 1:
|
||||
# Fewer pixel frames than latent frames — upsample by repeating
|
||||
# the available pixel frames via nearest interpolation.
|
||||
rest_flat = rearrange(
|
||||
spatial_down[:, :, 1:, :, :],
|
||||
"b 1 f h w -> (b h w) 1 f",
|
||||
)
|
||||
rest_up = torch.nn.functional.interpolate(
|
||||
rest_flat, size=remaining_lat, mode="nearest",
|
||||
)
|
||||
rest = rearrange(
|
||||
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||
b=b, h=h_lat, w=w_lat,
|
||||
)
|
||||
else:
|
||||
# Trim trailing pixel frames that don't fill a complete group
|
||||
usable = remaining_lat * t
|
||||
rest = rearrange(
|
||||
spatial_down[:, :, 1:1 + usable, :, :],
|
||||
"b 1 (f t) h w -> b 1 f t h w",
|
||||
t=t,
|
||||
)
|
||||
rest = rest.mean(dim=3)
|
||||
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||
elif f_lat > 1:
|
||||
# Single pixel frame but multiple latent frames — repeat the
|
||||
# single frame across all latent frames.
|
||||
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||
else:
|
||||
latent_mask = first_frame
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
tid = threading.get_ident()
|
||||
|
||||
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
|
||||
if cached is None:
|
||||
padding_length = self.time_kernel_size - 1
|
||||
if not causal:
|
||||
padding_length = padding_length // 2
|
||||
if x.shape[2] == 0:
|
||||
return x
|
||||
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
@@ -6,12 +7,35 @@ import math
|
||||
from einops import rearrange
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
current = m.temporal_cache_state.get(tid, (None, False))
|
||||
m.temporal_cache_state[tid] = (current[0], True)
|
||||
|
||||
def split2(tensor, split_point, dim=2):
|
||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||
|
||||
def add_exchange_cache(dest, cache_in, new_input, dim=2):
|
||||
if dest is not None:
|
||||
if cache_in is not None:
|
||||
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
|
||||
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
|
||||
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
|
||||
lead_in_dest.add_(lead_in_source)
|
||||
body, new_input = split2(new_input, dest.shape[dim], dim)
|
||||
dest.add_(body)
|
||||
return torch_cat_if_needed([cache_in, new_input], dim=dim)
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
@@ -205,7 +229,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
@@ -254,6 +278,22 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
tid = threading.get_ident()
|
||||
for _, module in self.named_modules():
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -341,18 +381,6 @@ class Decoder(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
@@ -428,8 +456,9 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
@@ -437,6 +466,7 @@ class Decoder(nn.Module):
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -445,24 +475,12 @@ class Decoder(nn.Module):
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
scaled_timestep = None
|
||||
timestep_shift_scale = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
@@ -483,16 +501,62 @@ class Decoder(nn.Module):
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
output = []
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
for _, module in self.named_modules():
|
||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||
#we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
@@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
tid = threading.get_ident()
|
||||
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
|
||||
y = self.conv(x, causal=causal)
|
||||
y = rearrange(
|
||||
y,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
|
||||
y = y[:, :, 1:, :, :]
|
||||
drop_first_conv = False
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
drop_first_res = False
|
||||
|
||||
if y.shape[2] == 0:
|
||||
y = None
|
||||
|
||||
cached = add_exchange_cache(y, cached, x_in, dim=2)
|
||||
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
|
||||
|
||||
else:
|
||||
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
|
||||
|
||||
return y
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
@@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
tid = threading.get_ident()
|
||||
cached = self.temporal_cache_state.get(tid, None)
|
||||
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
|
||||
self.temporal_cache_state[tid] = cached
|
||||
|
||||
return output_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
|
||||
@@ -451,6 +451,7 @@ class NextDiT(nn.Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
elif len(xl) == 1:
|
||||
return xl[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
@@ -99,19 +102,7 @@ class VideoConv3d(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
|
||||
@@ -18,6 +18,8 @@ import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
from ..sdpose import HeatmapHead
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
@@ -441,6 +443,7 @@ class UNetModel(nn.Module):
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
heatmap_head=False,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@@ -827,6 +830,9 @@ class UNetModel(nn.Module):
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
if heatmap_head:
|
||||
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
|
||||
130
comfy/ldm/modules/sdpose.py
Normal file
130
comfy/ldm/modules/sdpose.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
class HeatmapHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=640,
|
||||
out_channels=133,
|
||||
input_size=(768, 1024),
|
||||
heatmap_scale=4,
|
||||
deconv_out_channels=(640,),
|
||||
deconv_kernel_sizes=(4,),
|
||||
conv_out_channels=(640,),
|
||||
conv_kernel_sizes=(1,),
|
||||
final_layer_kernel_size=1,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
||||
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
||||
|
||||
# Deconv layers
|
||||
if deconv_out_channels:
|
||||
deconv_layers = []
|
||||
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
||||
if kernel_size == 4:
|
||||
padding, output_padding = 1, 0
|
||||
elif kernel_size == 3:
|
||||
padding, output_padding = 1, 1
|
||||
elif kernel_size == 2:
|
||||
padding, output_padding = 0, 0
|
||||
else:
|
||||
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
||||
|
||||
deconv_layers.extend([
|
||||
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
||||
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
||||
else:
|
||||
self.deconv_layers = torch.nn.Identity()
|
||||
|
||||
# Conv layers
|
||||
if conv_out_channels:
|
||||
conv_layers = []
|
||||
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv_layers.extend([
|
||||
operations.Conv2d(in_channels, out_ch, kernel_size,
|
||||
stride=1, padding=padding, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
||||
else:
|
||||
self.conv_layers = torch.nn.Identity()
|
||||
|
||||
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x): # Decode heatmaps to keypoints
|
||||
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
||||
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
||||
B, K, H, W = heatmaps_np.shape
|
||||
|
||||
batch_keypoints = []
|
||||
batch_scores = []
|
||||
|
||||
for b in range(B):
|
||||
hm = heatmaps_np[b].copy() # (K, H, W)
|
||||
|
||||
# --- vectorised argmax ---
|
||||
flat = hm.reshape(K, -1)
|
||||
idx = np.argmax(flat, axis=1)
|
||||
scores = flat[np.arange(K), idx].copy()
|
||||
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
||||
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
||||
invalid = scores <= 0.
|
||||
keypoints[invalid] = -1
|
||||
|
||||
# --- DARK sub-pixel refinement (UDP) ---
|
||||
# 1. Gaussian blur with max-preserving normalisation
|
||||
border = 5 # (kernel-1)//2 for kernel=11
|
||||
for k in range(K):
|
||||
origin_max = np.max(hm[k])
|
||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||
dr[border:-border, border:-border] = hm[k].copy()
|
||||
dr = gaussian_filter(dr, sigma=2.0)
|
||||
hm[k] = dr[border:-border, border:-border].copy()
|
||||
cur_max = np.max(hm[k])
|
||||
if cur_max > 0:
|
||||
hm[k] *= origin_max / cur_max
|
||||
# 2. Log-space for Taylor expansion
|
||||
np.clip(hm, 1e-3, 50., hm)
|
||||
np.log(hm, hm)
|
||||
# 3. Hessian-based Newton step
|
||||
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
||||
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
||||
index += (W + 2) * (H + 2) * np.arange(0, K)
|
||||
index = index.astype(int).reshape(-1, 1)
|
||||
i_ = hm_pad[index]
|
||||
ix1 = hm_pad[index + 1]
|
||||
iy1 = hm_pad[index + W + 2]
|
||||
ix1y1 = hm_pad[index + W + 3]
|
||||
ix1_y1_ = hm_pad[index - W - 3]
|
||||
ix1_ = hm_pad[index - 1]
|
||||
iy1_ = hm_pad[index - 2 - W]
|
||||
dx = 0.5 * (ix1 - ix1_)
|
||||
dy = 0.5 * (iy1 - iy1_)
|
||||
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
||||
dxx = ix1 - 2 * i_ + ix1_
|
||||
dyy = iy1 - 2 * i_ + iy1_
|
||||
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
||||
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
||||
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
||||
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
||||
|
||||
# --- restore to input image space ---
|
||||
keypoints = keypoints * self.scale_factor
|
||||
keypoints[invalid] = -1
|
||||
|
||||
batch_keypoints.append(keypoints)
|
||||
batch_scores.append(scores)
|
||||
|
||||
return batch_keypoints, batch_scores
|
||||
@@ -2,6 +2,196 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@@ -170,8 +170,14 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
@@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
self._padding = 2 * self.padding[0]
|
||||
self.padding = (0, self.padding[1], self.padding[2])
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
if cache_x is None and x.shape[2] == 1:
|
||||
#Fast path - the op will pad for use by truncating the weight
|
||||
#and save math on a pile of zeros.
|
||||
return super().forward(x, autopad="causal_zero")
|
||||
|
||||
if self._padding > 0:
|
||||
padding_needed = self._padding
|
||||
if cache_x is not None:
|
||||
cache_x = cache_x.to(x.device)
|
||||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[2] = padding_needed
|
||||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
@@ -452,6 +459,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
conv_out_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -467,15 +475,17 @@ class WanVAE(nn.Module):
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -495,10 +505,11 @@ class WanVAE(nn.Module):
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
@@ -331,6 +332,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@@ -367,6 +375,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_shape(patches, weight, key, original_weights=None):
|
||||
current_shape = weight.shape
|
||||
|
||||
for p in patches:
|
||||
v = p[1]
|
||||
offset = p[3]
|
||||
|
||||
# Offsets restore the old shape; lists force a diff without metadata
|
||||
if offset is not None or isinstance(v, list):
|
||||
continue
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
adapter_shape = v.calculate_shape(key)
|
||||
if adapter_shape is not None:
|
||||
current_shape = adapter_shape
|
||||
continue
|
||||
|
||||
# Standard diff logic with padding
|
||||
if len(v) == 2:
|
||||
patch_type, patch_data = v[0], v[1]
|
||||
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||
current_shape = patch_data[0].shape
|
||||
|
||||
return current_shape
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user