mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-10 02:19:31 +00:00
Compare commits
117 Commits
add-codera
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
814dab9f46 | ||
|
|
06f85e2c79 | ||
|
|
e4b0bb8305 | ||
|
|
7723f20bbe | ||
|
|
29b24cb517 | ||
|
|
a7a6335be5 | ||
|
|
bcf1a1fab1 | ||
|
|
6ac8152fc8 | ||
|
|
afc00f0055 | ||
|
|
d69d30819b | ||
|
|
f466b06601 | ||
|
|
34e55f0061 | ||
|
|
3b93d5d571 | ||
|
|
e544c65db9 | ||
|
|
1c21828236 | ||
|
|
58017e8726 | ||
|
|
17b43c2b87 | ||
|
|
8befce5c7b | ||
|
|
50549aa252 | ||
|
|
1c3b651c0a | ||
|
|
5073da57ad | ||
|
|
42e0e023ee | ||
|
|
6481569ad4 | ||
|
|
6ef82a89b8 | ||
|
|
da29b797ce | ||
|
|
9cdfd7403b | ||
|
|
bd21363563 | ||
|
|
e04d0dbeb8 | ||
|
|
c8428541a6 | ||
|
|
4941671b5a | ||
|
|
c5fe8ace68 | ||
|
|
f2ee7f2d36 | ||
|
|
43c64b6308 | ||
|
|
ac4a943ff3 | ||
|
|
8811db52db | ||
|
|
0a7446ade4 | ||
|
|
9b85cf9558 | ||
|
|
d531e3fb2a | ||
|
|
eb011733b6 | ||
|
|
ac6513e142 | ||
|
|
b6ddc590ed | ||
|
|
f719a9d928 | ||
|
|
174fd6759d | ||
|
|
09bcbddfcf | ||
|
|
dff0a4a158 | ||
|
|
9ebee0a217 | ||
|
|
57dd6c1aad | ||
|
|
f1f8996e15 | ||
|
|
afb54219fa | ||
|
|
7175c11a4e | ||
|
|
dfbf99a061 | ||
|
|
602f6bd82c | ||
|
|
c0d472e5b9 | ||
|
|
4d79f4f028 | ||
|
|
850e8b42ff | ||
|
|
d159142615 | ||
|
|
1080bd442a | ||
|
|
17106cb124 | ||
|
|
48bb0bd18a | ||
|
|
5f41584e96 | ||
|
|
1f6744162f | ||
|
|
95e1059661 | ||
|
|
80d49441e5 | ||
|
|
9d0e114ee3 | ||
|
|
ac4412d0fa | ||
|
|
94f1a1cc9d | ||
|
|
e721e24136 | ||
|
|
25ec3d96a3 | ||
|
|
1f1ec377ce | ||
|
|
0a7f8e11b6 | ||
|
|
35e9fce775 | ||
|
|
c7f7d52b68 | ||
|
|
08b26ed7c2 | ||
|
|
b233dbe0bc | ||
|
|
3811780e4f | ||
|
|
3dd10a59c0 | ||
|
|
88d05fe483 | ||
|
|
fd41ec97cc | ||
|
|
420e900f69 | ||
|
|
38ca94599f | ||
|
|
74b5a337dc | ||
|
|
8a4d85c708 | ||
|
|
a4522017c5 | ||
|
|
907e5dcbbf | ||
|
|
7253531670 | ||
|
|
e14b04478c | ||
|
|
eb8737d675 | ||
|
|
0467f690a8 | ||
|
|
4f5b7dbf1f | ||
|
|
3ebe1ac22e | ||
|
|
befa83d434 | ||
|
|
33f83d53ae | ||
|
|
b874bd2b8c | ||
|
|
0aa02453bb | ||
|
|
599f9c5010 | ||
|
|
11fefa58e9 | ||
|
|
d8090013b8 | ||
|
|
048dd2f321 | ||
|
|
84aba95e03 | ||
|
|
9b1c63eb69 | ||
|
|
7a7debcaf1 | ||
|
|
dba2766e53 | ||
|
|
caa43d2395 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 | ||
|
|
5fb8f06495 | ||
|
|
5a182bfaf1 | ||
|
|
f394af8d0f | ||
|
|
aeb5bdc8f6 | ||
|
|
64953bda0a | ||
|
|
b254cecd03 |
@@ -1,6 +1,7 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
@@ -35,6 +36,14 @@ reviews:
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
@@ -74,7 +83,11 @@ reviews:
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
@@ -84,7 +97,7 @@ reviews:
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: true
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
@@ -229,9 +227,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Merge AssetInfo and AssetCacheState into unified asset_references table.
|
||||
|
||||
This migration drops old tables and creates the new unified schema.
|
||||
All existing data is discarded.
|
||||
|
||||
Revision ID: 0002_merge_to_asset_references
|
||||
Revises: 0001_assets
|
||||
Create Date: 2025-02-11
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0002_merge_to_asset_references"
|
||||
down_revision = "0001_assets"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop old tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
# Truncate assets table (cascades handled by dropping dependent tables first)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Create asset_references table
|
||||
op.create_table(
|
||||
"asset_references",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column(
|
||||
"asset_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column(
|
||||
"needs_verify",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column(
|
||||
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column(
|
||||
"preview_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
|
||||
)
|
||||
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
|
||||
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
|
||||
op.create_index("ix_asset_references_name", "asset_references", ["name"])
|
||||
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
|
||||
op.create_index(
|
||||
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
|
||||
)
|
||||
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
|
||||
op.create_index(
|
||||
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
|
||||
)
|
||||
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
|
||||
|
||||
# Create asset_reference_tags table
|
||||
op.create_table(
|
||||
"asset_reference_tags",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"tag_name",
|
||||
sa.String(length=512),
|
||||
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"origin", sa.String(length=32), nullable=False, server_default="manual"
|
||||
),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_asset_reference_id",
|
||||
"asset_reference_tags",
|
||||
["asset_reference_id"],
|
||||
)
|
||||
|
||||
# Create asset_reference_meta table
|
||||
op.create_table(
|
||||
"asset_reference_meta",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_bool",
|
||||
"asset_reference_meta",
|
||||
["key", "val_bool"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
|
||||
|
||||
NOTE: Data is not recoverable. The upgrade discards all rows from the old
|
||||
tables and truncates assets. After downgrade the old schema will be empty.
|
||||
A filesystem rescan will repopulate data once the older code is running.
|
||||
"""
|
||||
# Drop new tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
|
||||
op.drop_table("asset_reference_meta")
|
||||
|
||||
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
|
||||
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
|
||||
op.drop_table("asset_reference_tags")
|
||||
|
||||
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
|
||||
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
|
||||
op.drop_table("asset_references")
|
||||
|
||||
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Recreate old tables from 0001_assets schema
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -10,6 +12,41 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class UploadError(Exception):
|
||||
"""Error during upload parsing with HTTP status and code."""
|
||||
|
||||
def __init__(self, status: int, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AssetValidationError(Exception):
|
||||
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
|
||||
|
||||
def __init__(self, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedUpload:
|
||||
"""Result of parsing a multipart upload request."""
|
||||
|
||||
file_present: bool
|
||||
file_written: int
|
||||
file_client_name: str | None
|
||||
tmp_path: str | None
|
||||
tags_raw: list[str]
|
||||
provided_name: str | None
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel):
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
)
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel):
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
def _validate_at_least_one_field(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
@@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel):
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
return validate_blake3_hash(v or "")
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
def _normalize_tags_field(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
@@ -154,15 +185,16 @@ class TagsRemove(TagsAdd):
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
@@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel):
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
return validate_blake3_hash(s)
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
@@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel):
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
|
||||
171
app/assets/api/upload.py
Normal file
171
app/assets/api/upload.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import folder_paths
|
||||
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
|
||||
|
||||
def normalize_and_validate_hash(s: str) -> str:
|
||||
"""Validate and normalize a hash string.
|
||||
|
||||
Returns canonical 'blake3:<hex>' or raises UploadError.
|
||||
"""
|
||||
try:
|
||||
return validate_blake3_hash(s)
|
||||
except ValueError:
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
|
||||
async def parse_multipart_upload(
|
||||
request: web.Request,
|
||||
check_hash_exists: Callable[[str], bool],
|
||||
) -> ParsedUpload:
|
||||
"""
|
||||
Parse a multipart/form-data upload request.
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
|
||||
|
||||
Returns:
|
||||
ParsedUpload with parsed fields and temp file path
|
||||
|
||||
Raises:
|
||||
UploadError: On validation or I/O errors
|
||||
"""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
raise UploadError(
|
||||
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
|
||||
)
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
|
||||
if s:
|
||||
provided_hash = normalize_and_validate_hash(s)
|
||||
try:
|
||||
provided_hash_exists = check_hash_exists(provided_hash)
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"check_hash_exists failed for hash=%s: %s", provided_hash, e
|
||||
)
|
||||
raise UploadError(
|
||||
500,
|
||||
"HASH_CHECK_FAILED",
|
||||
"Backend error while checking asset hash.",
|
||||
)
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# Hash exists - drain file but don't write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
|
||||
)
|
||||
continue
|
||||
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
|
||||
)
|
||||
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
|
||||
)
|
||||
|
||||
if (
|
||||
file_present
|
||||
and file_written == 0
|
||||
and not (provided_hash and provided_hash_exists)
|
||||
):
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
return ParsedUpload(
|
||||
file_present=file_present,
|
||||
file_written=file_written,
|
||||
file_client_name=file_client_name,
|
||||
tmp_path=tmp_path,
|
||||
tags_raw=tags_raw,
|
||||
provided_name=provided_name,
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
)
|
||||
|
||||
|
||||
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
|
||||
"""Safely remove a temp file and its parent directory if empty."""
|
||||
if tmp_path:
|
||||
try:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
except OSError as e:
|
||||
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
|
||||
try:
|
||||
parent = os.path.dirname(tmp_path)
|
||||
if parent and os.path.isdir(parent):
|
||||
os.rmdir(parent) # only succeeds if empty
|
||||
except OSError:
|
||||
pass
|
||||
@@ -1,204 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -16,102 +16,102 @@ from sqlalchemy import (
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
references: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
|
||||
foreign_keys=lambda: [AssetReference.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||
foreign_keys=lambda: [AssetReference.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
class AssetReference(Base):
|
||||
"""Unified model combining file cache state and user-facing metadata.
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
Each row represents either:
|
||||
- A filesystem reference (file_path is set) with cache state
|
||||
- An API-created reference (file_path is NULL) without cache state
|
||||
"""
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
__tablename__ = "asset_references"
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
asset_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
# Cache state fields (from former AssetCacheState)
|
||||
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
# Info fields (from former AssetInfo)
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=False), nullable=True, default=None
|
||||
)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
back_populates="references",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
@@ -121,51 +121,59 @@ class AssetInfo(Base):
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
overlaps="tags,asset_references",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="asset_references",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
overlaps="tag_links,asset_reference_links,asset_references,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
Index("uq_asset_references_file_path", "file_path", unique=True),
|
||||
Index("ix_asset_references_asset_id", "asset_id"),
|
||||
Index("ix_asset_references_owner_id", "owner_id"),
|
||||
Index("ix_asset_references_name", "name"),
|
||||
Index("ix_asset_references_is_missing", "is_missing"),
|
||||
Index("ix_asset_references_enrichment_level", "enrichment_level"),
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
path_part = f" path={self.file_path!r}" if self.file_path else ""
|
||||
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
class AssetReferenceMeta(Base):
|
||||
__tablename__ = "asset_reference_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
@@ -175,36 +183,40 @@ class AssetInfoMeta(Base):
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
asset_reference: Mapped[AssetReference] = relationship(
|
||||
back_populates="metadata_entries"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
Index("ix_asset_reference_meta_key", "key"),
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
class AssetReferenceTag(Base):
|
||||
__tablename__ = "asset_reference_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
Index("ix_asset_reference_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
|
||||
)
|
||||
|
||||
|
||||
@@ -214,20 +226,18 @@ class Tag(Base):
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
overlaps="asset_references,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
asset_references: Mapped[list[AssetReference]] = relationship(
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@@ -1,976 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
121
app/assets/database/queries/__init__.py
Normal file
121
app/assets/database/queries/__init__.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
)
|
||||
from app.assets.database.queries.asset_reference import (
|
||||
CacheStateRow,
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
delete_assets_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_reference_by_id,
|
||||
delete_references_by_ids,
|
||||
fetch_reference_and_asset,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_or_create_reference,
|
||||
get_reference_by_file_path,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
SetTagsResult,
|
||||
add_missing_tag_for_asset_id,
|
||||
add_tags_to_reference,
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
set_reference_tags,
|
||||
validate_tags_exist,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"CacheStateRow",
|
||||
"RemoveTagsResult",
|
||||
"SetTagsResult",
|
||||
"UnenrichedReferenceRow",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_tags_to_reference",
|
||||
"asset_exists_by_hash",
|
||||
"bulk_insert_assets",
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
"delete_assets_by_ids",
|
||||
"delete_orphaned_seed_asset",
|
||||
"delete_reference_by_id",
|
||||
"delete_references_by_ids",
|
||||
"ensure_tags_exist",
|
||||
"fetch_reference_and_asset",
|
||||
"fetch_reference_asset_and_tags",
|
||||
"get_asset_by_hash",
|
||||
"get_existing_asset_ids",
|
||||
"get_or_create_reference",
|
||||
"get_reference_by_file_path",
|
||||
"get_reference_by_id",
|
||||
"get_reference_with_owner_check",
|
||||
"get_reference_ids_by_ids",
|
||||
"get_reference_tags",
|
||||
"get_references_by_paths_and_asset_ids",
|
||||
"get_references_for_prefixes",
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
"update_reference_updated_at",
|
||||
"upsert_asset",
|
||||
"upsert_reference",
|
||||
"validate_tags_exist",
|
||||
]
|
||||
140
app/assets/database/queries/asset.py
Normal file
140
app/assets/database/queries/asset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True))
|
||||
.select_from(Asset)
|
||||
.where(Asset.hash == asset_hash)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def upsert_asset(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> tuple[Asset, bool, bool]:
|
||||
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
|
||||
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
|
||||
if mime_type:
|
||||
vals["mime_type"] = mime_type
|
||||
|
||||
ins = (
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
res = session.execute(ins)
|
||||
created = int(res.rowcount or 0) > 0
|
||||
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
|
||||
updated = False
|
||||
if not created:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
updated = True
|
||||
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
|
||||
if not rows:
|
||||
return
|
||||
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_existing_asset_ids(
|
||||
session: Session,
|
||||
asset_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Return the subset of asset_ids that exist in the database."""
|
||||
if not asset_ids:
|
||||
return set()
|
||||
found: set[str] = set()
|
||||
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
||||
rows = session.execute(
|
||||
select(Asset.id).where(Asset.id.in_(chunk))
|
||||
).fetchall()
|
||||
found.update(row[0] for row in rows)
|
||||
return found
|
||||
|
||||
|
||||
def update_asset_hash_and_mime(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
asset_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> bool:
|
||||
"""Update asset hash and/or mime_type. Returns True if asset was found."""
|
||||
asset = session.get(Asset, asset_id)
|
||||
if not asset:
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
def reassign_asset_references(
|
||||
session: Session,
|
||||
from_asset_id: str,
|
||||
to_asset_id: str,
|
||||
reference_id: str,
|
||||
) -> None:
|
||||
"""Reassign a reference from one asset to another.
|
||||
|
||||
Used when merging a stub asset into an existing asset with the same hash.
|
||||
"""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if ref and ref.asset_id == from_asset_id:
|
||||
ref.asset_id = to_asset_id
|
||||
|
||||
session.flush()
|
||||
1033
app/assets/database/queries/asset_reference.py
Normal file
1033
app/assets/database/queries/asset_reference.py
Normal file
File diff suppressed because it is too large
Load Diff
54
app/assets/database/queries/common.py
Normal file
54
app/assets/database/queries/common.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.assets.database.models import AssetReference
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
def calculate_rows_per_statement(cols: int) -> int:
|
||||
"""Calculate how many rows can fit in one statement given column count."""
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def iter_chunks(seq, n: int):
|
||||
"""Yield successive n-sized chunks from seq."""
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i : i + n]
|
||||
|
||||
|
||||
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
|
||||
"""Yield chunks of rows sized to fit within bind param limits."""
|
||||
if not rows:
|
||||
return
|
||||
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
|
||||
|
||||
|
||||
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads.
|
||||
|
||||
Owner-less rows are visible to everyone.
|
||||
"""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetReference.owner_id == ""
|
||||
return AssetReference.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def build_prefix_like_conditions(
|
||||
prefixes: list[str],
|
||||
) -> list[sa.sql.ColumnElement]:
|
||||
"""Build LIKE conditions for matching file paths under directory prefixes."""
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
356
app/assets/database/queries/tags.py
Normal file
356
app/assets/database/queries/tags.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddTagsResult:
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoveTagsResult:
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetTagsResult:
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def set_reference_tags(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
reference_row: AssetReference | None = None,
|
||||
) -> AddTagsResult:
|
||||
if not reference_row:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_reference_tags(session, reference_id=reference_id))
|
||||
return AddTagsResult(
|
||||
added=sorted(((after - current) & want)),
|
||||
already_present=sorted(want & current),
|
||||
total_tags=sorted(after),
|
||||
)
|
||||
|
||||
|
||||
def remove_tags_from_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> RemoveTagsResult:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
|
||||
|
||||
existing = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
|
||||
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetReference.id.label("asset_reference_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(get_utc_now()).label("added_at"),
|
||||
)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == "missing")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetReferenceTag)
|
||||
.from_select(
|
||||
["asset_reference_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id.in_(
|
||||
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
|
||||
),
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name.label("tag_name"),
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
visible_tags_sq = (
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
) -> None:
|
||||
"""Batch insert into asset_reference_tags and asset_reference_meta.
|
||||
|
||||
Uses ON CONFLICT DO NOTHING.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
|
||||
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
|
||||
session.execute(ins_tags, chunk)
|
||||
|
||||
if meta_rows:
|
||||
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceMeta.asset_reference_id,
|
||||
AssetReferenceMeta.key,
|
||||
AssetReferenceMeta.ordinal,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
@@ -1,226 +1,42 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
def select_best_live_path(states: Sequence) -> str:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
alive = [
|
||||
s
|
||||
for s in states
|
||||
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
|
||||
]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char in a LIKE prefix.
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
Returns (escaped_prefix, escape_char).
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
def get_utc_now() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
@@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
def project_kv(key: str, value):
|
||||
Returns canonical 'blake3:<hex>' or raises ValueError.
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
s = s.strip().lower()
|
||||
if not s or ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if (
|
||||
algo != "blake3"
|
||||
or len(digest) != 64
|
||||
or any(c for c in digest if c not in "0123456789abcdef")
|
||||
):
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@@ -1,516 +0,0 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
@@ -1,263 +1,567 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, TypedDict
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
from app.assets.database.queries import (
|
||||
add_missing_tag_for_asset_id,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.services.bulk_ingest import (
|
||||
SeedAssetSpec,
|
||||
batch_insert_seed_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
is_visible,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
class _RefInfo(TypedDict):
|
||||
ref_id: str
|
||||
file_path: str
|
||||
exists: bool
|
||||
stat_unchanged: bool
|
||||
needs_verify: bool
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
class _AssetAccumulator(TypedDict):
|
||||
hash: str | None
|
||||
size_db: int
|
||||
refs: list[_RefInfo]
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
|
||||
|
||||
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def get_all_known_prefixes() -> list[str]:
|
||||
"""Get all known asset prefixes across all root types."""
|
||||
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||
return [p for root in all_roots for p in get_prefixes_for_root(root)]
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
if not all(is_visible(part) for part in Path(rel_path).parts):
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
abs_p = Path(abs_path)
|
||||
for b in bases:
|
||||
if abs_p.is_relative_to(os.path.abspath(b)):
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
def sync_references_with_filesystem(
|
||||
session,
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""Reconcile asset references with filesystem for a root.
|
||||
|
||||
- Toggle needs_verify per reference using mtime/size stat check
|
||||
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
|
||||
- For seed assets with all refs missing: delete Asset and its references
|
||||
- Optionally add/remove 'missing' tags based on stat check in this root
|
||||
- Optionally return surviving absolute paths
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
root: Root type to scan
|
||||
collect_existing_paths: If True, return set of surviving file paths
|
||||
update_missing_tags: If True, update 'missing' tags based on file status
|
||||
|
||||
Returns:
|
||||
Set of surviving absolute paths if collect_existing_paths=True, else None
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
rows = get_references_for_prefixes(
|
||||
session, prefixes, include_missing=update_missing_tags
|
||||
)
|
||||
|
||||
by_asset: dict[str, _AssetAccumulator] = {}
|
||||
for row in rows:
|
||||
acc = by_asset.get(row.asset_id)
|
||||
if acc is None:
|
||||
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
|
||||
by_asset[row.asset_id] = acc
|
||||
|
||||
stat_unchanged = False
|
||||
try:
|
||||
exists = True
|
||||
stat_unchanged = verify_file_unchanged(
|
||||
mtime_db=row.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except PermissionError:
|
||||
exists = True
|
||||
logging.debug("Permission denied accessing %s", row.file_path)
|
||||
except OSError as e:
|
||||
exists = False
|
||||
logging.debug("OSError checking %s: %s", row.file_path, e)
|
||||
|
||||
acc["refs"].append(
|
||||
{
|
||||
"ref_id": row.reference_id,
|
||||
"file_path": row.file_path,
|
||||
"exists": exists,
|
||||
"stat_unchanged": stat_unchanged,
|
||||
"needs_verify": row.needs_verify,
|
||||
}
|
||||
)
|
||||
|
||||
to_set_verify: list[str] = []
|
||||
to_clear_verify: list[str] = []
|
||||
stale_ref_ids: list[str] = []
|
||||
to_mark_missing: list[str] = []
|
||||
to_clear_missing: list[str] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
refs = acc["refs"]
|
||||
any_unchanged = any(r["stat_unchanged"] for r in refs)
|
||||
all_missing = all(not r["exists"] for r in refs)
|
||||
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
to_mark_missing.append(r["ref_id"])
|
||||
continue
|
||||
if r["stat_unchanged"]:
|
||||
to_clear_missing.append(r["ref_id"])
|
||||
if r["needs_verify"]:
|
||||
to_clear_verify.append(r["ref_id"])
|
||||
if not r["stat_unchanged"] and not r["needs_verify"]:
|
||||
to_set_verify.append(r["ref_id"])
|
||||
|
||||
if a_hash is None:
|
||||
if refs and all_missing:
|
||||
delete_orphaned_seed_asset(session, aid)
|
||||
else:
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
continue
|
||||
|
||||
if any_unchanged:
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
stale_ref_ids.append(r["ref_id"])
|
||||
if update_missing_tags:
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=aid)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"Failed to remove missing tag for asset %s: %s", aid, e
|
||||
)
|
||||
elif update_missing_tags:
|
||||
try:
|
||||
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
|
||||
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
|
||||
delete_references_by_ids(session, stale_ref_ids)
|
||||
stale_set = set(stale_ref_ids)
|
||||
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
|
||||
bulk_update_is_missing(session, to_mark_missing, value=True)
|
||||
bulk_update_is_missing(session, to_clear_missing, value=False)
|
||||
bulk_update_needs_verify(session, to_set_verify, value=True)
|
||||
bulk_update_needs_verify(session, to_clear_verify, value=False)
|
||||
|
||||
return survivors if collect_existing_paths else None
|
||||
|
||||
|
||||
def sync_root_safely(root: RootType) -> set[str]:
|
||||
"""Sync a single root's references with the filesystem.
|
||||
|
||||
Returns survivors (existing paths) or empty set on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
survivors = sync_references_with_filesystem(
|
||||
sess,
|
||||
root,
|
||||
collect_existing_paths=True,
|
||||
update_missing_tags=True,
|
||||
)
|
||||
sess.commit()
|
||||
return survivors or set()
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", root, e)
|
||||
return set()
|
||||
|
||||
|
||||
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
||||
"""Mark references as missing when outside the given prefixes.
|
||||
|
||||
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
count = mark_references_missing_outside_prefixes(sess, prefixes)
|
||||
sess.commit()
|
||||
return count
|
||||
except Exception as e:
|
||||
logging.exception("marking missing assets failed: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
|
||||
"""Collect all file paths for the given roots."""
|
||||
paths: list[str] = []
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
|
||||
return paths
|
||||
|
||||
|
||||
def build_asset_specs(
|
||||
paths: list[str],
|
||||
existing_paths: set[str],
|
||||
enable_metadata_extraction: bool = True,
|
||||
compute_hashes: bool = False,
|
||||
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
|
||||
|
||||
Args:
|
||||
paths: List of file paths to process
|
||||
existing_paths: Set of paths that already exist in the database
|
||||
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
|
||||
compute_hashes: If True, compute blake3 hashes (slow for large files)
|
||||
"""
|
||||
specs: list[SeedAssetSpec] = []
|
||||
tag_pool: set[str] = set()
|
||||
skipped = 0
|
||||
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||
metadata = None
|
||||
if enable_metadata_extraction:
|
||||
metadata = extract_file_metadata(
|
||||
abs_p,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
|
||||
# Compute hash if requested
|
||||
asset_hash: str | None = None
|
||||
if compute_hashes:
|
||||
try:
|
||||
digest, _ = compute_blake3_hash(abs_p)
|
||||
asset_hash = "blake3:" + digest
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||
|
||||
mime_type = metadata.content_type if metadata else None
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": get_mtime_ns(stat_p),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": rel_fname,
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
|
||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
"""Insert asset specs into database, returning count of created refs."""
|
||||
if not specs:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
|
||||
|
||||
# Enrichment level constants
|
||||
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
|
||||
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
|
||||
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
|
||||
|
||||
|
||||
def get_unenriched_assets_for_roots(
|
||||
roots: tuple[RootType, ...],
|
||||
max_level: int = ENRICHMENT_STUB,
|
||||
limit: int = 1000,
|
||||
) -> list:
|
||||
"""Get assets that need enrichment for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
max_level: Maximum enrichment level to include
|
||||
limit: Maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
List of UnenrichedReferenceRow
|
||||
"""
|
||||
prefixes: list[str] = []
|
||||
for root in roots:
|
||||
prefixes.extend(get_prefixes_for_root(root))
|
||||
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
return get_unenriched_references(
|
||||
sess, prefixes, max_level=max_level, limit=limit
|
||||
)
|
||||
|
||||
|
||||
def enrich_asset(
|
||||
session,
|
||||
file_path: str,
|
||||
reference_id: str,
|
||||
asset_id: str,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> int:
|
||||
"""Enrich a single asset with metadata and/or hash.
|
||||
|
||||
Args:
|
||||
session: Database session (caller manages lifecycle)
|
||||
file_path: Absolute path to the file
|
||||
reference_id: ID of the reference to update
|
||||
asset_id: ID of the asset to update (for mime_type and hash)
|
||||
extract_metadata: If True, extract safetensors header and mime type
|
||||
compute_hash: If True, compute blake3 hash
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
New enrichment level achieved
|
||||
"""
|
||||
new_level = ENRICHMENT_STUB
|
||||
|
||||
try:
|
||||
stat_p = os.stat(file_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
|
||||
if extract_metadata:
|
||||
metadata = extract_file_metadata(
|
||||
file_path,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
if metadata:
|
||||
mime_type = metadata.content_type
|
||||
new_level = ENRICHMENT_METADATA
|
||||
|
||||
full_hash: str | None = None
|
||||
if compute_hash:
|
||||
try:
|
||||
mtime_before = get_mtime_ns(stat_p)
|
||||
size_before = stat_p.st_size
|
||||
|
||||
# Restore checkpoint if available and file unchanged
|
||||
checkpoint = None
|
||||
if hash_checkpoints is not None:
|
||||
checkpoint = hash_checkpoints.get(file_path)
|
||||
if checkpoint is not None:
|
||||
cur_stat = os.stat(file_path, follow_symlinks=True)
|
||||
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
|
||||
or checkpoint.file_size != cur_stat.st_size):
|
||||
checkpoint = None
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
else:
|
||||
mtime_before = get_mtime_ns(cur_stat)
|
||||
|
||||
digest, new_checkpoint = compute_blake3_hash(
|
||||
file_path,
|
||||
interrupt_check=interrupt_check,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
if digest is None:
|
||||
# Interrupted — save checkpoint for later resumption
|
||||
if hash_checkpoints is not None and new_checkpoint is not None:
|
||||
new_checkpoint.mtime_ns = mtime_before
|
||||
new_checkpoint.file_size = size_before
|
||||
hash_checkpoints[file_path] = new_checkpoint
|
||||
return new_level
|
||||
|
||||
# Completed — clear any saved checkpoint
|
||||
if hash_checkpoints is not None:
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
|
||||
stat_after = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_after = get_mtime_ns(stat_after)
|
||||
if mtime_before != mtime_after:
|
||||
logging.warning("File modified during hashing, discarding hash: %s", file_path)
|
||||
else:
|
||||
full_hash = f"blake3:{digest}"
|
||||
metadata_ok = not extract_metadata or metadata is not None
|
||||
if metadata_ok:
|
||||
new_level = ENRICHMENT_HASHED
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
if extract_metadata and metadata:
|
||||
user_metadata = metadata.to_user_metadata()
|
||||
set_reference_metadata(session, reference_id, user_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
if existing and existing.id != asset_id:
|
||||
reassign_asset_references(session, asset_id, existing.id, reference_id)
|
||||
delete_orphaned_seed_asset(session, asset_id)
|
||||
if mime_type:
|
||||
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
|
||||
else:
|
||||
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
|
||||
elif mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
|
||||
|
||||
bulk_update_enrichment_level(session, [reference_id], new_level)
|
||||
session.commit()
|
||||
|
||||
return new_level
|
||||
|
||||
|
||||
def enrich_assets_batch(
|
||||
rows: list,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Enrich a batch of assets.
|
||||
|
||||
Uses a single DB session for the entire batch, committing after each
|
||||
individual asset to avoid long-held transactions while eliminating
|
||||
per-asset session creation overhead.
|
||||
|
||||
Args:
|
||||
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
|
||||
extract_metadata: If True, extract metadata for each asset
|
||||
compute_hash: If True, compute hash for each asset
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
Tuple of (enriched_count, failed_reference_ids)
|
||||
"""
|
||||
enriched = 0
|
||||
failed_ids: list[str] = []
|
||||
|
||||
with create_session() as sess:
|
||||
for row in rows:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
break
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
new_level = enrich_asset(
|
||||
sess,
|
||||
file_path=row.file_path,
|
||||
reference_id=row.reference_id,
|
||||
asset_id=row.asset_id,
|
||||
extract_metadata=extract_metadata,
|
||||
compute_hash=compute_hash,
|
||||
interrupt_check=interrupt_check,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
if new_level > row.enrichment_level:
|
||||
enriched += 1
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
failed_ids.append(row.reference_id)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to enrich %s: %s", row.file_path, e)
|
||||
sess.rollback()
|
||||
failed_ids.append(row.reference_id)
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
return enriched, failed_ids
|
||||
|
||||
794
app/assets/seeder.py
Normal file
794
app/assets/seeder.py
Normal file
@@ -0,0 +1,794 @@
|
||||
"""Background asset seeder with thread management and cancellation support."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_METADATA,
|
||||
ENRICHMENT_STUB,
|
||||
RootType,
|
||||
build_asset_specs,
|
||||
collect_paths_for_roots,
|
||||
enrich_assets_batch,
|
||||
get_all_known_prefixes,
|
||||
get_prefixes_for_root,
|
||||
get_unenriched_assets_for_roots,
|
||||
insert_asset_specs,
|
||||
mark_missing_outside_prefixes_safely,
|
||||
sync_root_safely,
|
||||
)
|
||||
from app.database.db import dependencies_available
|
||||
|
||||
|
||||
class ScanInProgressError(Exception):
|
||||
"""Raised when an operation cannot proceed because a scan is running."""
|
||||
|
||||
|
||||
class State(Enum):
|
||||
"""Seeder state machine states."""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
PAUSED = "PAUSED"
|
||||
CANCELLING = "CANCELLING"
|
||||
|
||||
|
||||
class ScanPhase(Enum):
|
||||
"""Scan phase options."""
|
||||
|
||||
FAST = "fast" # Phase 1: filesystem only (stubs)
|
||||
ENRICH = "enrich" # Phase 2: metadata + hash
|
||||
FULL = "full" # Both phases sequentially
|
||||
|
||||
|
||||
@dataclass
|
||||
class Progress:
|
||||
"""Progress information for a scan operation."""
|
||||
|
||||
scanned: int = 0
|
||||
total: int = 0
|
||||
created: int = 0
|
||||
skipped: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanStatus:
|
||||
"""Current status of the asset seeder."""
|
||||
|
||||
state: State
|
||||
progress: Progress | None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class _AssetSeeder:
|
||||
"""Background asset scanning manager.
|
||||
|
||||
Spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
Use the module-level ``asset_seeder`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._run_gate = threading.Event()
|
||||
self._run_gate.set() # Start unpaused (set = running, clear = paused)
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._phase: ScanPhase = ScanPhase.FULL
|
||||
self._compute_hashes: bool = False
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
self._disabled = True
|
||||
logging.info("Asset seeder disabled")
|
||||
|
||||
def is_disabled(self) -> bool:
|
||||
"""Check if the asset seeder is disabled."""
|
||||
return self._disabled
|
||||
|
||||
def start(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
phase: ScanPhase = ScanPhase.FULL,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start a background scan for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
|
||||
progress_callback: Optional callback called with progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
compute_hashes: If True, compute blake3 hashes (slow)
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
if self._disabled:
|
||||
logging.debug("Asset seeder is disabled, skipping start")
|
||||
return False
|
||||
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
logging.info("Asset seeder already running, skipping start")
|
||||
return False
|
||||
self._state = State.RUNNING
|
||||
self._progress = Progress()
|
||||
self._errors = []
|
||||
self._roots = roots
|
||||
self._phase = phase
|
||||
self._prune_first = prune_first
|
||||
self._compute_hashes = compute_hashes
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._run_gate.set() # Ensure unpaused when starting
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="_AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def start_fast(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
) -> bool:
|
||||
"""Start a fast scan (phase 1 only) - creates stub records.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.FAST,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=prune_first,
|
||||
compute_hashes=False,
|
||||
)
|
||||
|
||||
def start_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.ENRICH,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=False,
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
Returns:
|
||||
True if cancellation was requested, False if not running or paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state not in (State.RUNNING, State.PAUSED):
|
||||
return False
|
||||
logging.info("Asset seeder cancelling (was %s)", self._state.value)
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
self._run_gate.set() # Unblock if paused so thread can exit
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""Stop the current scan (alias for cancel).
|
||||
|
||||
Returns:
|
||||
True if stop was requested, False if not running
|
||||
"""
|
||||
return self.cancel()
|
||||
|
||||
def pause(self) -> bool:
|
||||
"""Pause the current scan.
|
||||
|
||||
The scan will complete its current batch before pausing.
|
||||
|
||||
Returns:
|
||||
True if pause was requested, False if not running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.RUNNING:
|
||||
return False
|
||||
logging.info("Asset seeder pausing")
|
||||
self._state = State.PAUSED
|
||||
self._run_gate.clear()
|
||||
return True
|
||||
|
||||
def resume(self) -> bool:
|
||||
"""Resume a paused scan.
|
||||
|
||||
This is a noop if the scan is not in the PAUSED state
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.PAUSED:
|
||||
return False
|
||||
logging.info("Asset seeder resuming")
|
||||
self._state = State.RUNNING
|
||||
self._run_gate.set()
|
||||
self._emit_event("assets.seed.resumed", {})
|
||||
return True
|
||||
|
||||
def restart(
|
||||
self,
|
||||
roots: tuple[RootType, ...] | None = None,
|
||||
phase: ScanPhase | None = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool | None = None,
|
||||
compute_hashes: bool | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> bool:
|
||||
"""Cancel any running scan and start a new one.
|
||||
|
||||
Args:
|
||||
roots: Roots to scan (defaults to previous roots)
|
||||
phase: Scan phase (defaults to previous phase)
|
||||
progress_callback: Progress callback (defaults to previous)
|
||||
prune_first: Prune before scan (defaults to previous)
|
||||
compute_hashes: Compute hashes (defaults to previous)
|
||||
timeout: Max seconds to wait for current scan to stop
|
||||
|
||||
Returns:
|
||||
True if new scan was started, False if failed to stop previous
|
||||
"""
|
||||
logging.info("Asset seeder restart requested")
|
||||
with self._lock:
|
||||
prev_roots = self._roots
|
||||
prev_phase = self._phase
|
||||
prev_callback = self._progress_callback
|
||||
prev_prune = self._prune_first
|
||||
prev_hashes = self._compute_hashes
|
||||
|
||||
self.cancel()
|
||||
if not self.wait(timeout=timeout):
|
||||
return False
|
||||
|
||||
cb = progress_callback if progress_callback is not None else prev_callback
|
||||
return self.start(
|
||||
roots=roots if roots is not None else prev_roots,
|
||||
phase=phase if phase is not None else prev_phase,
|
||||
progress_callback=cb,
|
||||
prune_first=prune_first if prune_first is not None else prev_prune,
|
||||
compute_hashes=(
|
||||
compute_hashes if compute_hashes is not None else prev_hashes
|
||||
),
|
||||
)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Wait for the current scan to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait, or None for no timeout
|
||||
|
||||
Returns:
|
||||
True if scan completed, False if timeout expired or no scan running
|
||||
"""
|
||||
with self._lock:
|
||||
thread = self._thread
|
||||
if thread is None:
|
||||
return True
|
||||
thread.join(timeout=timeout)
|
||||
return not thread.is_alive()
|
||||
|
||||
def get_status(self) -> ScanStatus:
|
||||
"""Get the current status and progress of the seeder."""
|
||||
with self._lock:
|
||||
src = self._progress or self._last_progress
|
||||
return ScanStatus(
|
||||
state=self._state,
|
||||
progress=Progress(
|
||||
scanned=src.scanned,
|
||||
total=src.total,
|
||||
created=src.created,
|
||||
skipped=src.skipped,
|
||||
)
|
||||
if src
|
||||
else None,
|
||||
errors=list(self._errors),
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for thread to exit
|
||||
"""
|
||||
self.cancel()
|
||||
self.wait(timeout=timeout)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
def mark_missing_outside_prefixes(self) -> int:
|
||||
"""Mark references as missing when outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their
|
||||
metadata are preserved, but references are flagged as missing.
|
||||
They can be restored if the file reappears in a future scan.
|
||||
|
||||
This operation is decoupled from scanning to prevent partial scans
|
||||
from accidentally marking assets belonging to other roots.
|
||||
|
||||
Should be called explicitly when cleanup is desired, typically after
|
||||
a full scan of all roots or during maintenance.
|
||||
|
||||
Returns:
|
||||
Number of references marked as missing
|
||||
|
||||
Raises:
|
||||
ScanInProgressError: If a scan is currently running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
raise ScanInProgressError(
|
||||
"Cannot mark missing assets while scan is running"
|
||||
)
|
||||
self._state = State.RUNNING
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
logging.warning(
|
||||
"Database dependencies not available, skipping mark missing"
|
||||
)
|
||||
return 0
|
||||
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d references as missing", marked)
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def _is_paused_or_cancelled(self) -> bool:
|
||||
"""Non-blocking check: True if paused or cancelled.
|
||||
|
||||
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
|
||||
file handles are released immediately on pause rather than held
|
||||
open while blocked. The caller is responsible for blocking on
|
||||
_check_pause_and_cancel() afterward.
|
||||
"""
|
||||
return not self._run_gate.is_set() or self._cancel_event.is_set()
|
||||
|
||||
def _check_pause_and_cancel(self) -> bool:
|
||||
"""Block while paused, then check if cancelled.
|
||||
|
||||
Call this at checkpoint locations in scan loops. It will:
|
||||
1. Block indefinitely while paused (until resume or cancel)
|
||||
2. Return True if cancelled, False to continue
|
||||
|
||||
Returns:
|
||||
True if scan should stop, False to continue
|
||||
"""
|
||||
if not self._run_gate.is_set():
|
||||
self._emit_event("assets.seed.paused", {})
|
||||
self._run_gate.wait() # Blocks if paused
|
||||
return self._is_cancelled()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
"""Emit a WebSocket event if server is available."""
|
||||
try:
|
||||
from server import PromptServer
|
||||
|
||||
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||
PromptServer.instance.send_sync(event_type, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
scanned: int | None = None,
|
||||
total: int | None = None,
|
||||
created: int | None = None,
|
||||
skipped: int | None = None,
|
||||
) -> None:
|
||||
"""Update progress counters (thread-safe)."""
|
||||
callback: ProgressCallback | None = None
|
||||
progress: Progress | None = None
|
||||
|
||||
with self._lock:
|
||||
if self._progress is None:
|
||||
return
|
||||
if scanned is not None:
|
||||
self._progress.scanned = scanned
|
||||
if total is not None:
|
||||
self._progress.total = total
|
||||
if created is not None:
|
||||
self._progress.created = created
|
||||
if skipped is not None:
|
||||
self._progress.skipped = skipped
|
||||
if self._progress_callback:
|
||||
callback = self._progress_callback
|
||||
progress = Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
|
||||
if callback and progress:
|
||||
try:
|
||||
callback(progress)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_MAX_ERRORS = 200
|
||||
|
||||
def _add_error(self, message: str) -> None:
|
||||
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
|
||||
with self._lock:
|
||||
if len(self._errors) < self._MAX_ERRORS:
|
||||
self._errors.append(message)
|
||||
|
||||
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||
"""Log the directories that will be scanned."""
|
||||
import folder_paths
|
||||
|
||||
for root in roots:
|
||||
if root == "models":
|
||||
logging.info(
|
||||
"Asset scan [models] directory: %s",
|
||||
os.path.abspath(folder_paths.models_dir),
|
||||
)
|
||||
else:
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if prefixes:
|
||||
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||
|
||||
def _run_scan(self) -> None:
|
||||
"""Main scan loop running in background thread."""
|
||||
t_start = time.perf_counter()
|
||||
roots = self._roots
|
||||
phase = self._phase
|
||||
cancelled = False
|
||||
total_created = 0
|
||||
total_enriched = 0
|
||||
skipped_existing = 0
|
||||
total_paths = 0
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
self._add_error("Database dependencies not available")
|
||||
self._emit_event(
|
||||
"assets.seed.error",
|
||||
{"message": "Database dependencies not available"},
|
||||
)
|
||||
return
|
||||
|
||||
if self._prune_first:
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d refs as missing before scan", marked)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Asset scan cancelled after pruning phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._log_scan_config(roots)
|
||||
|
||||
# Phase 1: Fast scan (stub records)
|
||||
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
||||
created, skipped, paths = self._run_fast_phase(roots)
|
||||
total_created, skipped_existing, total_paths = created, skipped, paths
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.fast_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"created": total_created,
|
||||
"skipped": skipped_existing,
|
||||
"total": total_paths,
|
||||
},
|
||||
)
|
||||
|
||||
# Phase 2: Enrichment scan (metadata + hashes)
|
||||
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
|
||||
|
||||
if enrich_cancelled:
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.enrich_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logging.info(
|
||||
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
|
||||
roots,
|
||||
phase.value,
|
||||
elapsed,
|
||||
total_created,
|
||||
total_enriched,
|
||||
skipped_existing,
|
||||
)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.completed",
|
||||
{
|
||||
"phase": phase.value,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
"enriched": total_enriched,
|
||||
"skipped": skipped_existing,
|
||||
"elapsed": round(elapsed, 3),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._add_error(f"Scan failed: {e}")
|
||||
logging.exception("Asset scan failed")
|
||||
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||
finally:
|
||||
if cancelled:
|
||||
self._emit_event(
|
||||
"assets.seed.cancelled",
|
||||
{
|
||||
"scanned": self._progress.scanned if self._progress else 0,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_created, skipped_existing, total_paths)
|
||||
"""
|
||||
t_fast_start = time.perf_counter()
|
||||
total_created = 0
|
||||
skipped_existing = 0
|
||||
|
||||
existing_paths: set[str] = set()
|
||||
t_sync = time.perf_counter()
|
||||
for r in roots:
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
existing_paths.update(sync_root_safely(r))
|
||||
logging.debug(
|
||||
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
|
||||
time.perf_counter() - t_sync,
|
||||
len(existing_paths),
|
||||
)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
|
||||
t_collect = time.perf_counter()
|
||||
paths = collect_paths_for_roots(roots)
|
||||
logging.debug(
|
||||
"Fast scan: collect_paths took %.3fs (%d paths found)",
|
||||
time.perf_counter() - t_collect,
|
||||
len(paths),
|
||||
)
|
||||
total_paths = len(paths)
|
||||
self._update_progress(total=total_paths)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "total": total_paths, "phase": "fast"},
|
||||
)
|
||||
|
||||
# Use stub specs (no metadata extraction, no hashing)
|
||||
t_specs = time.perf_counter()
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths,
|
||||
existing_paths,
|
||||
enable_metadata_extraction=False,
|
||||
compute_hashes=False,
|
||||
)
|
||||
logging.debug(
|
||||
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
|
||||
time.perf_counter() - t_specs,
|
||||
len(specs),
|
||||
skipped_existing,
|
||||
)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch_size = 500
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
for i in range(0, len(specs), batch_size):
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info(
|
||||
"Fast scan cancelled after %d/%d files (created=%d)",
|
||||
i,
|
||||
len(specs),
|
||||
total_created,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch = specs[i : i + batch_size]
|
||||
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||
try:
|
||||
created = insert_asset_specs(batch, batch_tags)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||
logging.exception("Batch insert failed at offset %d", i)
|
||||
|
||||
scanned = i + len(batch)
|
||||
now = time.perf_counter()
|
||||
self._update_progress(scanned=scanned, created=total_created)
|
||||
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "fast",
|
||||
"scanned": scanned,
|
||||
"total": len(specs),
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
logging.info(
|
||||
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
|
||||
time.perf_counter() - t_fast_start,
|
||||
total_created,
|
||||
skipped_existing,
|
||||
total_paths,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
|
||||
"""Run phase 2: enrich existing records with metadata and hashes.
|
||||
|
||||
Returns:
|
||||
Tuple of (cancelled, total_enriched)
|
||||
"""
|
||||
total_enriched = 0
|
||||
batch_size = 100
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
# Get the target enrichment level based on compute_hashes
|
||||
if not self._compute_hashes:
|
||||
target_max_level = ENRICHMENT_STUB
|
||||
else:
|
||||
target_max_level = ENRICHMENT_METADATA
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "phase": "enrich"},
|
||||
)
|
||||
|
||||
skip_ids: set[str] = set()
|
||||
consecutive_empty = 0
|
||||
max_consecutive_empty = 3
|
||||
|
||||
# Hash checkpoints survive across batches so interrupted hashes
|
||||
# can be resumed without re-reading the entire file.
|
||||
hash_checkpoints: dict[str, object] = {}
|
||||
|
||||
while True:
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||
return True, total_enriched
|
||||
|
||||
# Fetch next batch of unenriched assets
|
||||
unenriched = get_unenriched_assets_for_roots(
|
||||
roots,
|
||||
max_level=target_max_level,
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
# Filter out previously failed references
|
||||
if skip_ids:
|
||||
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
|
||||
|
||||
if not unenriched:
|
||||
break
|
||||
|
||||
enriched, failed_ids = enrich_assets_batch(
|
||||
unenriched,
|
||||
extract_metadata=True,
|
||||
compute_hash=self._compute_hashes,
|
||||
interrupt_check=self._is_paused_or_cancelled,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
total_enriched += enriched
|
||||
skip_ids.update(failed_ids)
|
||||
|
||||
if enriched == 0:
|
||||
consecutive_empty += 1
|
||||
if consecutive_empty >= max_consecutive_empty:
|
||||
logging.warning(
|
||||
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
|
||||
consecutive_empty,
|
||||
len(skip_ids),
|
||||
)
|
||||
break
|
||||
else:
|
||||
consecutive_empty = 0
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "enrich",
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
return False, total_enriched
|
||||
|
||||
|
||||
asset_seeder = _AssetSeeder()
|
||||
87
app/assets/services/__init__.py
Normal file
87
app/assets/services/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from app.assets.services.asset_management import (
|
||||
asset_exists,
|
||||
delete_asset_reference,
|
||||
get_asset_by_hash,
|
||||
get_asset_detail,
|
||||
list_assets_page,
|
||||
resolve_asset_for_download,
|
||||
set_asset_preview,
|
||||
update_asset_metadata,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
BulkInsertResult,
|
||||
batch_insert_seed_assets,
|
||||
cleanup_unreferenced_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
get_size_and_mtime_ns,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
IngestResult,
|
||||
ListAssetsResult,
|
||||
ReferenceData,
|
||||
RegisterAssetResult,
|
||||
TagUsage,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
)
|
||||
from app.assets.services.tagging import (
|
||||
apply_tags,
|
||||
list_tags,
|
||||
remove_tags,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"AssetData",
|
||||
"AssetDetailResult",
|
||||
"AssetSummaryData",
|
||||
"ReferenceData",
|
||||
"BulkInsertResult",
|
||||
"DependencyMissingError",
|
||||
"DownloadResolutionResult",
|
||||
"HashMismatchError",
|
||||
"IngestResult",
|
||||
"ListAssetsResult",
|
||||
"RegisterAssetResult",
|
||||
"RemoveTagsResult",
|
||||
"TagUsage",
|
||||
"UploadResult",
|
||||
"UserMetadata",
|
||||
"apply_tags",
|
||||
"asset_exists",
|
||||
"batch_insert_seed_assets",
|
||||
"create_from_hash",
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
"list_files_recursively",
|
||||
"list_tags",
|
||||
"cleanup_unreferenced_assets",
|
||||
"remove_tags",
|
||||
"resolve_asset_for_download",
|
||||
"set_asset_preview",
|
||||
"update_asset_metadata",
|
||||
"upload_from_temp_path",
|
||||
"verify_file_unchanged",
|
||||
]
|
||||
309
app/assets/services/asset_management.py
Normal file
309
app/assets/services/asset_management.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
reference_exists_for_asset_id,
|
||||
delete_reference_by_id,
|
||||
fetch_reference_and_asset,
|
||||
soft_delete_reference_by_id,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_asset_by_hash as queries_get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
ListAssetsResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def get_asset_detail(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult | None:
|
||||
with create_session() as session:
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
ref, asset, tags = result
|
||||
return AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def update_asset_metadata(
|
||||
reference_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
touched = False
|
||||
if name is not None and name != ref.name:
|
||||
update_reference_name(session, reference_id=reference_id, name=name)
|
||||
touched = True
|
||||
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
|
||||
new_meta: dict | None = None
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
elif computed_filename:
|
||||
current_meta = ref.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
|
||||
if new_meta is not None:
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
set_reference_metadata(
|
||||
session, reference_id=reference_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during update")
|
||||
|
||||
ref, asset, tag_list = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_list,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def delete_asset_reference(
|
||||
reference_id: str,
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
deleted = soft_delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
session.commit()
|
||||
return deleted
|
||||
|
||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
||||
asset_id = ref_row.asset_id if ref_row else None
|
||||
file_path = ref_row.file_path if ref_row else None
|
||||
|
||||
deleted = delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - delete it and its files
|
||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||
]
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Delete files after commit
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
|
||||
ref, asset, tags = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> ListAssetsResult:
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
AssetSummaryData(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(ref.asset),
|
||||
tags=tag_map.get(ref.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult:
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref, asset = pair
|
||||
|
||||
# For references with file_path, use that directly
|
||||
if ref.file_path and os.path.isfile(ref.file_path):
|
||||
abs_path = ref.file_path
|
||||
else:
|
||||
# For API-created refs without file_path, find a path from other refs
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = select_best_live_path(refs)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError(
|
||||
f"No live path for AssetReference {reference_id} "
|
||||
f"(asset id={asset.id}, name={ref.name})"
|
||||
)
|
||||
|
||||
# Capture ORM attributes before commit (commit expires loaded objects)
|
||||
ref_name = ref.name
|
||||
asset_mime = asset.mime_type
|
||||
|
||||
update_reference_access_time(session, reference_id=reference_id)
|
||||
session.commit()
|
||||
|
||||
ctype = (
|
||||
asset_mime
|
||||
or mimetypes.guess_type(ref_name or abs_path)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
download_name = ref_name or os.path.basename(abs_path)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=download_name,
|
||||
)
|
||||
280
app/assets/services/bulk_ingest.py
Normal file
280
app/assets/services/bulk_ingest.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.queries import (
|
||||
bulk_insert_assets,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
|
||||
|
||||
class SeedAssetSpec(TypedDict):
|
||||
"""Spec for seeding an asset from filesystem."""
|
||||
|
||||
abs_path: str
|
||||
size_bytes: int
|
||||
mtime_ns: int
|
||||
info_name: str
|
||||
tags: list[str]
|
||||
fname: str
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
"""Row data for inserting an Asset."""
|
||||
|
||||
id: str
|
||||
hash: str | None
|
||||
size_bytes: int
|
||||
mime_type: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ReferenceRow(TypedDict):
|
||||
"""Row data for inserting an AssetReference."""
|
||||
|
||||
id: str
|
||||
asset_id: str
|
||||
file_path: str
|
||||
mtime_ns: int
|
||||
owner_id: str
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
|
||||
|
||||
class TagRow(TypedDict):
|
||||
"""Row data for inserting a Tag."""
|
||||
|
||||
asset_reference_id: str
|
||||
tag_name: str
|
||||
origin: str
|
||||
added_at: datetime
|
||||
|
||||
|
||||
class MetadataRow(TypedDict):
|
||||
"""Row data for inserting asset metadata."""
|
||||
|
||||
asset_reference_id: str
|
||||
key: str
|
||||
ordinal: int
|
||||
val_str: str | None
|
||||
val_num: float | None
|
||||
val_bool: bool | None
|
||||
val_json: dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkInsertResult:
|
||||
"""Result of bulk asset insertion."""
|
||||
|
||||
inserted_refs: int
|
||||
won_paths: int
|
||||
lost_paths: int
|
||||
|
||||
|
||||
def batch_insert_seed_assets(
|
||||
session: Session,
|
||||
specs: list[SeedAssetSpec],
|
||||
owner_id: str = "",
|
||||
) -> BulkInsertResult:
|
||||
"""Seed assets from filesystem specs in batch.
|
||||
|
||||
Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
|
||||
This function orchestrates:
|
||||
1. Insert seed Assets (hash=NULL)
|
||||
2. Claim references with ON CONFLICT DO NOTHING on file_path
|
||||
3. Query to find winners (paths where our asset_id was inserted)
|
||||
4. Delete Assets for losers (path already claimed by another asset)
|
||||
5. Insert tags and metadata for successfully inserted references
|
||||
|
||||
Returns:
|
||||
BulkInsertResult with inserted_refs, won_paths, lost_paths
|
||||
"""
|
||||
if not specs:
|
||||
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
|
||||
|
||||
current_time = get_utc_now()
|
||||
asset_rows: list[AssetRow] = []
|
||||
reference_rows: list[ReferenceRow] = []
|
||||
path_to_asset_id: dict[str, str] = {}
|
||||
asset_id_to_ref_data: dict[str, dict] = {}
|
||||
absolute_path_list: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
absolute_path = os.path.abspath(spec["abs_path"])
|
||||
asset_id = str(uuid.uuid4())
|
||||
reference_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
path_to_asset_id[absolute_path] = asset_id
|
||||
|
||||
mime_type = spec.get("mime_type")
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": asset_id,
|
||||
"hash": spec.get("hash"),
|
||||
"size_bytes": spec["size_bytes"],
|
||||
"mime_type": mime_type,
|
||||
"created_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Build user_metadata from extracted metadata or fallback to filename
|
||||
extracted_metadata = spec.get("metadata")
|
||||
if extracted_metadata:
|
||||
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
|
||||
elif spec["fname"]:
|
||||
user_metadata = {"filename": spec["fname"]}
|
||||
else:
|
||||
user_metadata = None
|
||||
|
||||
reference_rows.append(
|
||||
{
|
||||
"id": reference_id,
|
||||
"asset_id": asset_id,
|
||||
"file_path": absolute_path,
|
||||
"mtime_ns": spec["mtime_ns"],
|
||||
"owner_id": owner_id,
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
asset_id_to_ref_data[asset_id] = {
|
||||
"reference_id": reference_id,
|
||||
"tags": spec["tags"],
|
||||
"filename": spec["fname"],
|
||||
"extracted_metadata": extracted_metadata,
|
||||
}
|
||||
|
||||
bulk_insert_assets(session, asset_rows)
|
||||
|
||||
# Filter reference rows to only those whose assets were actually inserted
|
||||
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
|
||||
inserted_asset_ids = get_existing_asset_ids(
|
||||
session, [r["asset_id"] for r in reference_rows]
|
||||
)
|
||||
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
|
||||
|
||||
bulk_insert_references_ignore_conflicts(session, reference_rows)
|
||||
restore_references_by_paths(session, absolute_path_list)
|
||||
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
|
||||
|
||||
inserted_paths = {
|
||||
path
|
||||
for path in absolute_path_list
|
||||
if path_to_asset_id[path] in inserted_asset_ids
|
||||
}
|
||||
losing_paths = inserted_paths - winning_paths
|
||||
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
|
||||
|
||||
if lost_asset_ids:
|
||||
delete_assets_by_ids(session, lost_asset_ids)
|
||||
|
||||
if not winning_paths:
|
||||
return BulkInsertResult(
|
||||
inserted_refs=0,
|
||||
won_paths=0,
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
# Get reference IDs for winners
|
||||
winning_ref_ids = [
|
||||
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
|
||||
for path in winning_paths
|
||||
]
|
||||
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
|
||||
|
||||
tag_rows: list[TagRow] = []
|
||||
metadata_rows: list[MetadataRow] = []
|
||||
|
||||
if inserted_ref_ids:
|
||||
for path in winning_paths:
|
||||
asset_id = path_to_asset_id[path]
|
||||
ref_data = asset_id_to_ref_data[asset_id]
|
||||
ref_id = ref_data["reference_id"]
|
||||
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
for tag in ref_data["tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Use extracted metadata for meta rows if available
|
||||
extracted_metadata = ref_data.get("extracted_metadata")
|
||||
if extracted_metadata:
|
||||
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
|
||||
elif ref_data["filename"]:
|
||||
# Fallback: just store filename
|
||||
metadata_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": ref_data["filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
inserted_refs=len(inserted_ref_ids),
|
||||
won_paths=len(winning_paths),
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
|
||||
def cleanup_unreferenced_assets(session: Session) -> int:
|
||||
"""Hard-delete unhashed assets with no active references.
|
||||
|
||||
This is a destructive operation intended for explicit cleanup.
|
||||
Only deletes assets where hash=None and all references are missing.
|
||||
|
||||
Returns:
|
||||
Number of assets deleted
|
||||
"""
|
||||
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
|
||||
return delete_assets_by_ids(session, unreferenced_ids)
|
||||
70
app/assets/services/file_utils.py
Normal file
70
app/assets/services/file_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_mtime_ns(stat_result: os.stat_result) -> int:
|
||||
"""Extract mtime in nanoseconds from a stat result."""
|
||||
return getattr(
|
||||
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
|
||||
)
|
||||
|
||||
|
||||
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
|
||||
"""Get file size in bytes and mtime in nanoseconds."""
|
||||
st = os.stat(path, follow_symlinks=follow_symlinks)
|
||||
return st.st_size, get_mtime_ns(st)
|
||||
|
||||
|
||||
def verify_file_unchanged(
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
"""Check if a file is unchanged based on mtime and size.
|
||||
|
||||
Returns True if the file's mtime and size match the database values.
|
||||
Returns False if mtime_db is None or values don't match.
|
||||
|
||||
size_db=None means don't check size; 0 is a valid recorded size.
|
||||
"""
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = get_mtime_ns(stat_result)
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
if size_db is not None:
|
||||
return int(stat_result.st_size) == int(size_db)
|
||||
return True
|
||||
|
||||
|
||||
def is_visible(name: str) -> bool:
|
||||
"""Return True if a file or directory name is visible (not hidden)."""
|
||||
return not name.startswith(".")
|
||||
|
||||
|
||||
def list_files_recursively(base_dir: str) -> list[str]:
|
||||
"""Recursively list all files in a directory, following symlinks."""
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
# Track seen real directory identities to prevent circular symlink loops
|
||||
seen_dirs: set[tuple[int, int]] = set()
|
||||
for dirpath, subdirs, filenames in os.walk(
|
||||
base_abs, topdown=True, followlinks=True
|
||||
):
|
||||
try:
|
||||
st = os.stat(dirpath)
|
||||
dir_id = (st.st_dev, st.st_ino)
|
||||
except OSError:
|
||||
subdirs.clear()
|
||||
continue
|
||||
if dir_id in seen_dirs:
|
||||
subdirs.clear()
|
||||
continue
|
||||
seen_dirs.add(dir_id)
|
||||
subdirs[:] = [d for d in subdirs if is_visible(d)]
|
||||
for name in filenames:
|
||||
if not is_visible(name):
|
||||
continue
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
99
app/assets/services/hashing.py
Normal file
99
app/assets/services/hashing.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import io
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import IO, Any, Callable, Iterator
|
||||
import logging
|
||||
|
||||
try:
|
||||
from blake3 import blake3
|
||||
except ModuleNotFoundError:
|
||||
logging.warning("WARNING: blake3 package not installed")
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HashCheckpoint:
|
||||
"""Saved state for resuming an interrupted hash computation."""
|
||||
|
||||
bytes_processed: int
|
||||
hasher: Any # blake3 hasher instance
|
||||
mtime_ns: int = 0
|
||||
file_size: int = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
|
||||
"""Yield (file_object, is_path) with appropriate setup/teardown."""
|
||||
if hasattr(fp, "read"):
|
||||
seekable = getattr(fp, "seekable", lambda: False)()
|
||||
orig_pos = None
|
||||
if seekable:
|
||||
try:
|
||||
orig_pos = fp.tell()
|
||||
if orig_pos != 0:
|
||||
fp.seek(0)
|
||||
except io.UnsupportedOperation:
|
||||
orig_pos = None
|
||||
try:
|
||||
yield fp, False
|
||||
finally:
|
||||
if orig_pos is not None:
|
||||
fp.seek(orig_pos)
|
||||
else:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
yield f, True
|
||||
|
||||
|
||||
def compute_blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
interrupt_check: InterruptCheck | None = None,
|
||||
checkpoint: HashCheckpoint | None = None,
|
||||
) -> tuple[str | None, HashCheckpoint | None]:
|
||||
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
|
||||
|
||||
Args:
|
||||
fp: File path or file-like object
|
||||
chunk_size: Size of chunks to read at a time
|
||||
interrupt_check: Optional callable that returns True if the operation
|
||||
should be interrupted (e.g. paused or cancelled). Must be
|
||||
non-blocking so file handles are released immediately. Checked
|
||||
between chunk reads.
|
||||
checkpoint: Optional checkpoint to resume from (file paths only)
|
||||
|
||||
Returns:
|
||||
Tuple of (hex_digest, None) on completion, or
|
||||
(None, checkpoint) on interruption (file paths only), or
|
||||
(None, None) on interruption of a file object
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
with _open_for_hashing(fp) as (f, is_path):
|
||||
if checkpoint is not None and is_path:
|
||||
f.seek(checkpoint.bytes_processed)
|
||||
h = checkpoint.hasher
|
||||
bytes_processed = checkpoint.bytes_processed
|
||||
else:
|
||||
h = blake3()
|
||||
bytes_processed = 0
|
||||
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
if is_path:
|
||||
return None, HashCheckpoint(
|
||||
bytes_processed=bytes_processed,
|
||||
hasher=h,
|
||||
)
|
||||
return None, None
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
bytes_processed += len(chunk)
|
||||
|
||||
return h.hexdigest(), None
|
||||
375
app/assets/services/ingest.py
Normal file
375
app/assets/services/ingest.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def _ingest_file_from_path(
|
||||
abs_path: str,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
ref_created = False
|
||||
ref_updated = False
|
||||
reference_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
ref_created, ref_updated = upsert_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
file_path=locator,
|
||||
name=info_name or os.path.basename(locator),
|
||||
mtime_ns=mtime_ns,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
# Get the reference we just created/updated
|
||||
ref = get_reference_by_file_path(session, locator)
|
||||
if ref:
|
||||
reference_id = ref.id
|
||||
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, norm)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=ref.file_path,
|
||||
current_metadata=ref.user_metadata,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return IngestResult(
|
||||
asset_created=asset_created,
|
||||
asset_updated=asset_updated,
|
||||
ref_created=ref_created,
|
||||
ref_updated=ref_updated,
|
||||
reference_id=reference_id,
|
||||
)
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=False,
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
session.refresh(ref)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str | None,
|
||||
current_metadata: dict | None,
|
||||
user_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
computed_filename = compute_relative_filename(file_path) if file_path else None
|
||||
|
||||
current_meta = current_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
|
||||
class HashMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DependencyMissingError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def upload_from_temp_path(
|
||||
temp_path: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_hash and asset_hash != expected_hash.strip().lower():
|
||||
raise HashMismatchError("Uploaded file hash does not match provided hash.")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _sanitize_filename(name or client_filename, fallback=digest)
|
||||
result = _register_existing_asset(
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
if not tags:
|
||||
raise ValueError("tags are required for new asset uploads")
|
||||
base_dir, subdirs = resolve_destination_from_tags(tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
327
app/assets/services/metadata_extract.py
Normal file
327
app/assets/services/metadata_extract.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Metadata extraction for asset scanning.
|
||||
|
||||
Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from utils.mime_types import init_mime_types
|
||||
|
||||
init_mime_types()
|
||||
|
||||
# Supported safetensors extensions
|
||||
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||
|
||||
# Maximum safetensors header size to read (8MB)
|
||||
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
"""Metadata extracted from a file during scanning."""
|
||||
|
||||
# Tier 1: Filesystem (always available)
|
||||
filename: str = ""
|
||||
file_path: str = "" # Full absolute path to the file
|
||||
content_length: int = 0
|
||||
content_type: str | None = None
|
||||
format: str = "" # file extension without dot
|
||||
|
||||
# Tier 2: Safetensors header (if available)
|
||||
base_model: str | None = None
|
||||
trained_words: list[str] | None = None
|
||||
air: str | None = None # CivitAI AIR identifier
|
||||
has_preview_images: bool = False
|
||||
|
||||
# Source provenance (populated if embedded in safetensors)
|
||||
source_url: str | None = None
|
||||
source_arn: str | None = None
|
||||
repo_url: str | None = None
|
||||
preview_url: str | None = None
|
||||
source_hash: str | None = None
|
||||
|
||||
# HuggingFace specific
|
||||
repo_id: str | None = None
|
||||
revision: str | None = None
|
||||
filepath: str | None = None
|
||||
resolve_url: str | None = None
|
||||
|
||||
def to_user_metadata(self) -> dict[str, Any]:
|
||||
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
|
||||
data: dict[str, Any] = {
|
||||
"filename": self.filename,
|
||||
"content_length": self.content_length,
|
||||
"format": self.format,
|
||||
}
|
||||
if self.file_path:
|
||||
data["file_path"] = self.file_path
|
||||
if self.content_type:
|
||||
data["content_type"] = self.content_type
|
||||
|
||||
# Tier 2 fields
|
||||
if self.base_model:
|
||||
data["base_model"] = self.base_model
|
||||
if self.trained_words:
|
||||
data["trained_words"] = self.trained_words
|
||||
if self.air:
|
||||
data["air"] = self.air
|
||||
if self.has_preview_images:
|
||||
data["has_preview_images"] = True
|
||||
|
||||
# Source provenance
|
||||
if self.source_url:
|
||||
data["source_url"] = self.source_url
|
||||
if self.source_arn:
|
||||
data["source_arn"] = self.source_arn
|
||||
if self.repo_url:
|
||||
data["repo_url"] = self.repo_url
|
||||
if self.preview_url:
|
||||
data["preview_url"] = self.preview_url
|
||||
if self.source_hash:
|
||||
data["source_hash"] = self.source_hash
|
||||
|
||||
# HuggingFace
|
||||
if self.repo_id:
|
||||
data["repo_id"] = self.repo_id
|
||||
if self.revision:
|
||||
data["revision"] = self.revision
|
||||
if self.filepath:
|
||||
data["filepath"] = self.filepath
|
||||
if self.resolve_url:
|
||||
data["resolve_url"] = self.resolve_url
|
||||
|
||||
return data
|
||||
|
||||
def to_meta_rows(self, reference_id: str) -> list[dict]:
|
||||
"""Convert to asset_reference_meta rows for typed/indexed querying."""
|
||||
rows: list[dict] = []
|
||||
|
||||
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
||||
if val:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": val[:2048] if len(val) > 2048 else val,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_num(key: str, val: int | float | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": val,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_bool(key: str, val: bool | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": val,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
# Tier 1
|
||||
add_str("filename", self.filename)
|
||||
add_num("content_length", self.content_length)
|
||||
add_str("content_type", self.content_type)
|
||||
add_str("format", self.format)
|
||||
|
||||
# Tier 2
|
||||
add_str("base_model", self.base_model)
|
||||
add_str("air", self.air)
|
||||
has_previews = self.has_preview_images if self.has_preview_images else None
|
||||
add_bool("has_preview_images", has_previews)
|
||||
|
||||
# trained_words as multiple rows with ordinals
|
||||
if self.trained_words:
|
||||
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
||||
add_str("trained_words", word, ordinal=i)
|
||||
|
||||
# Source provenance
|
||||
add_str("source_url", self.source_url)
|
||||
add_str("source_arn", self.source_arn)
|
||||
add_str("repo_url", self.repo_url)
|
||||
add_str("preview_url", self.preview_url)
|
||||
add_str("source_hash", self.source_hash)
|
||||
|
||||
# HuggingFace
|
||||
add_str("repo_id", self.repo_id)
|
||||
add_str("revision", self.revision)
|
||||
add_str("filepath", self.filepath)
|
||||
add_str("resolve_url", self.resolve_url)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _read_safetensors_header(
|
||||
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
|
||||
) -> dict[str, Any] | None:
|
||||
"""Read only the JSON header from a safetensors file.
|
||||
|
||||
This is very fast - reads 8 bytes for header length, then the JSON header.
|
||||
No tensor data is loaded.
|
||||
|
||||
Args:
|
||||
path: Absolute path to safetensors file
|
||||
max_size: Maximum header size to read (default 8MB)
|
||||
|
||||
Returns:
|
||||
Parsed header dict or None if failed
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
header_bytes = f.read(8)
|
||||
if len(header_bytes) < 8:
|
||||
return None
|
||||
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
header_data = f.read(length_of_header)
|
||||
if len(header_data) < length_of_header:
|
||||
return None
|
||||
return json.loads(header_data.decode("utf-8"))
|
||||
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_safetensors_metadata(
|
||||
header: dict[str, Any], meta: ExtractedMetadata
|
||||
) -> None:
|
||||
"""Extract metadata from safetensors header __metadata__ section.
|
||||
|
||||
Modifies meta in-place.
|
||||
"""
|
||||
st_meta = header.get("__metadata__", {})
|
||||
if not isinstance(st_meta, dict):
|
||||
return
|
||||
|
||||
# Common model metadata
|
||||
meta.base_model = (
|
||||
st_meta.get("ss_base_model_version")
|
||||
or st_meta.get("modelspec.base_model")
|
||||
or st_meta.get("base_model")
|
||||
)
|
||||
|
||||
# Trained words / trigger words
|
||||
trained_words = st_meta.get("ss_tag_frequency")
|
||||
if trained_words and isinstance(trained_words, str):
|
||||
try:
|
||||
tag_freq = json.loads(trained_words)
|
||||
# Extract unique tags from all datasets
|
||||
all_tags: set[str] = set()
|
||||
for dataset_tags in tag_freq.values():
|
||||
if isinstance(dataset_tags, dict):
|
||||
all_tags.update(dataset_tags.keys())
|
||||
if all_tags:
|
||||
meta.trained_words = sorted(all_tags)[:100]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Direct trained_words field (some formats)
|
||||
if not meta.trained_words:
|
||||
tw = st_meta.get("trained_words")
|
||||
if isinstance(tw, str):
|
||||
try:
|
||||
parsed = json.loads(tw)
|
||||
if isinstance(parsed, list):
|
||||
meta.trained_words = [str(x) for x in parsed]
|
||||
else:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
except json.JSONDecodeError:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
elif isinstance(tw, list):
|
||||
meta.trained_words = [str(x) for x in tw]
|
||||
|
||||
# CivitAI AIR
|
||||
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
||||
|
||||
# Preview images (ssmd_cover_images)
|
||||
cover_images = st_meta.get("ssmd_cover_images")
|
||||
if cover_images:
|
||||
meta.has_preview_images = True
|
||||
|
||||
# Source provenance fields
|
||||
meta.source_url = st_meta.get("source_url")
|
||||
meta.source_arn = st_meta.get("source_arn")
|
||||
meta.repo_url = st_meta.get("repo_url")
|
||||
meta.preview_url = st_meta.get("preview_url")
|
||||
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
||||
|
||||
# HuggingFace fields
|
||||
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
||||
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
||||
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
||||
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
||||
|
||||
|
||||
def extract_file_metadata(
|
||||
abs_path: str,
|
||||
stat_result: os.stat_result | None = None,
|
||||
relative_filename: str | None = None,
|
||||
) -> ExtractedMetadata:
|
||||
"""Extract metadata from a file using tier 1 and tier 2 methods.
|
||||
|
||||
Tier 1: Filesystem metadata from path and stat
|
||||
Tier 2: Safetensors header parsing if applicable
|
||||
|
||||
Args:
|
||||
abs_path: Absolute path to the file
|
||||
stat_result: Optional pre-fetched stat result (saves a syscall)
|
||||
relative_filename: Optional relative filename to use instead of basename
|
||||
(e.g., "flux/123/model.safetensors" for model paths)
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata with all available fields populated
|
||||
"""
|
||||
meta = ExtractedMetadata()
|
||||
|
||||
# Tier 1: Filesystem metadata
|
||||
meta.filename = relative_filename or os.path.basename(abs_path)
|
||||
meta.file_path = abs_path
|
||||
_, ext = os.path.splitext(abs_path)
|
||||
meta.format = ext.lstrip(".").lower() if ext else ""
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(abs_path)
|
||||
meta.content_type = mime_type
|
||||
|
||||
# Size from stat
|
||||
if stat_result is None:
|
||||
try:
|
||||
stat_result = os.stat(abs_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if stat_result:
|
||||
meta.content_length = stat_result.st_size
|
||||
|
||||
# Tier 2: Safetensors header (if applicable and enabled)
|
||||
if ext.lower() in SAFETENSORS_EXTENSIONS:
|
||||
header = _read_safetensors_header(abs_path)
|
||||
if header:
|
||||
try:
|
||||
_extract_safetensors_metadata(header, meta)
|
||||
except Exception as e:
|
||||
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
|
||||
|
||||
return meta
|
||||
167
app/assets/services/path_utils.py
Normal file
167
app/assets/services/path_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
|
||||
Includes every category registered in folder_names_and_paths,
|
||||
regardless of whether its paths are under the main models_dir,
|
||||
but excludes non-model entries like custom_nodes.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
if name in _NON_MODEL_FOLDER_NAMES:
|
||||
continue
|
||||
paths, _exts = values[0], values[1]
|
||||
if paths:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = Path(os.path.abspath(candidate))
|
||||
base_abs = Path(os.path.abspath(base))
|
||||
if not cand_abs.is_relative_to(base_abs):
|
||||
raise ValueError("destination escapes base directory")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _check_is_within(child: str, parent: str) -> bool:
|
||||
return Path(child).is_relative_to(parent)
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _check_is_within(fp_abs, input_base):
|
||||
return "input", _compute_relative(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
109
app/assets/services/schemas.py
Normal file
109
app/assets/services/schemas.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
|
||||
UserMetadata = dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetData:
|
||||
hash: str | None
|
||||
size_bytes: int | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReferenceData:
|
||||
"""Data transfer object for AssetReference."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
file_path: str | None
|
||||
user_metadata: UserMetadata
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetDetailResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterAssetResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestResult:
|
||||
asset_created: bool
|
||||
asset_updated: bool
|
||||
ref_created: bool
|
||||
ref_updated: bool
|
||||
reference_id: str | None
|
||||
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetSummaryData:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadResolutionResult:
|
||||
abs_path: str
|
||||
content_type: str
|
||||
download_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created_new: bool
|
||||
|
||||
|
||||
def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
return ReferenceData(
|
||||
id=ref.id,
|
||||
name=ref.name,
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def extract_asset_data(asset: Asset | None) -> AssetData | None:
|
||||
if asset is None:
|
||||
return None
|
||||
return AssetData(
|
||||
hash=asset.hash,
|
||||
size_bytes=asset.size_bytes,
|
||||
mime_type=asset.mime_type,
|
||||
)
|
||||
75
app/assets/services/tagging.py
Normal file
75
app/assets/services/tagging.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
add_tags_to_reference,
|
||||
get_reference_with_owner_check,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def apply_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AddTagsResult:
|
||||
with create_session() as session:
|
||||
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
reference_row=ref_row,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> RemoveTagsResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = remove_tags_from_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[TagUsage], int]:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from filelock import FileLock, Timeout
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
@@ -14,8 +15,12 @@ try:
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
@@ -65,9 +70,69 @@ def get_db_path():
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
_db_lock = None
|
||||
|
||||
def _acquire_file_lock(db_path):
|
||||
"""Acquire an OS-level file lock to prevent multi-process access.
|
||||
|
||||
Uses filelock for cross-platform support (macOS, Linux, Windows).
|
||||
The OS automatically releases the lock when the process exits, even on crashes.
|
||||
"""
|
||||
global _db_lock
|
||||
lock_path = db_path + ".lock"
|
||||
_db_lock = FileLock(lock_path)
|
||||
try:
|
||||
_db_lock.acquire(timeout=0)
|
||||
except Timeout:
|
||||
raise RuntimeError(
|
||||
f"Could not acquire lock on database '{db_path}'. "
|
||||
"Another ComfyUI process may already be using it. "
|
||||
"Use --database-url to specify a separate database file."
|
||||
)
|
||||
|
||||
|
||||
def _is_memory_db(db_url):
|
||||
"""Check if the database URL refers to an in-memory SQLite database."""
|
||||
return db_url in ("sqlite:///:memory:", "sqlite://")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
if _is_memory_db(db_url):
|
||||
_init_memory_db(db_url)
|
||||
else:
|
||||
_init_file_db(db_url)
|
||||
|
||||
|
||||
def _init_memory_db(db_url):
|
||||
"""Initialize an in-memory SQLite database using metadata.create_all.
|
||||
|
||||
Alembic migrations don't work with in-memory SQLite because each
|
||||
connection gets its own separate database — tables created by Alembic's
|
||||
internal connection are lost immediately.
|
||||
"""
|
||||
engine = create_engine(
|
||||
db_url,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def _init_file_db(db_url):
|
||||
"""Initialize a file-backed SQLite database using Alembic migrations."""
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
@@ -75,6 +140,14 @@ def init_db():
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# Enable foreign key enforcement for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
@@ -104,6 +177,12 @@ def init_db():
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
# Acquire an OS-level file lock after migrations are complete.
|
||||
# Alembic uses its own connection, so we must wait until it's done
|
||||
# before locking — otherwise our own lock blocks the migration.
|
||||
conn.close()
|
||||
_acquire_file_lock(db_path)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
@@ -45,25 +45,7 @@ def get_installed_frontend_version():
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
@@ -217,25 +199,7 @@ class FrontendManager:
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
|
||||
@@ -46,6 +46,8 @@ class NodeReplaceManager:
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
|
||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}}]}}
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
|
||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
@@ -1 +1 @@
|
||||
{"revision":0,"last_node_id":25,"last_link_id":0,"nodes":[{"id":25,"type":"621ba4e2-22a8-482d-a369-023753198b7b","pos":[4610,-790],"size":[230,58],"flags":{},"order":4,"mode":0,"inputs":[{"label":"image","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":null}],"outputs":[{"label":"IMAGE","localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[]}],"title":"Sharpen","properties":{"proxyWidgets":[["24","value"]]},"widgets_values":[]}],"links":[],"version":0.4,"definitions":{"subgraphs":[{"id":"621ba4e2-22a8-482d-a369-023753198b7b","version":1,"state":{"lastGroupId":0,"lastNodeId":24,"lastLinkId":36,"lastRerouteId":0},"revision":0,"config":{},"name":"Sharpen","inputNode":{"id":-10,"bounding":[4090,-825,120,60]},"outputNode":{"id":-20,"bounding":[5150,-825,120,60]},"inputs":[{"id":"37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7","name":"images.image0","type":"IMAGE","linkIds":[34],"localized_name":"images.image0","label":"image","pos":[4190,-805]}],"outputs":[{"id":"e9182b3f-635c-4cd4-a152-4b4be17ae4b9","name":"IMAGE0","type":"IMAGE","linkIds":[35],"localized_name":"IMAGE0","label":"IMAGE","pos":[5170,-805]}],"widgets":[],"nodes":[{"id":24,"type":"PrimitiveFloat","pos":[4280,-1240],"size":[270,58],"flags":{},"order":0,"mode":0,"inputs":[{"label":"strength","localized_name":"value","name":"value","type":"FLOAT","widget":{"name":"value"},"link":null}],"outputs":[{"localized_name":"FLOAT","name":"FLOAT","type":"FLOAT","links":[36]}],"properties":{"Node name for S&R":"PrimitiveFloat","min":0,"max":3,"precision":2,"step":0.05},"widgets_values":[0.5]},{"id":23,"type":"GLSLShader","pos":[4570,-1240],"size":[370,192],"flags":{},"order":1,"mode":0,"inputs":[{"label":"image0","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":34},{"label":"image1","localized_name":"images.image1","name":"images.image1","shape":7,"type":"IMAGE","link":null},{"label":"u_float0","localized_name":"floats.u_float0","name":"floats.u_float0","shape":7,"type":"FLOAT","link":36},{"label":"u_float1","localized_name":"floats.u_float1","name":"floats.u_float1","shape":7,"type":"FLOAT","link":null},{"label":"u_int0","localized_name":"ints.u_int0","name":"ints.u_int0","shape":7,"type":"INT","link":null},{"localized_name":"fragment_shader","name":"fragment_shader","type":"STRING","widget":{"name":"fragment_shader"},"link":null},{"localized_name":"size_mode","name":"size_mode","type":"COMFY_DYNAMICCOMBO_V3","widget":{"name":"size_mode"},"link":null}],"outputs":[{"localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[35]},{"localized_name":"IMAGE1","name":"IMAGE1","type":"IMAGE","links":null},{"localized_name":"IMAGE2","name":"IMAGE2","type":"IMAGE","links":null},{"localized_name":"IMAGE3","name":"IMAGE3","type":"IMAGE","links":null}],"properties":{"Node name for S&R":"GLSLShader"},"widgets_values":["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}","from_input"]}],"groups":[],"links":[{"id":36,"origin_id":24,"origin_slot":0,"target_id":23,"target_slot":2,"type":"FLOAT"},{"id":34,"origin_id":-10,"origin_slot":0,"target_id":23,"target_slot":0,"type":"IMAGE"},{"id":35,"origin_id":23,"origin_slot":0,"target_id":-20,"target_slot":0,"type":"IMAGE"}],"extra":{"workflowRendererVersion":"LG"}}]}}
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
|
||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
@@ -27,6 +27,7 @@ class AudioEncoderModel():
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
comfy.model_management.archive_model_dtypes(self.model)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
@@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@@ -232,7 +232,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
@@ -260,4 +260,4 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -4,6 +4,25 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
if not is_equal(self.cond, other.cond):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
|
||||
@@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
|
||||
class ZImagePixelSpace(ChromaRadiance):
|
||||
"""Pixel-space latent format for ZImage DCT variant.
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -2,13 +2,19 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
ADALN_BASE_PARAMS_COUNT,
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
NormSingleLinearTextProjection,
|
||||
LTXVModel,
|
||||
apply_cross_attention_adaln,
|
||||
compute_prompt_timestep,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -86,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -93,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
@@ -100,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -110,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -121,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -131,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -143,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -155,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -167,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
@@ -214,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||
):
|
||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||
if self.cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(
|
||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||
)
|
||||
return apply_cross_attention_adaln(
|
||||
x, context, attn, shift_q, scale_q, gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask, transformer_options,
|
||||
)
|
||||
return attn(
|
||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||
mask=attention_mask, transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -233,13 +273,17 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx.add_(self._apply_text_cross_attention(
|
||||
vx, v_context, self.attn2, self.scale_shift_table,
|
||||
getattr(self, 'prompt_scale_shift_table', None),
|
||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
@@ -253,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ax.add_(self._apply_text_cross_attention(
|
||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
@@ -350,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
apply_gated_attention=False,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -361,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
self.apply_gated_attention = apply_gated_attention
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
@@ -385,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -399,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=audio_embedding_coefficient,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=2,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_prompt_adaln_single = None
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
@@ -442,14 +510,75 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = lambda a: a
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||
num_layers = kwargs.get("connector_num_layers", 2)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||
if not unprocessed:
|
||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||
return context
|
||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||
context_vid = context[:, :, :self.cross_attention_dim]
|
||||
context_audio = context[:, :, self.cross_attention_dim:]
|
||||
else:
|
||||
context_vid = context
|
||||
context_audio = context
|
||||
if self.caption_proj_before_connector:
|
||||
context_vid = self.caption_projection(context_vid)
|
||||
context_audio = self.audio_caption_projection(context_audio)
|
||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -463,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
apply_gated_attention=self.apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -584,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
@@ -594,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -636,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
|
||||
a_prompt_timestep = compute_prompt_timestep(
|
||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
a_prompt_timestep = None
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
], None
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
video_dim = vx.shape[-1]
|
||||
audio_dim = ax.shape[-1]
|
||||
|
||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
@@ -702,7 +848,7 @@ class LTXAVModel(LTXVModel):
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
@@ -720,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
v_prompt_timestep = timestep[3]
|
||||
a_prompt_timestep = timestep[4]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -746,6 +895,9 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -766,6 +918,9 @@ class LTXAVModel(LTXVModel):
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
"v_prompt_timestep": v_prompt_timestep,
|
||||
"a_prompt_timestep": a_prompt_timestep,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -787,6 +942,9 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
v_prompt_timestep=v_prompt_timestep,
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -157,11 +161,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
torch.empty(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
@@ -234,7 +236,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -247,7 +249,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,7 +290,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
@@ -272,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NormSingleLinearTextProjection(nn.Module):
|
||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.in_norm = operations.RMSNorm(
|
||||
in_features, eps=1e-6, elementwise_affine=False
|
||||
)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, caption):
|
||||
caption = self.in_norm(caption)
|
||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||
return self.linear_1(caption)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -340,6 +367,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -359,6 +387,12 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
@@ -380,16 +414,30 @@ class CrossAttention(nn.Module):
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||
out = out * gates.unsqueeze(-1)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||
ADALN_BASE_PARAMS_COUNT = 6
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
@@ -413,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||
x += apply_cross_attention_adaln(
|
||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||
)
|
||||
else:
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
@@ -432,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||
|
||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||
over text tokens. Returns None when *adaln_module* is None.
|
||||
"""
|
||||
if adaln_module is None:
|
||||
return None
|
||||
ts_input = (
|
||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||
if timestep_scaled.dim() > 1
|
||||
else timestep_scaled.flatten()
|
||||
)
|
||||
prompt_ts, _ = adaln_module(
|
||||
ts_input,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x, context, attn, q_shift, q_scale, q_gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask=None, transformer_options={},
|
||||
):
|
||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||
|
||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||
that both regular tensors and CompressedTimestep are supported.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
@@ -553,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
caption_projection_first_linear=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -579,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.caption_proj_before_connector = caption_proj_before_connector
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.caption_projection_first_linear = caption_projection_first_linear
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -606,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.cross_attention_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
else:
|
||||
self.prompt_adaln_single = None
|
||||
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.caption_projection = lambda a: a
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
@@ -638,8 +760,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Base implementation returns None (no attenuation). Subclasses that
|
||||
support guide-based attention control should override this.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@@ -654,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -666,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
return timestep, embedded_timestep, prompt_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
@@ -781,16 +915,25 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
merged_args["prompt_timestep"] = prompt_timestep
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Build self-attention mask for per-guide attenuation
|
||||
self_attention_mask = self._build_guide_self_attention_mask(
|
||||
x, transformer_options, merged_args
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
x, context, attention_mask, timestep, pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
**merged_args,
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -814,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -833,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -841,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -853,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -890,26 +1037,257 @@ class LTXVModel(LTXBaseModel):
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
|
||||
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||
# Each entry tracks one guide reference; they are appended in order and
|
||||
# their pre_filter_counts partition the kf_grid_mask.
|
||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_entries:
|
||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||
if total_pfc != len(kf_grid_mask):
|
||||
raise ValueError(
|
||||
f"guide pre_filter_counts ({total_pfc}) != "
|
||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||
)
|
||||
resolved_entries = []
|
||||
offset = 0
|
||||
for entry in guide_entries:
|
||||
pfc = entry["pre_filter_count"]
|
||||
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||
surviving = int(entry_mask.sum().item())
|
||||
resolved_entries.append({
|
||||
**entry,
|
||||
"surviving_count": surviving,
|
||||
})
|
||||
offset += pfc
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||
attention for each guide reference independently.
|
||||
|
||||
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||
spatial masks, or no guide tokens).
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
# AV model: x = [vx, ax]; use vx for token count and device
|
||||
total_tokens = x[0].shape[1]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
else:
|
||||
total_tokens = x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||
if num_guide_tokens == 0:
|
||||
return None
|
||||
|
||||
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
# Guides are appended in order at the end of the sequence.
|
||||
guide_start = total_tokens - num_guide_tokens
|
||||
all_weights = []
|
||||
total_tracked = 0
|
||||
|
||||
for entry in resolved_entries:
|
||||
surviving = entry["surviving_count"]
|
||||
if surviving == 0:
|
||||
continue
|
||||
|
||||
strength = entry["strength"]
|
||||
pixel_mask = entry.get("pixel_mask")
|
||||
latent_shape = entry.get("latent_shape")
|
||||
|
||||
if pixel_mask is not None and latent_shape is not None:
|
||||
f_lat, h_lat, w_lat = latent_shape
|
||||
per_token = self._downsample_mask_to_latent(
|
||||
pixel_mask.to(device=device, dtype=dtype),
|
||||
f_lat, h_lat, w_lat,
|
||||
)
|
||||
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||
# Collapse batch dim — the mask is assumed identical across the
|
||||
# batch; validate and take the first element to get (1, tokens).
|
||||
if per_token.shape[0] > 1:
|
||||
ref = per_token[0]
|
||||
for bi in range(1, per_token.shape[0]):
|
||||
if not torch.equal(ref, per_token[bi]):
|
||||
logger.warning(
|
||||
"pixel_mask differs across batch elements; "
|
||||
"using first element only."
|
||||
)
|
||||
break
|
||||
per_token = per_token[:1]
|
||||
# `surviving` is the post-grid_mask token count.
|
||||
# Clamp to surviving to handle any mismatch safely.
|
||||
n_weights = min(per_token.shape[1], surviving)
|
||||
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||
else:
|
||||
weights = torch.full(
|
||||
(1, surviving), strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
all_weights.append(weights)
|
||||
total_tracked += weights.shape[1]
|
||||
|
||||
if not all_weights:
|
||||
return None
|
||||
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
"""Downsample a pixel-space mask to per-token latent weights.
|
||||
|
||||
Args:
|
||||
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||
f_lat: Number of latent frames (pre-dilation original count).
|
||||
h_lat: Latent height (pre-dilation original height).
|
||||
w_lat: Latent width (pre-dilation original width).
|
||||
|
||||
Returns:
|
||||
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||
"""
|
||||
b = mask.shape[0]
|
||||
f_pix = mask.shape[2]
|
||||
|
||||
# Spatial downsampling: area interpolation per frame
|
||||
spatial_down = torch.nn.functional.interpolate(
|
||||
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||
size=(h_lat, w_lat),
|
||||
mode="area",
|
||||
)
|
||||
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||
|
||||
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||
first_frame = spatial_down[:, :, :1, :, :]
|
||||
if f_pix > 1 and f_lat > 1:
|
||||
remaining_pix = f_pix - 1
|
||||
remaining_lat = f_lat - 1
|
||||
t = remaining_pix // remaining_lat
|
||||
if t < 1:
|
||||
# Fewer pixel frames than latent frames — upsample by repeating
|
||||
# the available pixel frames via nearest interpolation.
|
||||
rest_flat = rearrange(
|
||||
spatial_down[:, :, 1:, :, :],
|
||||
"b 1 f h w -> (b h w) 1 f",
|
||||
)
|
||||
rest_up = torch.nn.functional.interpolate(
|
||||
rest_flat, size=remaining_lat, mode="nearest",
|
||||
)
|
||||
rest = rearrange(
|
||||
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||
b=b, h=h_lat, w=w_lat,
|
||||
)
|
||||
else:
|
||||
# Trim trailing pixel frames that don't fill a complete group
|
||||
usable = remaining_lat * t
|
||||
rest = rearrange(
|
||||
spatial_down[:, :, 1:1 + usable, :, :],
|
||||
"b 1 (f t) h w -> b 1 f t h w",
|
||||
t=t,
|
||||
)
|
||||
rest = rest.mean(dim=3)
|
||||
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||
elif f_lat > 1:
|
||||
# Single pixel frame but multiple latent frames — repeat the
|
||||
# single frame across all latent frames.
|
||||
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||
else:
|
||||
latent_mask = first_frame
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -919,6 +1297,8 @@ class LTXVModel(LTXBaseModel):
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
prompt_timestep=prompt_timestep,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
config = self.get_default_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
self.sampling_rate = model_config.get(
|
||||
"sampling_rate", config.get("sampling_rate", 16000)
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
def get_default_config(self):
|
||||
ddconfig = {
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"mel_bins": 64,
|
||||
"z_channels": 8,
|
||||
"resolution": 256,
|
||||
"downsample_time": False,
|
||||
"in_channels": 2,
|
||||
"out_ch": 2,
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
"causality_axis": "height",
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
"ddconfig": ddconfig,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
},
|
||||
"preprocessing": {
|
||||
"stft": {
|
||||
"filter_length": 1024,
|
||||
"hop_length": 160,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_space":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_time":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
else:
|
||||
self.register_buffer(
|
||||
"last_scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(4, in_channels),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
config = self.get_default_config(version)
|
||||
|
||||
self.config = config
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
base_channels=config.get("encoder_base_channels", 128),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
base_channels=config.get("decoder_base_channels", 128),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
def get_default_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@@ -2,7 +2,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -12,6 +14,307 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor):
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride=1,
|
||||
padding=True,
|
||||
padding_mode="replicate",
|
||||
kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
|
||||
self.register_buffer("filter", filter, persistent=persistent)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio=2,
|
||||
down_ratio=2,
|
||||
up_kernel_size=12,
|
||||
down_kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||
)
|
||||
self.acts2 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HiFi-GAN residual blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
@@ -119,6 +422,7 @@ class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -128,19 +432,39 @@ class Vocoder(torch.nn.Module):
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
activation = config.get("activation", "snake")
|
||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||
|
||||
|
||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||
# where upsample_factor = product of all upsample stride lengths,
|
||||
# and mel_hop_length is loaded from the autoencoder config at
|
||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.resblock = config.get("resblock", "1")
|
||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
if self.resblock == "1":
|
||||
resblock_cls = ResBlock1
|
||||
elif self.resblock == "2":
|
||||
resblock_cls = ResBlock2
|
||||
elif self.resblock == "AMP1":
|
||||
resblock_cls = AMPBlock1
|
||||
else:
|
||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
@@ -157,25 +481,40 @@ class Vocoder(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
if self.resblock == "AMP1":
|
||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
if self.resblock == "AMP1":
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(ch))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.conv_post = ops.Conv1d(
|
||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||
)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [5, 4, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
"activation": "snake",
|
||||
"use_bias_at_final": True,
|
||||
"use_tanh_at_final": True,
|
||||
}
|
||||
|
||||
return config
|
||||
@@ -196,8 +535,10 @@ class Vocoder(torch.nn.Module):
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if self.resblock != "AMP1":
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
@@ -206,8 +547,167 @@ class Vocoder(torch.nn.Module):
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
The STFT is computed by convolving the padded signal with forward_basis.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
Computed in float32 for numerical stability, then cast back to
|
||||
the input dtype.
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
sampling_rate: int,
|
||||
mel_fmin: float,
|
||||
mel_fmax: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class VocoderWithBWE(torch.nn.Module):
|
||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||
|
||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
vocoder_config = config["vocoder"]
|
||||
bwe_config = config["bwe"]
|
||||
|
||||
self.vocoder = Vocoder(config=vocoder_config)
|
||||
self.bwe_generator = Vocoder(
|
||||
config={**bwe_config, "apply_final_activation": False}
|
||||
)
|
||||
|
||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||
self.hop_length = bwe_config["hop_length"]
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=bwe_config["n_fft"],
|
||||
hop_length=bwe_config["hop_length"],
|
||||
win_length=bwe_config["n_fft"],
|
||||
n_mel_channels=bwe_config["num_mels"],
|
||||
sampling_rate=bwe_config["input_sampling_rate"],
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||
)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||
persistent=False,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
def _compute_mel(self, audio):
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||
B, C, T = audio.shape
|
||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec):
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
remainder = T_low % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
mel = self._compute_mel(x)
|
||||
residual = self.bwe_generator(mel)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||
|
||||
@@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
@@ -858,3 +859,267 @@ class NextDiT(nn.Module):
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
return -img
|
||||
|
||||
|
||||
#############################################################################
|
||||
# Pixel Space Decoder Components #
|
||||
#############################################################################
|
||||
|
||||
def _modulate_shift_scale(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
"""
|
||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||
an identity at the beginning of training.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
||||
h = self.mlp(h)
|
||||
return x + gate * h
|
||||
|
||||
|
||||
class DCTFinalLayer(nn.Module):
|
||||
"""Zero-initialised output projection (adopted from DiT)."""
|
||||
|
||||
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
|
||||
|
||||
class SimpleMLPAdaLN(nn.Module):
|
||||
"""
|
||||
Small MLP decoder head for the pixel-space variant.
|
||||
|
||||
Takes per-patch pixel values and a per-patch conditioning vector from the
|
||||
transformer backbone and predicts the denoised pixel values.
|
||||
|
||||
x : [B*N, P^2, C] – noisy pixel values per patch position
|
||||
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
||||
→ [B*N, P^2, C]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
max_freqs: int = 8,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
# Project backbone hidden state → per-patch conditioning
|
||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
in_channels=in_channels,
|
||||
hidden_size_input=model_channels,
|
||||
max_freqs=max_freqs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
self.res_blocks = nn.ModuleList([
|
||||
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
||||
])
|
||||
|
||||
# Output projection
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||
original_dtype = x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||
x = x.to(weight_dtype)
|
||||
for block in self.res_blocks:
|
||||
x = block(x, y)
|
||||
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT – Pixel Space #
|
||||
#############################################################################
|
||||
|
||||
class NextDiTPixelSpace(NextDiT):
|
||||
"""
|
||||
Pixel-space variant of NextDiT.
|
||||
|
||||
Identical transformer backbone to NextDiT, but the output head is replaced
|
||||
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
||||
per patch rather than a single affine projection.
|
||||
|
||||
Key differences vs NextDiT:
|
||||
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
||||
• ``_forward`` stores the raw patchified pixel values before the backbone
|
||||
embedding and feeds them to ``dec_net`` together with the per-patch
|
||||
backbone hidden states.
|
||||
• Supports optional x0 prediction via ``use_x0``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# decoder-specific
|
||||
decoder_hidden_size: int = 3840,
|
||||
decoder_num_res_blocks: int = 4,
|
||||
decoder_max_freqs: int = 8,
|
||||
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
|
||||
use_x0: bool = False,
|
||||
# all NextDiT args forwarded unchanged
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Remove the latent-space final layer – not used in pixel space
|
||||
del self.final_layer
|
||||
|
||||
patch_size = kwargs.get("patch_size", 2)
|
||||
in_channels = kwargs.get("in_channels", 4)
|
||||
dim = kwargs.get("dim", 4096)
|
||||
|
||||
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
|
||||
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
|
||||
|
||||
self.dec_net = SimpleMLPAdaLN(
|
||||
in_channels=dec_in_ch,
|
||||
model_channels=decoder_hidden_size,
|
||||
out_channels=dec_in_ch,
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
max_freqs=decoder_max_freqs,
|
||||
dtype=kwargs.get("dtype"),
|
||||
device=kwargs.get("device"),
|
||||
operations=kwargs.get("operations"),
|
||||
)
|
||||
|
||||
if use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
|
||||
# with the pixel-space dec_net decoder.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||
adaln_input = t
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
|
||||
pH = pW = self.patch_size
|
||||
B, C, H, W = x.shape
|
||||
pixel_patches = (
|
||||
x.view(B, C, H // pH, pH, W // pW, pW)
|
||||
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
||||
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
||||
.flatten(1, 2) # [B, N, pH*pW*C]
|
||||
)
|
||||
N = pixel_patches.shape[1]
|
||||
# decoder sees one token per patch: [B*N, 1, P^2*C]
|
||||
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, adaln_input, num_tokens,
|
||||
ref_latents=ref_latents, ref_contexts=ref_contexts,
|
||||
siglip_feats=siglip_feats, transformer_options=transformer_options
|
||||
)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
|
||||
# img may have padding tokens beyond N; only the first N are real image patches
|
||||
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
|
||||
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
|
||||
|
||||
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
|
||||
output = output.reshape(B, N, -1) # [B, N, P^2*C]
|
||||
|
||||
# prepend zero cap placeholder so unpatchify indexing works unchanged
|
||||
cap_placeholder = torch.zeros(
|
||||
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
||||
)
|
||||
img_out = self.unpatchify(
|
||||
torch.cat([cap_placeholder, output], dim=1),
|
||||
img_size, cap_size, return_tensor=x_is_tensor
|
||||
)[:, :, :h, :w]
|
||||
|
||||
return -img_out
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
# _forward returns neg_x0 = -x0 (negated decoder output).
|
||||
#
|
||||
# Reference inference (working_inference_reference.py):
|
||||
# out = _forward(img, t) # = -x0
|
||||
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
|
||||
# img += (t_prev - t_curr) * pred # Euler step
|
||||
#
|
||||
# ComfyUI's Euler sampler does the same:
|
||||
# x_next = x + (sigma_next - sigma) * model_output
|
||||
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
|
||||
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
@@ -18,6 +18,8 @@ import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
from ..sdpose import HeatmapHead
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
@@ -441,6 +443,7 @@ class UNetModel(nn.Module):
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
heatmap_head=False,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@@ -827,6 +830,9 @@ class UNetModel(nn.Module):
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
if heatmap_head:
|
||||
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
|
||||
130
comfy/ldm/modules/sdpose.py
Normal file
130
comfy/ldm/modules/sdpose.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
class HeatmapHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=640,
|
||||
out_channels=133,
|
||||
input_size=(768, 1024),
|
||||
heatmap_scale=4,
|
||||
deconv_out_channels=(640,),
|
||||
deconv_kernel_sizes=(4,),
|
||||
conv_out_channels=(640,),
|
||||
conv_kernel_sizes=(1,),
|
||||
final_layer_kernel_size=1,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
||||
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
||||
|
||||
# Deconv layers
|
||||
if deconv_out_channels:
|
||||
deconv_layers = []
|
||||
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
||||
if kernel_size == 4:
|
||||
padding, output_padding = 1, 0
|
||||
elif kernel_size == 3:
|
||||
padding, output_padding = 1, 1
|
||||
elif kernel_size == 2:
|
||||
padding, output_padding = 0, 0
|
||||
else:
|
||||
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
||||
|
||||
deconv_layers.extend([
|
||||
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
||||
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
||||
else:
|
||||
self.deconv_layers = torch.nn.Identity()
|
||||
|
||||
# Conv layers
|
||||
if conv_out_channels:
|
||||
conv_layers = []
|
||||
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv_layers.extend([
|
||||
operations.Conv2d(in_channels, out_ch, kernel_size,
|
||||
stride=1, padding=padding, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
||||
else:
|
||||
self.conv_layers = torch.nn.Identity()
|
||||
|
||||
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x): # Decode heatmaps to keypoints
|
||||
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
||||
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
||||
B, K, H, W = heatmaps_np.shape
|
||||
|
||||
batch_keypoints = []
|
||||
batch_scores = []
|
||||
|
||||
for b in range(B):
|
||||
hm = heatmaps_np[b].copy() # (K, H, W)
|
||||
|
||||
# --- vectorised argmax ---
|
||||
flat = hm.reshape(K, -1)
|
||||
idx = np.argmax(flat, axis=1)
|
||||
scores = flat[np.arange(K), idx].copy()
|
||||
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
||||
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
||||
invalid = scores <= 0.
|
||||
keypoints[invalid] = -1
|
||||
|
||||
# --- DARK sub-pixel refinement (UDP) ---
|
||||
# 1. Gaussian blur with max-preserving normalisation
|
||||
border = 5 # (kernel-1)//2 for kernel=11
|
||||
for k in range(K):
|
||||
origin_max = np.max(hm[k])
|
||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||
dr[border:-border, border:-border] = hm[k].copy()
|
||||
dr = gaussian_filter(dr, sigma=2.0)
|
||||
hm[k] = dr[border:-border, border:-border].copy()
|
||||
cur_max = np.max(hm[k])
|
||||
if cur_max > 0:
|
||||
hm[k] *= origin_max / cur_max
|
||||
# 2. Log-space for Taylor expansion
|
||||
np.clip(hm, 1e-3, 50., hm)
|
||||
np.log(hm, hm)
|
||||
# 3. Hessian-based Newton step
|
||||
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
||||
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
||||
index += (W + 2) * (H + 2) * np.arange(0, K)
|
||||
index = index.astype(int).reshape(-1, 1)
|
||||
i_ = hm_pad[index]
|
||||
ix1 = hm_pad[index + 1]
|
||||
iy1 = hm_pad[index + W + 2]
|
||||
ix1y1 = hm_pad[index + W + 3]
|
||||
ix1_y1_ = hm_pad[index - W - 3]
|
||||
ix1_ = hm_pad[index - 1]
|
||||
iy1_ = hm_pad[index - 2 - W]
|
||||
dx = 0.5 * (ix1 - ix1_)
|
||||
dy = 0.5 * (iy1 - iy1_)
|
||||
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
||||
dxx = ix1 - 2 * i_ + ix1_
|
||||
dyy = iy1 - 2 * i_ + iy1_
|
||||
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
||||
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
||||
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
||||
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
||||
|
||||
# --- restore to input image space ---
|
||||
keypoints = keypoints * self.scale_factor
|
||||
keypoints[invalid] = -1
|
||||
|
||||
batch_keypoints.append(keypoints)
|
||||
batch_scores.append(scores)
|
||||
|
||||
return batch_keypoints, batch_scores
|
||||
@@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class SCAILWanModel(WanModel):
|
||||
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
del scail_x
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if scail_pose_seq_len > 0:
|
||||
x = x[:, :-scail_pose_seq_len]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
|
||||
if reference_latent is not None:
|
||||
x = x[:, :, reference_latent.shape[2]:]
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
return main_freqs
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
|
||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
|
||||
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
reference_latent = None
|
||||
if "reference_latent" in kwargs:
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
@@ -459,6 +459,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
conv_out_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -474,7 +475,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -484,7 +485,7 @@ class WanVAE(nn.Module):
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||
if tp > 0 and not k.startswith("clip_"):
|
||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
@@ -337,6 +340,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = None
|
||||
aimdo_enabled = False
|
||||
|
||||
@@ -76,6 +76,7 @@ class ModelType(Enum):
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -922,6 +925,25 @@ class Flux(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class LongCatImage(Flux):
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", 512.0)
|
||||
rope_opts.setdefault("shift_x", 512.0)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out.pop('guidance', None)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -971,6 +993,10 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@@ -988,10 +1014,14 @@ class LTXAV(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
@@ -1019,6 +1049,10 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1229,6 +1263,11 @@ class Lumina2(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class ZImagePixelSpace(Lumina2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@@ -1462,6 +1501,50 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
class WAN21_SCAIL(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
pose_latents = self.process_latent_in(pose_latents)
|
||||
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
@@ -421,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
@@ -462,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
||||
if dec_cond_key in state_dict_keys: # pixel-space variant
|
||||
dit_config["image_model"] = "zimage_pixel"
|
||||
# patch_size and in_channels are derived from x_embedder:
|
||||
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
|
||||
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
|
||||
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
|
||||
# patch_size: infer from decoder final layer output matching x_embedder input
|
||||
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
|
||||
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
|
||||
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
|
||||
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["decoder_in_channels"] = dec_in_ch
|
||||
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
|
||||
dit_config["decoder_num_res_blocks"] = count_blocks(
|
||||
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
||||
)
|
||||
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
|
||||
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_x0"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
@@ -496,6 +521,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -509,6 +536,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
@@ -526,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
@@ -792,6 +821,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config["use_temporal_resblock"] = False
|
||||
unet_config["use_temporal_attention"] = False
|
||||
|
||||
heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix)
|
||||
if heatmap_key in state_dict_keys:
|
||||
unet_config["heatmap_head"] = True
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
@@ -1012,7 +1045,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
|
||||
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
@@ -1044,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
|
||||
n_layers = count_blocks(state_dict, 'layers.{}.')
|
||||
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
|
||||
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
|
||||
for k in state_dict: # For zeta chroma
|
||||
if k not in sd_map:
|
||||
sd_map[k] = k
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
|
||||
@@ -32,9 +32,6 @@ import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -180,6 +177,14 @@ def is_ixuca():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_wsl():
|
||||
version = platform.uname().release
|
||||
if version.endswith("-Microsoft"):
|
||||
return True
|
||||
elif version.endswith("microsoft-standard-WSL2"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -350,7 +355,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
@@ -378,7 +383,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@@ -631,12 +636,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
@@ -792,6 +796,8 @@ def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
@@ -824,11 +830,14 @@ def unet_offload_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
cpu_dev = torch.device("cpu")
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return cpu_dev
|
||||
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||
return cpu_dev
|
||||
|
||||
@@ -836,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@@ -930,7 +939,7 @@ def text_encoder_offload_device():
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
@@ -939,6 +948,9 @@ def text_encoder_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return offload_device
|
||||
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
@@ -1121,7 +1133,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
@@ -1137,6 +1148,7 @@ def reset_cast_buffers():
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
@@ -1200,43 +1212,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@@ -1692,12 +1667,16 @@ def lora_compute_dtype(device):
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if cpu_mode():
|
||||
return
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
if cpu_mode():
|
||||
return
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
|
||||
@@ -241,6 +241,7 @@ class ModelPatcher:
|
||||
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.backup_buffers = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
@@ -271,6 +272,7 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -305,10 +307,30 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
|
||||
#the vast majority of setups a little bit of offloading on the giant model more
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
class_ = self.__class__
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if model_override is None:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
|
||||
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -317,13 +339,12 @@ class ModelPatcher:
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
for k in self.attachments:
|
||||
@@ -362,6 +383,8 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
@@ -682,7 +705,7 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
def _load_list(self, for_dynamic=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
default = False
|
||||
@@ -692,8 +715,8 @@ class ModelPatcher:
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
for param_name, param in params.items():
|
||||
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
@@ -710,8 +733,13 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
|
||||
# Non-dynamic: prioritize by module offload cost.
|
||||
if for_dynamic:
|
||||
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
|
||||
else:
|
||||
sort_criteria = (module_offload_mem,)
|
||||
loading.append(sort_criteria + (module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -1419,12 +1447,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
@@ -1444,15 +1469,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
@@ -1487,6 +1504,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@@ -1495,15 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
loading = self._load_list(for_dynamic=True, default_device=device_to)
|
||||
loading.sort()
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
*_, module_mem, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
@@ -1541,6 +1555,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is not None:
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
@@ -1566,21 +1583,26 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
||||
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
||||
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
for key, buf in self.model.named_buffers(recurse=True):
|
||||
if key not in self.backup_buffers:
|
||||
self.backup_buffers[key] = buf
|
||||
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
||||
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
||||
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
@@ -1596,12 +1618,23 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
@@ -1623,11 +1656,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
@@ -1659,4 +1687,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
def get_non_dynamic_delegate(self):
|
||||
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
||||
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
||||
return model_patcher
|
||||
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
@@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
class IMG_TO_IMG_FLOW(CONST):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return 1.0 - latent
|
||||
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
|
||||
66
comfy/ops.py
66
comfy/ops.py
@@ -19,7 +19,7 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import json
|
||||
import comfy.memory_management
|
||||
@@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@@ -167,17 +182,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
x = y
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
@@ -271,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
#FIXME: This is really bad RTTI
|
||||
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
@@ -296,7 +309,7 @@ class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
@@ -317,7 +330,7 @@ class disable_weight_init:
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
@@ -617,7 +630,8 @@ def fp8_linear(self, input):
|
||||
|
||||
if input.ndim != 2:
|
||||
return None
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
@@ -661,23 +675,29 @@ class fp8_ops(manual_cast):
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
from cublas_ops import CublasLinear, cublas_half_matmul
|
||||
CUBLAS_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if CUBLAS_IS_AVAILABLE:
|
||||
class cublas_ops(disable_weight_init):
|
||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||
class cublas_ops(manual_cast):
|
||||
class Linear(CublasLinear, manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
return super().forward(input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
@@ -827,6 +847,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight'):
|
||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||
return sd
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
|
||||
@@ -66,6 +66,18 @@ def convert_cond(cond):
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def cond_has_hooks(cond):
|
||||
for c in cond:
|
||||
temp = c[1]
|
||||
if "hooks" in temp:
|
||||
return True
|
||||
if "control" in temp:
|
||||
control = temp["control"]
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets: list[ControlBase] = []
|
||||
|
||||
@@ -946,6 +946,8 @@ class CFGGuider:
|
||||
|
||||
def inner_set_conds(self, conds):
|
||||
for k in conds:
|
||||
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
||||
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user