mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-02 22:29:01 +00:00
Compare commits
16 Commits
node-essen
...
assets-api
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ad4b76b55 | ||
|
|
facda426b4 | ||
|
|
65a5992f2d | ||
|
|
287da646e5 | ||
|
|
63f9f1b11b | ||
|
|
9e3f559189 | ||
|
|
63c98d0c75 | ||
|
|
e69a5aa1be | ||
|
|
e0c063f93e | ||
|
|
6db4f4e3f1 | ||
|
|
41d364030b | ||
|
|
fab9b71f5d | ||
|
|
e5c1de4777 | ||
|
|
a5ed151e51 | ||
|
|
e527b72b09 | ||
|
|
f14129947c |
@@ -1,14 +1,20 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
import app.assets.scanner as scanner
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
@@ -28,6 +34,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -76,6 +94,321 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -100,3 +433,83 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
try:
|
||||
scanner.seed_assets(body.roots)
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", body.roots)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response({"synced": True, "roots": body.roots}, status=200)
|
||||
|
||||
@@ -8,8 +8,10 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from app.assets.helpers import RootType
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
@@ -57,6 +59,61 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
tags: list[str] | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -75,6 +132,145 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
@@ -92,3 +288,7 @@ class SetPreviewBody(BaseModel):
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@@ -177,6 +236,7 @@ def list_asset_infos_page(
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@@ -208,6 +268,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@@ -265,3 +813,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@@ -1,13 +1,34 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
import app.assets.hashing as hashing
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@@ -19,11 +40,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@@ -76,7 +114,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@@ -97,6 +140,349 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
try:
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@@ -22,6 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.6
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
0
tests-unit/assets_test/__init__.py
Normal file
0
tests-unit/assets_test/__init__.py
Normal file
104
tests-unit/assets_test/conftest.py
Normal file
104
tests-unit/assets_test/conftest.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Pytest fixtures for assets API tests.
|
||||
"""
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def in_memory_engine():
|
||||
"""Create an in-memory SQLite engine with all asset tables."""
|
||||
engine = create_engine("sqlite:///:memory:", echo=False)
|
||||
|
||||
from app.database.models import Base
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetInfo,
|
||||
AssetCacheState,
|
||||
AssetInfoMeta,
|
||||
AssetInfoTag,
|
||||
Tag,
|
||||
)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
yield engine
|
||||
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(in_memory_engine) -> Session:
|
||||
"""Create a fresh database session for each test."""
|
||||
SessionLocal = sessionmaker(bind=in_memory_engine)
|
||||
session = SessionLocal()
|
||||
|
||||
yield session
|
||||
|
||||
session.rollback()
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_manager():
|
||||
"""Create a mock UserManager that returns a predictable owner_id."""
|
||||
mock = MagicMock()
|
||||
mock.get_request_user_id = MagicMock(return_value="test-user-123")
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_user_manager) -> web.Application:
|
||||
"""Create an aiohttp Application with assets routes registered."""
|
||||
from app.assets.api.routes import register_assets_system
|
||||
|
||||
application = web.Application()
|
||||
register_assets_system(application, mock_user_manager)
|
||||
return application
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_image_bytes() -> bytes:
|
||||
"""Generate a minimal valid PNG image (10x10 red pixels)."""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (10, 10), color="red")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_upload_dir(tmp_path):
|
||||
"""Create a temporary directory for uploads and patch folder_paths."""
|
||||
upload_dir = tmp_path / "uploads"
|
||||
upload_dir.mkdir()
|
||||
|
||||
with patch("folder_paths.get_temp_directory", return_value=str(tmp_path)):
|
||||
yield tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_create_session(in_memory_engine):
|
||||
"""Patch create_session to use our in-memory database."""
|
||||
SessionLocal = sessionmaker(bind=in_memory_engine)
|
||||
|
||||
with patch("app.database.db.Session", SessionLocal):
|
||||
with patch("app.database.db.create_session", lambda: SessionLocal()):
|
||||
with patch("app.database.db.can_create_session", return_value=True):
|
||||
yield
|
||||
|
||||
|
||||
async def test_fixtures_work(db_session, mock_user_manager):
|
||||
"""Smoke test to verify fixtures are working."""
|
||||
assert db_session is not None
|
||||
assert mock_user_manager.get_request_user_id(None) == "test-user-123"
|
||||
317
tests-unit/assets_test/helpers_test.py
Normal file
317
tests-unit/assets_test/helpers_test.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Tests for app.assets.helpers utility functions."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.assets.helpers import (
|
||||
normalize_tags,
|
||||
escape_like_prefix,
|
||||
ensure_within_base,
|
||||
get_query_dict,
|
||||
utcnow,
|
||||
project_kv,
|
||||
is_scalar,
|
||||
fast_asset_file_check,
|
||||
list_tree,
|
||||
RootType,
|
||||
ALLOWED_ROOTS,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeTags:
|
||||
def test_lowercases(self):
|
||||
assert normalize_tags(["FOO", "Bar"]) == ["foo", "bar"]
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert normalize_tags([" hello ", "world "]) == ["hello", "world"]
|
||||
|
||||
def test_does_not_deduplicate(self):
|
||||
result = normalize_tags(["a", "A", "a"])
|
||||
assert result == ["a", "a", "a"]
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
assert normalize_tags(None) == []
|
||||
|
||||
def test_empty_list_returns_empty(self):
|
||||
assert normalize_tags([]) == []
|
||||
|
||||
def test_filters_empty_strings(self):
|
||||
assert normalize_tags(["a", "", " ", "b"]) == ["a", "b"]
|
||||
|
||||
def test_preserves_order(self):
|
||||
result = normalize_tags(["Z", "A", "z", "B"])
|
||||
assert result == ["z", "a", "z", "b"]
|
||||
|
||||
|
||||
class TestEscapeLikePrefix:
|
||||
def test_escapes_percent(self):
|
||||
result, esc = escape_like_prefix("50%")
|
||||
assert result == "50!%"
|
||||
assert esc == "!"
|
||||
|
||||
def test_escapes_underscore(self):
|
||||
result, esc = escape_like_prefix("file_name")
|
||||
assert result == "file!_name"
|
||||
assert esc == "!"
|
||||
|
||||
def test_escapes_escape_char(self):
|
||||
result, esc = escape_like_prefix("a!b")
|
||||
assert result == "a!!b"
|
||||
assert esc == "!"
|
||||
|
||||
def test_normal_string_unchanged(self):
|
||||
result, esc = escape_like_prefix("hello")
|
||||
assert result == "hello"
|
||||
assert esc == "!"
|
||||
|
||||
def test_complex_string(self):
|
||||
result, esc = escape_like_prefix("50%_!x")
|
||||
assert result == "50!%!_!!x"
|
||||
|
||||
def test_custom_escape_char(self):
|
||||
result, esc = escape_like_prefix("50%", escape="\\")
|
||||
assert result == "50\\%"
|
||||
assert esc == "\\"
|
||||
|
||||
|
||||
class TestEnsureWithinBase:
|
||||
def test_valid_path_within_base(self, tmp_path):
|
||||
base = str(tmp_path)
|
||||
candidate = str(tmp_path / "subdir" / "file.txt")
|
||||
ensure_within_base(candidate, base)
|
||||
|
||||
def test_path_traversal_rejected(self, tmp_path):
|
||||
base = str(tmp_path / "safe")
|
||||
candidate = str(tmp_path / "safe" / ".." / "unsafe")
|
||||
with pytest.raises(ValueError, match="escapes base directory|invalid destination"):
|
||||
ensure_within_base(candidate, base)
|
||||
|
||||
def test_completely_outside_path_rejected(self, tmp_path):
|
||||
base = str(tmp_path / "safe")
|
||||
candidate = "/etc/passwd"
|
||||
with pytest.raises(ValueError):
|
||||
ensure_within_base(candidate, base)
|
||||
|
||||
def test_same_path_is_valid(self, tmp_path):
|
||||
base = str(tmp_path)
|
||||
ensure_within_base(base, base)
|
||||
|
||||
|
||||
class TestGetQueryDict:
|
||||
def test_single_values(self):
|
||||
request = MagicMock()
|
||||
request.query.keys.return_value = ["a", "b"]
|
||||
request.query.get.side_effect = lambda k: {"a": "1", "b": "2"}[k]
|
||||
request.query.getall.side_effect = lambda k: [{"a": "1", "b": "2"}[k]]
|
||||
|
||||
result = get_query_dict(request)
|
||||
assert result == {"a": "1", "b": "2"}
|
||||
|
||||
def test_multiple_values_same_key(self):
|
||||
request = MagicMock()
|
||||
request.query.keys.return_value = ["tags"]
|
||||
request.query.get.return_value = "tag1"
|
||||
request.query.getall.return_value = ["tag1", "tag2", "tag3"]
|
||||
|
||||
result = get_query_dict(request)
|
||||
assert result == {"tags": ["tag1", "tag2", "tag3"]}
|
||||
|
||||
def test_empty_query(self):
|
||||
request = MagicMock()
|
||||
request.query.keys.return_value = []
|
||||
|
||||
result = get_query_dict(request)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestUtcnow:
|
||||
def test_returns_datetime(self):
|
||||
result = utcnow()
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_no_tzinfo(self):
|
||||
result = utcnow()
|
||||
assert result.tzinfo is None
|
||||
|
||||
def test_is_approximately_now(self):
|
||||
before = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
result = utcnow()
|
||||
after = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
assert before <= result <= after
|
||||
|
||||
|
||||
class TestIsScalar:
|
||||
def test_none_is_scalar(self):
|
||||
assert is_scalar(None) is True
|
||||
|
||||
def test_bool_is_scalar(self):
|
||||
assert is_scalar(True) is True
|
||||
assert is_scalar(False) is True
|
||||
|
||||
def test_int_is_scalar(self):
|
||||
assert is_scalar(42) is True
|
||||
|
||||
def test_float_is_scalar(self):
|
||||
assert is_scalar(3.14) is True
|
||||
|
||||
def test_decimal_is_scalar(self):
|
||||
assert is_scalar(Decimal("10.5")) is True
|
||||
|
||||
def test_str_is_scalar(self):
|
||||
assert is_scalar("hello") is True
|
||||
|
||||
def test_list_is_not_scalar(self):
|
||||
assert is_scalar([1, 2, 3]) is False
|
||||
|
||||
def test_dict_is_not_scalar(self):
|
||||
assert is_scalar({"a": 1}) is False
|
||||
|
||||
|
||||
class TestProjectKv:
|
||||
def test_none_value(self):
|
||||
result = project_kv("key", None)
|
||||
assert len(result) == 1
|
||||
assert result[0]["key"] == "key"
|
||||
assert result[0]["ordinal"] == 0
|
||||
assert result[0]["val_str"] is None
|
||||
assert result[0]["val_num"] is None
|
||||
|
||||
def test_string_value(self):
|
||||
result = project_kv("name", "test")
|
||||
assert len(result) == 1
|
||||
assert result[0]["val_str"] == "test"
|
||||
|
||||
def test_int_value(self):
|
||||
result = project_kv("count", 42)
|
||||
assert len(result) == 1
|
||||
assert result[0]["val_num"] == Decimal("42")
|
||||
|
||||
def test_float_value(self):
|
||||
result = project_kv("ratio", 3.14)
|
||||
assert len(result) == 1
|
||||
assert result[0]["val_num"] == Decimal("3.14")
|
||||
|
||||
def test_bool_value(self):
|
||||
result = project_kv("enabled", True)
|
||||
assert len(result) == 1
|
||||
assert result[0]["val_bool"] is True
|
||||
|
||||
def test_list_of_strings(self):
|
||||
result = project_kv("tags", ["a", "b", "c"])
|
||||
assert len(result) == 3
|
||||
assert result[0]["ordinal"] == 0
|
||||
assert result[0]["val_str"] == "a"
|
||||
assert result[1]["ordinal"] == 1
|
||||
assert result[1]["val_str"] == "b"
|
||||
assert result[2]["ordinal"] == 2
|
||||
assert result[2]["val_str"] == "c"
|
||||
|
||||
def test_list_of_mixed_scalars(self):
|
||||
result = project_kv("mixed", [1, "two", True])
|
||||
assert len(result) == 3
|
||||
assert result[0]["val_num"] == Decimal("1")
|
||||
assert result[1]["val_str"] == "two"
|
||||
assert result[2]["val_bool"] is True
|
||||
|
||||
def test_list_with_none(self):
|
||||
result = project_kv("items", ["a", None, "b"])
|
||||
assert len(result) == 3
|
||||
assert result[1]["val_str"] is None
|
||||
assert result[1]["val_num"] is None
|
||||
|
||||
def test_dict_value_stored_as_json(self):
|
||||
result = project_kv("meta", {"nested": "value"})
|
||||
assert len(result) == 1
|
||||
assert result[0]["val_json"] == {"nested": "value"}
|
||||
|
||||
def test_list_of_dicts_stored_as_json(self):
|
||||
result = project_kv("items", [{"a": 1}, {"b": 2}])
|
||||
assert len(result) == 2
|
||||
assert result[0]["val_json"] == {"a": 1}
|
||||
assert result[1]["val_json"] == {"b": 2}
|
||||
|
||||
|
||||
class TestFastAssetFileCheck:
|
||||
def test_none_mtime_returns_false(self):
|
||||
stat = MagicMock()
|
||||
assert fast_asset_file_check(mtime_db=None, size_db=100, stat_result=stat) is False
|
||||
|
||||
def test_matching_mtime_and_size(self):
|
||||
stat = MagicMock()
|
||||
stat.st_mtime_ns = 1234567890123456789
|
||||
stat.st_size = 100
|
||||
|
||||
result = fast_asset_file_check(
|
||||
mtime_db=1234567890123456789,
|
||||
size_db=100,
|
||||
stat_result=stat
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_mismatched_mtime(self):
|
||||
stat = MagicMock()
|
||||
stat.st_mtime_ns = 9999999999999999999
|
||||
stat.st_size = 100
|
||||
|
||||
result = fast_asset_file_check(
|
||||
mtime_db=1234567890123456789,
|
||||
size_db=100,
|
||||
stat_result=stat
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_mismatched_size(self):
|
||||
stat = MagicMock()
|
||||
stat.st_mtime_ns = 1234567890123456789
|
||||
stat.st_size = 200
|
||||
|
||||
result = fast_asset_file_check(
|
||||
mtime_db=1234567890123456789,
|
||||
size_db=100,
|
||||
stat_result=stat
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_zero_size_skips_size_check(self):
|
||||
stat = MagicMock()
|
||||
stat.st_mtime_ns = 1234567890123456789
|
||||
stat.st_size = 999
|
||||
|
||||
result = fast_asset_file_check(
|
||||
mtime_db=1234567890123456789,
|
||||
size_db=0,
|
||||
stat_result=stat
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestListTree:
|
||||
def test_lists_files_in_directory(self, tmp_path):
|
||||
(tmp_path / "file1.txt").touch()
|
||||
(tmp_path / "file2.txt").touch()
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
(subdir / "file3.txt").touch()
|
||||
|
||||
result = list_tree(str(tmp_path))
|
||||
assert len(result) == 3
|
||||
assert all(os.path.isabs(p) for p in result)
|
||||
assert str(tmp_path / "file1.txt") in result
|
||||
assert str(tmp_path / "subdir" / "file3.txt") in result
|
||||
|
||||
def test_nonexistent_directory_returns_empty(self):
|
||||
result = list_tree("/nonexistent/path/that/does/not/exist")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRootType:
|
||||
def test_allowed_roots_contains_expected_values(self):
|
||||
assert "models" in ALLOWED_ROOTS
|
||||
assert "input" in ALLOWED_ROOTS
|
||||
assert "output" in ALLOWED_ROOTS
|
||||
|
||||
def test_allowed_roots_is_tuple(self):
|
||||
assert isinstance(ALLOWED_ROOTS, tuple)
|
||||
597
tests-unit/assets_test/queries_crud_test.py
Normal file
597
tests-unit/assets_test/queries_crud_test.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
Tests for core CRUD database query functions in app.assets.database.queries.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
create_asset_info_for_existing_asset,
|
||||
ingest_fs_asset,
|
||||
delete_asset_info_by_id,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
ensure_tags_exist,
|
||||
)
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState
|
||||
|
||||
|
||||
def make_hash(seed: str = "a") -> str:
|
||||
return "blake3:" + seed * 64
|
||||
|
||||
|
||||
def make_unique_hash() -> str:
|
||||
return "blake3:" + uuid.uuid4().hex + uuid.uuid4().hex
|
||||
|
||||
|
||||
class TestAssetExistsByHash:
|
||||
def test_returns_true_when_exists(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"fake png data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=len(b"fake png data"),
|
||||
mtime_ns=1000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert asset_exists_by_hash(db_session, asset_hash=asset_hash) is True
|
||||
|
||||
def test_returns_false_when_missing(self, db_session):
|
||||
assert asset_exists_by_hash(db_session, asset_hash=make_unique_hash()) is False
|
||||
|
||||
|
||||
class TestGetAssetByHash:
|
||||
def test_returns_asset_when_exists(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=9,
|
||||
mtime_ns=1000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
asset = get_asset_by_hash(db_session, asset_hash=asset_hash)
|
||||
assert asset is not None
|
||||
assert asset.hash == asset_hash
|
||||
assert asset.size_bytes == 9
|
||||
assert asset.mime_type == "image/png"
|
||||
|
||||
def test_returns_none_when_missing(self, db_session):
|
||||
result = get_asset_by_hash(db_session, asset_hash=make_unique_hash())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetAssetInfoById:
|
||||
def test_returns_asset_info_when_exists(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=9,
|
||||
mtime_ns=1000000,
|
||||
info_name="my-asset",
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
assert info is not None
|
||||
assert info.name == "my-asset"
|
||||
assert info.owner_id == "user1"
|
||||
|
||||
def test_returns_none_when_missing(self, db_session):
|
||||
fake_id = str(uuid.uuid4())
|
||||
result = get_asset_info_by_id(db_session, asset_info_id=fake_id)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreateAssetInfoForExistingAsset:
|
||||
def test_creates_linked_asset_info(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=9,
|
||||
mtime_ns=1000000,
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
name="new-info",
|
||||
owner_id="owner123",
|
||||
user_metadata={"key": "value"},
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert info is not None
|
||||
assert info.name == "new-info"
|
||||
assert info.owner_id == "owner123"
|
||||
|
||||
asset = get_asset_by_hash(db_session, asset_hash=asset_hash)
|
||||
assert info.asset_id == asset.id
|
||||
|
||||
def test_raises_on_unknown_hash(self, db_session):
|
||||
with pytest.raises(ValueError, match="Unknown asset hash"):
|
||||
create_asset_info_for_existing_asset(
|
||||
db_session,
|
||||
asset_hash=make_unique_hash(),
|
||||
name="test",
|
||||
)
|
||||
|
||||
def test_returns_existing_on_duplicate(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=9,
|
||||
mtime_ns=1000000,
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info1 = create_asset_info_for_existing_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
name="same-name",
|
||||
owner_id="owner1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info2 = create_asset_info_for_existing_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
name="same-name",
|
||||
owner_id="owner1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert info1.id == info2.id
|
||||
|
||||
|
||||
class TestIngestFsAsset:
|
||||
def test_creates_all_records(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"fake png data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=len(b"fake png data"),
|
||||
mtime_ns=1000000,
|
||||
mime_type="image/png",
|
||||
info_name="test-asset",
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert result["asset_created"] is True
|
||||
assert result["state_created"] is True
|
||||
assert result["asset_info_id"] is not None
|
||||
|
||||
asset = get_asset_by_hash(db_session, asset_hash=asset_hash)
|
||||
assert asset is not None
|
||||
assert asset.size_bytes == len(b"fake png data")
|
||||
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
assert info is not None
|
||||
assert info.name == "test-asset"
|
||||
|
||||
cache_states = db_session.query(AssetCacheState).filter_by(asset_id=asset.id).all()
|
||||
assert len(cache_states) == 1
|
||||
assert cache_states[0].file_path == str(test_file)
|
||||
|
||||
def test_idempotent_on_same_file(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result1 = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
result2 = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert result1["asset_info_id"] == result2["asset_info_id"]
|
||||
assert result2["asset_created"] is False
|
||||
|
||||
def test_creates_with_tags(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info, asset, tags = fetch_asset_info_asset_and_tags(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
)
|
||||
assert set(tags) == {"tag1", "tag2"}
|
||||
|
||||
|
||||
class TestDeleteAssetInfoById:
|
||||
def test_deletes_existing_record(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="to-delete",
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
deleted = delete_asset_info_by_id(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert deleted is True
|
||||
assert get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"]) is None
|
||||
|
||||
def test_returns_false_for_nonexistent(self, db_session):
|
||||
result = delete_asset_info_by_id(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
owner_id="user1",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_respects_owner_visibility(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="owned-asset",
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
deleted = delete_asset_info_by_id(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
owner_id="different-user",
|
||||
)
|
||||
assert deleted is False
|
||||
|
||||
assert get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"]) is not None
|
||||
|
||||
|
||||
class TestTouchAssetInfoById:
|
||||
def test_updates_last_access_time(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info_before = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
original_time = info_before.last_access_time
|
||||
|
||||
new_time = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
|
||||
touch_asset_info_by_id(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
ts=new_time,
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.expire_all()
|
||||
|
||||
info_after = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
assert info_after.last_access_time == new_time
|
||||
assert info_after.last_access_time > original_time
|
||||
|
||||
def test_only_if_newer_respects_flag(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
original_time = info.last_access_time
|
||||
|
||||
older_time = original_time - timedelta(hours=1)
|
||||
touch_asset_info_by_id(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
ts=older_time,
|
||||
only_if_newer=True,
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.expire_all()
|
||||
|
||||
info_after = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
assert info_after.last_access_time == original_time
|
||||
|
||||
|
||||
class TestUpdateAssetInfoFull:
|
||||
def test_updates_name(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="original-name",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
updated = update_asset_info_full(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
name="new-name",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert updated.name == "new-name"
|
||||
|
||||
def test_updates_tags(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
update_asset_info_full(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["newtag1", "newtag2"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_, _, tags = fetch_asset_info_asset_and_tags(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
)
|
||||
assert set(tags) == {"newtag1", "newtag2"}
|
||||
|
||||
def test_updates_metadata(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
update_asset_info_full(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
user_metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.expire_all()
|
||||
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=result["asset_info_id"])
|
||||
assert "custom_key" in info.user_metadata
|
||||
assert info.user_metadata["custom_key"] == "custom_value"
|
||||
|
||||
def test_raises_on_invalid_id(self, db_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
update_asset_info_full(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
name="test",
|
||||
)
|
||||
|
||||
|
||||
class TestFetchAssetInfoAndAsset:
|
||||
def test_returns_tuple_when_exists(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
mime_type="image/png",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
fetched = fetch_asset_info_and_asset(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
)
|
||||
|
||||
assert fetched is not None
|
||||
info, asset = fetched
|
||||
assert info.name == "test"
|
||||
assert asset.hash == asset_hash
|
||||
assert asset.mime_type == "image/png"
|
||||
|
||||
def test_returns_none_when_missing(self, db_session):
|
||||
result = fetch_asset_info_and_asset(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_respects_owner_visibility(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
owner_id="user1",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
fetched = fetch_asset_info_and_asset(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
owner_id="different-user",
|
||||
)
|
||||
assert fetched is None
|
||||
|
||||
|
||||
class TestFetchAssetInfoAssetAndTags:
|
||||
def test_returns_tuple_with_tags(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
tags=["alpha", "beta"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
fetched = fetch_asset_info_asset_and_tags(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
)
|
||||
|
||||
assert fetched is not None
|
||||
info, asset, tags = fetched
|
||||
assert info.name == "test"
|
||||
assert asset.hash == asset_hash
|
||||
assert set(tags) == {"alpha", "beta"}
|
||||
|
||||
def test_returns_empty_tags_when_none(self, db_session, tmp_path):
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=4,
|
||||
mtime_ns=1000000,
|
||||
info_name="test",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
fetched = fetch_asset_info_asset_and_tags(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
)
|
||||
|
||||
assert fetched is not None
|
||||
info, asset, tags = fetched
|
||||
assert tags == []
|
||||
|
||||
def test_returns_none_when_missing(self, db_session):
|
||||
result = fetch_asset_info_asset_and_tags(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert result is None
|
||||
471
tests-unit/assets_test/queries_filter_test.py
Normal file
471
tests-unit/assets_test/queries_filter_test.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Tests for filtering and pagination query functions in app.assets.database.queries.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, delete
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, AssetInfoMeta, AssetCacheState, Tag
|
||||
from app.assets.database.queries import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
ingest_fs_asset,
|
||||
list_asset_infos_page,
|
||||
replace_asset_info_metadata_projection,
|
||||
visible_owner_clause,
|
||||
)
|
||||
from app.assets.helpers import utcnow
|
||||
from sqlalchemy import select
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_db_session():
|
||||
"""Create a fresh in-memory database for each test."""
|
||||
engine = create_engine("sqlite:///:memory:", echo=False)
|
||||
Base.metadata.create_all(engine)
|
||||
SessionLocal = sessionmaker(bind=engine)
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.rollback()
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def seed_assets(
|
||||
session: Session,
|
||||
tmp_path: Path,
|
||||
count: int = 10,
|
||||
owner_id: str = "",
|
||||
tag_sets: list[list[str]] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create test assets with varied tags.
|
||||
Returns list of asset_info_ids.
|
||||
"""
|
||||
asset_info_ids = []
|
||||
for i in range(count):
|
||||
f = tmp_path / f"test_{i}.png"
|
||||
f.write_bytes(b"x" * (100 + i))
|
||||
asset_hash = hashlib.sha256(f"unique-{uuid.uuid4()}".encode()).hexdigest()
|
||||
|
||||
if tag_sets is not None:
|
||||
tags = tag_sets[i % len(tag_sets)]
|
||||
else:
|
||||
tags = ["input"] if i % 2 == 0 else ["models", "loras"]
|
||||
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(f),
|
||||
size_bytes=100 + i,
|
||||
mtime_ns=1000000000 + i,
|
||||
mime_type="image/png",
|
||||
info_name=f"test_asset_{i}.png",
|
||||
owner_id=owner_id,
|
||||
tags=tags,
|
||||
)
|
||||
if result.get("asset_info_id"):
|
||||
asset_info_ids.append(result["asset_info_id"])
|
||||
|
||||
session.commit()
|
||||
return asset_info_ids
|
||||
|
||||
|
||||
class TestListAssetInfosPage:
|
||||
@pytest.fixture
|
||||
def seeded_db(self, clean_db_session, tmp_path):
|
||||
seed_assets(clean_db_session, tmp_path, 15, owner_id="")
|
||||
return clean_db_session
|
||||
|
||||
def test_pagination_limit(self, seeded_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
seeded_db, owner_id="", limit=5, offset=0
|
||||
)
|
||||
assert len(infos) <= 5
|
||||
assert total >= 5
|
||||
|
||||
def test_pagination_offset(self, seeded_db):
|
||||
first_page, _, total = list_asset_infos_page(
|
||||
seeded_db, owner_id="", limit=5, offset=0
|
||||
)
|
||||
second_page, _, _ = list_asset_infos_page(
|
||||
seeded_db, owner_id="", limit=5, offset=5
|
||||
)
|
||||
|
||||
first_ids = {i.id for i in first_page}
|
||||
second_ids = {i.id for i in second_page}
|
||||
assert first_ids.isdisjoint(second_ids)
|
||||
|
||||
def test_returns_tuple_with_tag_map(self, seeded_db):
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
seeded_db, owner_id="", limit=10, offset=0
|
||||
)
|
||||
assert isinstance(infos, list)
|
||||
assert isinstance(tag_map, dict)
|
||||
assert isinstance(total, int)
|
||||
|
||||
for info in infos:
|
||||
if info.id in tag_map:
|
||||
assert isinstance(tag_map[info.id], list)
|
||||
|
||||
def test_total_count_matches(self, seeded_db):
|
||||
_, _, total = list_asset_infos_page(seeded_db, owner_id="", limit=100, offset=0)
|
||||
assert total == 15
|
||||
|
||||
|
||||
class TestApplyTagFilters:
|
||||
@pytest.fixture
|
||||
def tagged_db(self, clean_db_session, tmp_path):
|
||||
tag_sets = [
|
||||
["alpha", "beta"],
|
||||
["alpha", "gamma"],
|
||||
["beta", "gamma"],
|
||||
["alpha", "beta", "gamma"],
|
||||
["delta"],
|
||||
]
|
||||
seed_assets(clean_db_session, tmp_path, 5, owner_id="", tag_sets=tag_sets)
|
||||
return clean_db_session
|
||||
|
||||
def test_include_tags_requires_all(self, tagged_db):
|
||||
infos, tag_map, _ = list_asset_infos_page(
|
||||
tagged_db,
|
||||
owner_id="",
|
||||
include_tags=["alpha", "beta"],
|
||||
limit=100,
|
||||
)
|
||||
for info in infos:
|
||||
tags = tag_map.get(info.id, [])
|
||||
assert "alpha" in tags and "beta" in tags
|
||||
|
||||
def test_include_single_tag(self, tagged_db):
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
tagged_db,
|
||||
owner_id="",
|
||||
include_tags=["alpha"],
|
||||
limit=100,
|
||||
)
|
||||
assert total >= 1
|
||||
for info in infos:
|
||||
tags = tag_map.get(info.id, [])
|
||||
assert "alpha" in tags
|
||||
|
||||
def test_exclude_tags_excludes_any(self, tagged_db):
|
||||
infos, tag_map, _ = list_asset_infos_page(
|
||||
tagged_db,
|
||||
owner_id="",
|
||||
exclude_tags=["delta"],
|
||||
limit=100,
|
||||
)
|
||||
for info in infos:
|
||||
tags = tag_map.get(info.id, [])
|
||||
assert "delta" not in tags
|
||||
|
||||
def test_exclude_multiple_tags(self, tagged_db):
|
||||
infos, tag_map, _ = list_asset_infos_page(
|
||||
tagged_db,
|
||||
owner_id="",
|
||||
exclude_tags=["alpha", "delta"],
|
||||
limit=100,
|
||||
)
|
||||
for info in infos:
|
||||
tags = tag_map.get(info.id, [])
|
||||
assert "alpha" not in tags
|
||||
assert "delta" not in tags
|
||||
|
||||
def test_combine_include_and_exclude(self, tagged_db):
|
||||
infos, tag_map, _ = list_asset_infos_page(
|
||||
tagged_db,
|
||||
owner_id="",
|
||||
include_tags=["alpha"],
|
||||
exclude_tags=["gamma"],
|
||||
limit=100,
|
||||
)
|
||||
for info in infos:
|
||||
tags = tag_map.get(info.id, [])
|
||||
assert "alpha" in tags
|
||||
assert "gamma" not in tags
|
||||
|
||||
|
||||
class TestApplyMetadataFilter:
|
||||
@pytest.fixture
|
||||
def metadata_db(self, clean_db_session, tmp_path):
|
||||
ids = seed_assets(clean_db_session, tmp_path, 5, owner_id="")
|
||||
metadata_sets = [
|
||||
{"author": "alice", "version": 1},
|
||||
{"author": "bob", "version": 2},
|
||||
{"author": "alice", "version": 2},
|
||||
{"author": "charlie", "active": True},
|
||||
{"author": "alice", "active": False},
|
||||
]
|
||||
for i, info_id in enumerate(ids):
|
||||
replace_asset_info_metadata_projection(
|
||||
clean_db_session,
|
||||
asset_info_id=info_id,
|
||||
user_metadata=metadata_sets[i],
|
||||
)
|
||||
clean_db_session.commit()
|
||||
return clean_db_session
|
||||
|
||||
def test_filter_by_string_value(self, metadata_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
metadata_db,
|
||||
owner_id="",
|
||||
metadata_filter={"author": "alice"},
|
||||
limit=100,
|
||||
)
|
||||
assert total == 3
|
||||
for info in infos:
|
||||
assert info.user_metadata.get("author") == "alice"
|
||||
|
||||
def test_filter_by_numeric_value(self, metadata_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
metadata_db,
|
||||
owner_id="",
|
||||
metadata_filter={"version": 2},
|
||||
limit=100,
|
||||
)
|
||||
assert total == 2
|
||||
|
||||
def test_filter_by_boolean_value(self, metadata_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
metadata_db,
|
||||
owner_id="",
|
||||
metadata_filter={"active": True},
|
||||
limit=100,
|
||||
)
|
||||
assert total == 1
|
||||
|
||||
def test_filter_by_multiple_keys(self, metadata_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
metadata_db,
|
||||
owner_id="",
|
||||
metadata_filter={"author": "alice", "version": 2},
|
||||
limit=100,
|
||||
)
|
||||
assert total == 1
|
||||
|
||||
def test_filter_with_list_values(self, metadata_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
metadata_db,
|
||||
owner_id="",
|
||||
metadata_filter={"author": ["alice", "bob"]},
|
||||
limit=100,
|
||||
)
|
||||
assert total == 4
|
||||
|
||||
|
||||
class TestVisibleOwnerClause:
|
||||
@pytest.fixture
|
||||
def multi_owner_db(self, clean_db_session, tmp_path):
|
||||
seed_assets(clean_db_session, tmp_path, 3, owner_id="user1")
|
||||
seed_assets(clean_db_session, tmp_path, 2, owner_id="user2")
|
||||
seed_assets(clean_db_session, tmp_path, 4, owner_id="")
|
||||
return clean_db_session
|
||||
|
||||
def test_empty_owner_sees_only_public(self, multi_owner_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
multi_owner_db,
|
||||
owner_id="",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 4
|
||||
for info in infos:
|
||||
assert info.owner_id == ""
|
||||
|
||||
def test_owner_sees_own_plus_public(self, multi_owner_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
multi_owner_db,
|
||||
owner_id="user1",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 7
|
||||
owner_ids = {info.owner_id for info in infos}
|
||||
assert owner_ids == {"user1", ""}
|
||||
|
||||
def test_owner_sees_only_own_and_public(self, multi_owner_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
multi_owner_db,
|
||||
owner_id="user2",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 6
|
||||
owner_ids = {info.owner_id for info in infos}
|
||||
assert owner_ids == {"user2", ""}
|
||||
assert all(info.owner_id in ("user2", "") for info in infos)
|
||||
|
||||
def test_nonexistent_owner_sees_public(self, multi_owner_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
multi_owner_db,
|
||||
owner_id="unknown-user",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 4
|
||||
for info in infos:
|
||||
assert info.owner_id == ""
|
||||
|
||||
|
||||
class TestSorting:
|
||||
@pytest.fixture
|
||||
def sortable_db(self, clean_db_session, tmp_path):
|
||||
import time
|
||||
|
||||
ids = []
|
||||
names = ["zebra.png", "alpha.png", "mango.png"]
|
||||
sizes = [500, 100, 300]
|
||||
|
||||
for i, name in enumerate(names):
|
||||
f = tmp_path / f"sort_{i}.png"
|
||||
f.write_bytes(b"x" * sizes[i])
|
||||
asset_hash = hashlib.sha256(f"sort-{uuid.uuid4()}".encode()).hexdigest()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
clean_db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(f),
|
||||
size_bytes=sizes[i],
|
||||
mtime_ns=1000000000 + i,
|
||||
mime_type="image/png",
|
||||
info_name=name,
|
||||
owner_id="",
|
||||
tags=["test"],
|
||||
)
|
||||
ids.append(result["asset_info_id"])
|
||||
time.sleep(0.01)
|
||||
|
||||
clean_db_session.commit()
|
||||
return clean_db_session
|
||||
|
||||
def test_sort_by_name_asc(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="name",
|
||||
order="asc",
|
||||
limit=100,
|
||||
)
|
||||
names = [i.name for i in infos]
|
||||
assert names == sorted(names)
|
||||
|
||||
def test_sort_by_name_desc(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="name",
|
||||
order="desc",
|
||||
limit=100,
|
||||
)
|
||||
names = [i.name for i in infos]
|
||||
assert names == sorted(names, reverse=True)
|
||||
|
||||
def test_sort_by_size(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="size",
|
||||
order="asc",
|
||||
limit=100,
|
||||
)
|
||||
sizes = [i.asset.size_bytes for i in infos]
|
||||
assert sizes == sorted(sizes)
|
||||
|
||||
def test_sort_by_created_at_desc(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="created_at",
|
||||
order="desc",
|
||||
limit=100,
|
||||
)
|
||||
dates = [i.created_at for i in infos]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
def test_sort_by_updated_at(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="updated_at",
|
||||
order="desc",
|
||||
limit=100,
|
||||
)
|
||||
dates = [i.updated_at for i in infos]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
def test_sort_by_last_access_time(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="last_access_time",
|
||||
order="asc",
|
||||
limit=100,
|
||||
)
|
||||
times = [i.last_access_time for i in infos]
|
||||
assert times == sorted(times)
|
||||
|
||||
def test_invalid_sort_defaults_to_created_at(self, sortable_db):
|
||||
infos, _, _ = list_asset_infos_page(
|
||||
sortable_db,
|
||||
owner_id="",
|
||||
sort="invalid_column",
|
||||
order="desc",
|
||||
limit=100,
|
||||
)
|
||||
dates = [i.created_at for i in infos]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
class TestNameContainsFilter:
|
||||
@pytest.fixture
|
||||
def named_db(self, clean_db_session, tmp_path):
|
||||
names = ["report_2023.pdf", "report_2024.pdf", "image.png", "data.csv"]
|
||||
for i, name in enumerate(names):
|
||||
f = tmp_path / f"file_{i}.bin"
|
||||
f.write_bytes(b"x" * 100)
|
||||
asset_hash = hashlib.sha256(f"named-{uuid.uuid4()}".encode()).hexdigest()
|
||||
ingest_fs_asset(
|
||||
clean_db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(f),
|
||||
size_bytes=100,
|
||||
mtime_ns=1000000000,
|
||||
mime_type="application/octet-stream",
|
||||
info_name=name,
|
||||
owner_id="",
|
||||
tags=["test"],
|
||||
)
|
||||
clean_db_session.commit()
|
||||
return clean_db_session
|
||||
|
||||
def test_name_contains_filter(self, named_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
named_db,
|
||||
owner_id="",
|
||||
name_contains="report",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 2
|
||||
for info in infos:
|
||||
assert "report" in info.name.lower()
|
||||
|
||||
def test_name_contains_case_insensitive(self, named_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
named_db,
|
||||
owner_id="",
|
||||
name_contains="REPORT",
|
||||
limit=100,
|
||||
)
|
||||
assert total == 2
|
||||
|
||||
def test_name_contains_partial_match(self, named_db):
|
||||
infos, _, total = list_asset_infos_page(
|
||||
named_db,
|
||||
owner_id="",
|
||||
name_contains=".p",
|
||||
limit=100,
|
||||
)
|
||||
assert total >= 1
|
||||
380
tests-unit/assets_test/queries_tags_test.py
Normal file
380
tests-unit/assets_test/queries_tags_test.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Tests for tag-related database query functions in app.assets.database.queries.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
get_asset_tags,
|
||||
list_tags_with_usage,
|
||||
set_asset_info_preview,
|
||||
ingest_fs_asset,
|
||||
get_asset_by_hash,
|
||||
)
|
||||
|
||||
|
||||
def make_unique_hash() -> str:
|
||||
return "blake3:" + uuid.uuid4().hex + uuid.uuid4().hex
|
||||
|
||||
|
||||
def create_test_asset(db_session, tmp_path, name="test", tags=None, owner_id=""):
|
||||
test_file = tmp_path / f"{name}.png"
|
||||
test_file.write_bytes(b"fake png data")
|
||||
asset_hash = make_unique_hash()
|
||||
|
||||
result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=str(test_file),
|
||||
size_bytes=len(b"fake png data"),
|
||||
mtime_ns=1000000,
|
||||
mime_type="image/png",
|
||||
info_name=name,
|
||||
owner_id=owner_id,
|
||||
tags=tags,
|
||||
)
|
||||
db_session.flush()
|
||||
return result
|
||||
|
||||
|
||||
class TestAddTagsToAssetInfo:
|
||||
def test_adds_new_tags(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-add-tags")
|
||||
|
||||
add_result = add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["tag1", "tag2"],
|
||||
origin="manual",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert set(add_result["added"]) == {"tag1", "tag2"}
|
||||
assert add_result["already_present"] == []
|
||||
assert set(add_result["total_tags"]) == {"tag1", "tag2"}
|
||||
|
||||
def test_idempotent_on_duplicates(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-idempotent")
|
||||
|
||||
add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["dup-tag"],
|
||||
origin="manual",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
second_result = add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["dup-tag"],
|
||||
origin="manual",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert second_result["added"] == []
|
||||
assert second_result["already_present"] == ["dup-tag"]
|
||||
assert second_result["total_tags"] == ["dup-tag"]
|
||||
|
||||
def test_mixed_new_and_existing_tags(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-mixed", tags=["existing"])
|
||||
|
||||
add_result = add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["existing", "new-tag"],
|
||||
origin="manual",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert add_result["added"] == ["new-tag"]
|
||||
assert add_result["already_present"] == ["existing"]
|
||||
assert set(add_result["total_tags"]) == {"existing", "new-tag"}
|
||||
|
||||
def test_empty_tags_list(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-empty", tags=["pre-existing"])
|
||||
|
||||
add_result = add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=[],
|
||||
origin="manual",
|
||||
)
|
||||
|
||||
assert add_result["added"] == []
|
||||
assert add_result["already_present"] == []
|
||||
assert add_result["total_tags"] == ["pre-existing"]
|
||||
|
||||
def test_raises_on_invalid_asset_info_id(self, db_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
tags=["tag1"],
|
||||
origin="manual",
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveTagsFromAssetInfo:
|
||||
def test_removes_existing_tags(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-remove", tags=["tag1", "tag2", "tag3"])
|
||||
|
||||
remove_result = remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert set(remove_result["removed"]) == {"tag1", "tag2"}
|
||||
assert remove_result["not_present"] == []
|
||||
assert remove_result["total_tags"] == ["tag3"]
|
||||
|
||||
def test_handles_nonexistent_tags_gracefully(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-nonexistent", tags=["existing"])
|
||||
|
||||
remove_result = remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["nonexistent"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert remove_result["removed"] == []
|
||||
assert remove_result["not_present"] == ["nonexistent"]
|
||||
assert remove_result["total_tags"] == ["existing"]
|
||||
|
||||
def test_mixed_existing_and_nonexistent(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-mixed-remove", tags=["tag1", "tag2"])
|
||||
|
||||
remove_result = remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=["tag1", "nonexistent"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
assert remove_result["removed"] == ["tag1"]
|
||||
assert remove_result["not_present"] == ["nonexistent"]
|
||||
assert remove_result["total_tags"] == ["tag2"]
|
||||
|
||||
def test_empty_tags_list(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-empty-remove", tags=["existing"])
|
||||
|
||||
remove_result = remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=result["asset_info_id"],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
assert remove_result["removed"] == []
|
||||
assert remove_result["not_present"] == []
|
||||
assert remove_result["total_tags"] == ["existing"]
|
||||
|
||||
def test_raises_on_invalid_asset_info_id(self, db_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
tags=["tag1"],
|
||||
)
|
||||
|
||||
|
||||
class TestGetAssetTags:
|
||||
def test_returns_list_of_tag_names(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-get-tags", tags=["alpha", "beta", "gamma"])
|
||||
|
||||
tags = get_asset_tags(db_session, asset_info_id=result["asset_info_id"])
|
||||
|
||||
assert set(tags) == {"alpha", "beta", "gamma"}
|
||||
|
||||
def test_returns_empty_list_when_no_tags(self, db_session, tmp_path):
|
||||
result = create_test_asset(db_session, tmp_path, name="test-no-tags")
|
||||
|
||||
tags = get_asset_tags(db_session, asset_info_id=result["asset_info_id"])
|
||||
|
||||
assert tags == []
|
||||
|
||||
def test_returns_empty_for_nonexistent_asset(self, db_session):
|
||||
tags = get_asset_tags(db_session, asset_info_id=str(uuid.uuid4()))
|
||||
|
||||
assert tags == []
|
||||
|
||||
|
||||
class TestListTagsWithUsage:
|
||||
def test_returns_tags_with_counts(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="asset1", tags=["shared-tag", "unique1"])
|
||||
create_test_asset(db_session, tmp_path, name="asset2", tags=["shared-tag", "unique2"])
|
||||
create_test_asset(db_session, tmp_path, name="asset3", tags=["shared-tag"])
|
||||
|
||||
tags, total = list_tags_with_usage(db_session)
|
||||
|
||||
tag_dict = {name: count for name, _, count in tags}
|
||||
assert tag_dict["shared-tag"] == 3
|
||||
assert tag_dict.get("unique1", 0) == 1
|
||||
assert tag_dict.get("unique2", 0) == 1
|
||||
|
||||
def test_prefix_filtering(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="asset-prefix", tags=["prefix-a", "prefix-b", "other"])
|
||||
|
||||
tags, total = list_tags_with_usage(db_session, prefix="prefix")
|
||||
|
||||
tag_names = [name for name, _, _ in tags]
|
||||
assert "prefix-a" in tag_names
|
||||
assert "prefix-b" in tag_names
|
||||
assert "other" not in tag_names
|
||||
|
||||
def test_pagination(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="asset-page", tags=["page1", "page2", "page3", "page4", "page5"])
|
||||
|
||||
first_page, _ = list_tags_with_usage(db_session, limit=2, offset=0)
|
||||
second_page, _ = list_tags_with_usage(db_session, limit=2, offset=2)
|
||||
|
||||
first_names = {name for name, _, _ in first_page}
|
||||
second_names = {name for name, _, _ in second_page}
|
||||
|
||||
assert len(first_page) == 2
|
||||
assert len(second_page) == 2
|
||||
assert first_names.isdisjoint(second_names)
|
||||
|
||||
def test_order_by_count_desc(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="count1", tags=["popular", "rare"])
|
||||
create_test_asset(db_session, tmp_path, name="count2", tags=["popular"])
|
||||
create_test_asset(db_session, tmp_path, name="count3", tags=["popular"])
|
||||
|
||||
tags, _ = list_tags_with_usage(db_session, order="count_desc", include_zero=False)
|
||||
|
||||
counts = [count for _, _, count in tags]
|
||||
assert counts == sorted(counts, reverse=True)
|
||||
|
||||
def test_order_by_name_asc(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="name-order", tags=["zebra", "apple", "mango"])
|
||||
|
||||
tags, _ = list_tags_with_usage(db_session, order="name_asc", include_zero=False)
|
||||
|
||||
names = [name for name, _, _ in tags]
|
||||
assert names == sorted(names)
|
||||
|
||||
def test_include_zero_false_excludes_unused_tags(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="used-tag-asset", tags=["used-tag"])
|
||||
|
||||
add_tags_to_asset_info(
|
||||
db_session,
|
||||
asset_info_id=create_test_asset(db_session, tmp_path, name="temp")["asset_info_id"],
|
||||
tags=["orphan-tag"],
|
||||
origin="manual",
|
||||
)
|
||||
db_session.flush()
|
||||
remove_tags_from_asset_info(
|
||||
db_session,
|
||||
asset_info_id=create_test_asset(db_session, tmp_path, name="temp2")["asset_info_id"],
|
||||
tags=["orphan-tag"],
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
tags_with_zero, _ = list_tags_with_usage(db_session, include_zero=True)
|
||||
tags_without_zero, _ = list_tags_with_usage(db_session, include_zero=False)
|
||||
|
||||
with_zero_names = {name for name, _, _ in tags_with_zero}
|
||||
without_zero_names = {name for name, _, _ in tags_without_zero}
|
||||
|
||||
assert "used-tag" in without_zero_names
|
||||
assert len(without_zero_names) <= len(with_zero_names)
|
||||
|
||||
def test_respects_owner_visibility(self, db_session, tmp_path):
|
||||
create_test_asset(db_session, tmp_path, name="user1-asset", tags=["user1-tag"], owner_id="user1")
|
||||
create_test_asset(db_session, tmp_path, name="user2-asset", tags=["user2-tag"], owner_id="user2")
|
||||
|
||||
user1_tags, _ = list_tags_with_usage(db_session, owner_id="user1", include_zero=False)
|
||||
|
||||
user1_tag_names = {name for name, _, _ in user1_tags}
|
||||
assert "user1-tag" in user1_tag_names
|
||||
|
||||
|
||||
class TestSetAssetInfoPreview:
|
||||
def test_sets_preview_id(self, db_session, tmp_path):
|
||||
asset_result = create_test_asset(db_session, tmp_path, name="main-asset")
|
||||
|
||||
preview_file = tmp_path / "preview.png"
|
||||
preview_file.write_bytes(b"preview data")
|
||||
preview_hash = make_unique_hash()
|
||||
preview_result = ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=preview_hash,
|
||||
abs_path=str(preview_file),
|
||||
size_bytes=len(b"preview data"),
|
||||
mtime_ns=1000000,
|
||||
mime_type="image/png",
|
||||
info_name="preview",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
preview_asset = get_asset_by_hash(db_session, asset_hash=preview_hash)
|
||||
|
||||
set_asset_info_preview(
|
||||
db_session,
|
||||
asset_info_id=asset_result["asset_info_id"],
|
||||
preview_asset_id=preview_asset.id,
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
from app.assets.database.queries import get_asset_info_by_id
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=asset_result["asset_info_id"])
|
||||
assert info.preview_id == preview_asset.id
|
||||
|
||||
def test_clears_preview_with_none(self, db_session, tmp_path):
|
||||
asset_result = create_test_asset(db_session, tmp_path, name="clear-preview")
|
||||
|
||||
preview_file = tmp_path / "preview2.png"
|
||||
preview_file.write_bytes(b"preview data")
|
||||
preview_hash = make_unique_hash()
|
||||
ingest_fs_asset(
|
||||
db_session,
|
||||
asset_hash=preview_hash,
|
||||
abs_path=str(preview_file),
|
||||
size_bytes=len(b"preview data"),
|
||||
mtime_ns=1000000,
|
||||
info_name="preview2",
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
preview_asset = get_asset_by_hash(db_session, asset_hash=preview_hash)
|
||||
|
||||
set_asset_info_preview(
|
||||
db_session,
|
||||
asset_info_id=asset_result["asset_info_id"],
|
||||
preview_asset_id=preview_asset.id,
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
set_asset_info_preview(
|
||||
db_session,
|
||||
asset_info_id=asset_result["asset_info_id"],
|
||||
preview_asset_id=None,
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
from app.assets.database.queries import get_asset_info_by_id
|
||||
info = get_asset_info_by_id(db_session, asset_info_id=asset_result["asset_info_id"])
|
||||
assert info.preview_id is None
|
||||
|
||||
def test_raises_on_invalid_asset_info_id(self, db_session):
|
||||
with pytest.raises(ValueError, match="AssetInfo.*not found"):
|
||||
set_asset_info_preview(
|
||||
db_session,
|
||||
asset_info_id=str(uuid.uuid4()),
|
||||
preview_asset_id=None,
|
||||
)
|
||||
|
||||
def test_raises_on_invalid_preview_asset_id(self, db_session, tmp_path):
|
||||
asset_result = create_test_asset(db_session, tmp_path, name="invalid-preview")
|
||||
|
||||
with pytest.raises(ValueError, match="Preview Asset.*not found"):
|
||||
set_asset_info_preview(
|
||||
db_session,
|
||||
asset_info_id=asset_result["asset_info_id"],
|
||||
preview_asset_id=str(uuid.uuid4()),
|
||||
)
|
||||
340
tests-unit/assets_test/routes_read_update_test.py
Normal file
340
tests-unit/assets_test/routes_read_update_test.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Tests for read and update endpoints in the assets API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
from aiohttp import FormData
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def make_mock_asset(asset_id=None, name="Test Asset", tags=None, user_metadata=None, preview_id=None):
|
||||
"""Helper to create a mock asset result."""
|
||||
if asset_id is None:
|
||||
asset_id = str(uuid.uuid4())
|
||||
if tags is None:
|
||||
tags = ["input"]
|
||||
if user_metadata is None:
|
||||
user_metadata = {}
|
||||
|
||||
mock = MagicMock()
|
||||
mock.model_dump.return_value = {
|
||||
"id": asset_id,
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"user_metadata": user_metadata,
|
||||
"preview_id": preview_id,
|
||||
}
|
||||
return mock
|
||||
|
||||
|
||||
def make_mock_list_result(assets, total=None):
|
||||
"""Helper to create a mock list result."""
|
||||
if total is None:
|
||||
total = len(assets)
|
||||
mock = MagicMock()
|
||||
mock.model_dump.return_value = {
|
||||
"assets": [a.model_dump() if hasattr(a, 'model_dump') else a for a in assets],
|
||||
"total": total,
|
||||
}
|
||||
return mock
|
||||
|
||||
|
||||
class TestListAssets:
|
||||
async def test_returns_list(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "Asset 1", "tags": ["input"]},
|
||||
], total=1)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'assets' in body
|
||||
assert 'total' in body
|
||||
assert body['total'] == 1
|
||||
|
||||
async def test_returns_list_with_pagination(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "Asset 1", "tags": ["input"]},
|
||||
{"id": str(uuid.uuid4()), "name": "Asset 2", "tags": ["input"]},
|
||||
], total=5)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets?limit=2&offset=0')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert len(body['assets']) == 2
|
||||
assert body['total'] == 5
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert call_kwargs['limit'] == 2
|
||||
assert call_kwargs['offset'] == 0
|
||||
|
||||
async def test_filter_by_include_tags(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "Special Asset", "tags": ["special"]},
|
||||
], total=1)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets?include_tags=special')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
for asset in body['assets']:
|
||||
assert 'special' in asset.get('tags', [])
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert 'special' in call_kwargs['include_tags']
|
||||
|
||||
async def test_filter_by_exclude_tags(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "Kept Asset", "tags": ["keep"]},
|
||||
], total=1)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets?exclude_tags=exclude_me')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
for asset in body['assets']:
|
||||
assert 'exclude_me' not in asset.get('tags', [])
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert 'exclude_me' in call_kwargs['exclude_tags']
|
||||
|
||||
async def test_filter_by_name_contains(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "UniqueSearchName", "tags": ["input"]},
|
||||
], total=1)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets?name_contains=UniqueSearch')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
for asset in body['assets']:
|
||||
assert 'UniqueSearch' in asset.get('name', '')
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert call_kwargs['name_contains'] == 'UniqueSearch'
|
||||
|
||||
async def test_sort_and_order(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.list_assets") as mock_list:
|
||||
mock_list.return_value = make_mock_list_result([
|
||||
{"id": str(uuid.uuid4()), "name": "Alpha", "tags": ["input"]},
|
||||
{"id": str(uuid.uuid4()), "name": "Zeta", "tags": ["input"]},
|
||||
], total=2)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/api/assets?sort=name&order=asc')
|
||||
assert resp.status == 200
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert call_kwargs['sort'] == 'name'
|
||||
assert call_kwargs['order'] == 'asc'
|
||||
|
||||
|
||||
class TestGetAssetById:
|
||||
async def test_returns_asset(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.get_asset") as mock_get:
|
||||
mock_get.return_value = make_mock_asset(asset_id=asset_id, name="Test Asset")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{asset_id}')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body['id'] == asset_id
|
||||
|
||||
async def test_returns_404_for_missing_id(self, aiohttp_client, app):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.get_asset") as mock_get:
|
||||
mock_get.side_effect = ValueError("Asset not found")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{fake_id}')
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'ASSET_NOT_FOUND'
|
||||
|
||||
async def test_returns_404_for_wrong_owner(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.get_asset") as mock_get:
|
||||
mock_get.side_effect = ValueError("Asset not found for this owner")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{asset_id}')
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'ASSET_NOT_FOUND'
|
||||
|
||||
|
||||
class TestDownloadAssetContent:
|
||||
async def test_returns_file_content(self, aiohttp_client, app, test_image_bytes, tmp_path):
|
||||
asset_id = str(uuid.uuid4())
|
||||
test_file = tmp_path / "test_image.png"
|
||||
test_file.write_bytes(test_image_bytes)
|
||||
|
||||
with patch("app.assets.manager.resolve_asset_content_for_download") as mock_resolve:
|
||||
mock_resolve.return_value = (str(test_file), "image/png", "test_image.png")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{asset_id}/content')
|
||||
assert resp.status == 200
|
||||
assert 'image' in resp.content_type
|
||||
|
||||
async def test_sets_content_disposition_header(self, aiohttp_client, app, test_image_bytes, tmp_path):
|
||||
asset_id = str(uuid.uuid4())
|
||||
test_file = tmp_path / "test_image.png"
|
||||
test_file.write_bytes(test_image_bytes)
|
||||
|
||||
with patch("app.assets.manager.resolve_asset_content_for_download") as mock_resolve:
|
||||
mock_resolve.return_value = (str(test_file), "image/png", "test_image.png")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{asset_id}/content')
|
||||
assert resp.status == 200
|
||||
assert 'Content-Disposition' in resp.headers
|
||||
assert 'test_image.png' in resp.headers['Content-Disposition']
|
||||
|
||||
async def test_returns_404_for_missing_asset(self, aiohttp_client, app):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.resolve_asset_content_for_download") as mock_resolve:
|
||||
mock_resolve.side_effect = ValueError("Asset not found")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{fake_id}/content')
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'ASSET_NOT_FOUND'
|
||||
|
||||
async def test_returns_404_for_missing_file(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.resolve_asset_content_for_download") as mock_resolve:
|
||||
mock_resolve.side_effect = FileNotFoundError("File not found on disk")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get(f'/api/assets/{asset_id}/content')
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'FILE_NOT_FOUND'
|
||||
|
||||
|
||||
class TestUpdateAsset:
|
||||
async def test_update_name(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.update_asset") as mock_update:
|
||||
mock_update.return_value = make_mock_asset(asset_id=asset_id, name="New Name")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}', json={'name': 'New Name'})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body['name'] == 'New Name'
|
||||
mock_update.assert_called_once()
|
||||
call_kwargs = mock_update.call_args.kwargs
|
||||
assert call_kwargs['name'] == 'New Name'
|
||||
|
||||
async def test_update_tags(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.update_asset") as mock_update:
|
||||
mock_update.return_value = make_mock_asset(
|
||||
asset_id=asset_id, tags=['new_tag', 'another_tag']
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}', json={'tags': ['new_tag', 'another_tag']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'new_tag' in body.get('tags', [])
|
||||
assert 'another_tag' in body.get('tags', [])
|
||||
mock_update.assert_called_once()
|
||||
call_kwargs = mock_update.call_args.kwargs
|
||||
assert call_kwargs['tags'] == ['new_tag', 'another_tag']
|
||||
|
||||
async def test_update_user_metadata(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.update_asset") as mock_update:
|
||||
mock_update.return_value = make_mock_asset(
|
||||
asset_id=asset_id, user_metadata={'key': 'value'}
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}', json={'user_metadata': {'key': 'value'}})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body.get('user_metadata', {}).get('key') == 'value'
|
||||
mock_update.assert_called_once()
|
||||
call_kwargs = mock_update.call_args.kwargs
|
||||
assert call_kwargs['user_metadata'] == {'key': 'value'}
|
||||
|
||||
async def test_returns_400_on_empty_body(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}', data=b'')
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'INVALID_JSON'
|
||||
|
||||
async def test_returns_404_for_missing_asset(self, aiohttp_client, app):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.update_asset") as mock_update:
|
||||
mock_update.side_effect = ValueError("Asset not found")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{fake_id}', json={'name': 'New Name'})
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'ASSET_NOT_FOUND'
|
||||
|
||||
|
||||
class TestSetAssetPreview:
|
||||
async def test_sets_preview_id(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
preview_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.set_asset_preview") as mock_set_preview:
|
||||
mock_set_preview.return_value = make_mock_asset(
|
||||
asset_id=asset_id, preview_id=preview_id
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}/preview', json={'preview_id': preview_id})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body.get('preview_id') == preview_id
|
||||
mock_set_preview.assert_called_once()
|
||||
call_kwargs = mock_set_preview.call_args.kwargs
|
||||
assert call_kwargs['preview_asset_id'] == preview_id
|
||||
|
||||
async def test_clears_preview_with_null(self, aiohttp_client, app):
|
||||
asset_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.set_asset_preview") as mock_set_preview:
|
||||
mock_set_preview.return_value = make_mock_asset(
|
||||
asset_id=asset_id, preview_id=None
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{asset_id}/preview', json={'preview_id': None})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body.get('preview_id') is None
|
||||
mock_set_preview.assert_called_once()
|
||||
call_kwargs = mock_set_preview.call_args.kwargs
|
||||
assert call_kwargs['preview_asset_id'] is None
|
||||
|
||||
async def test_returns_404_for_missing_asset(self, aiohttp_client, app):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("app.assets.manager.set_asset_preview") as mock_set_preview:
|
||||
mock_set_preview.side_effect = ValueError("Asset not found")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.put(f'/api/assets/{fake_id}/preview', json={'preview_id': None})
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body['error']['code'] == 'ASSET_NOT_FOUND'
|
||||
175
tests-unit/assets_test/routes_tags_delete_test.py
Normal file
175
tests-unit/assets_test/routes_tags_delete_test.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Tests for tag management and delete endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import FormData
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def create_test_asset(client, test_image_bytes, tags=None):
|
||||
"""Helper to create a test asset."""
|
||||
data = FormData()
|
||||
data.add_field('file', test_image_bytes, filename='test.png', content_type='image/png')
|
||||
data.add_field('tags', tags or 'input')
|
||||
data.add_field('name', 'Test Asset')
|
||||
resp = await client.post('/api/assets', data=data)
|
||||
return await resp.json()
|
||||
|
||||
|
||||
class TestListTags:
|
||||
async def test_returns_tags(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.get('/api/tags')
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'tags' in body
|
||||
|
||||
async def test_prefix_filtering(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
await create_test_asset(client, test_image_bytes, tags='input,mytag')
|
||||
|
||||
resp = await client.get('/api/tags', params={'prefix': 'my'})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'tags' in body
|
||||
|
||||
async def test_pagination(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.get('/api/tags', params={'limit': 10, 'offset': 0})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'tags' in body
|
||||
|
||||
async def test_order_by_count_desc(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.get('/api/tags', params={'order': 'count_desc'})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'tags' in body
|
||||
|
||||
async def test_order_by_name_asc(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.get('/api/tags', params={'order': 'name_asc'})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'tags' in body
|
||||
|
||||
|
||||
class TestAddAssetTags:
|
||||
async def test_add_tags_success(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.post(f'/api/assets/{asset["id"]}/tags', json={'tags': ['newtag']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'added' in body or 'total_tags' in body
|
||||
|
||||
async def test_add_tags_returns_already_present(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes, tags='input,existingtag')
|
||||
|
||||
resp = await client.post(f'/api/assets/{asset["id"]}/tags', json={'tags': ['existingtag']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'already_present' in body or 'added' in body
|
||||
|
||||
async def test_add_tags_missing_asset_returns_404(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post('/api/assets/00000000-0000-0000-0000-000000000000/tags', json={'tags': ['newtag']})
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_add_tags_empty_tags_returns_400(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.post(f'/api/assets/{asset["id"]}/tags', json={'tags': []})
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
class TestDeleteAssetTags:
|
||||
async def test_remove_tags_success(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes, tags='input,removeme')
|
||||
|
||||
resp = await client.delete(f'/api/assets/{asset["id"]}/tags', json={'tags': ['removeme']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'removed' in body or 'total_tags' in body
|
||||
|
||||
async def test_remove_tags_returns_not_present(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.delete(f'/api/assets/{asset["id"]}/tags', json={'tags': ['nonexistent']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert 'not_present' in body or 'removed' in body
|
||||
|
||||
async def test_remove_tags_missing_asset_returns_404(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.delete('/api/assets/00000000-0000-0000-0000-000000000000/tags', json={'tags': ['sometag']})
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_remove_tags_empty_tags_returns_400(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.delete(f'/api/assets/{asset["id"]}/tags', json={'tags': []})
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
class TestDeleteAsset:
|
||||
async def test_delete_success(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
|
||||
resp = await client.delete(f'/api/assets/{asset["id"]}')
|
||||
assert resp.status == 204
|
||||
|
||||
resp = await client.get(f'/api/assets/{asset["id"]}')
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_delete_missing_asset_returns_404(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.delete('/api/assets/00000000-0000-0000-0000-000000000000')
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_delete_with_delete_content_false(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
asset = await create_test_asset(client, test_image_bytes)
|
||||
if 'id' not in asset:
|
||||
pytest.skip("Asset creation failed due to transient DB session issue")
|
||||
|
||||
resp = await client.delete(f'/api/assets/{asset["id"]}', params={'delete_content': 'false'})
|
||||
assert resp.status == 204
|
||||
|
||||
resp = await client.get(f'/api/assets/{asset["id"]}')
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
class TestSeedAssets:
|
||||
async def test_seed_returns_200(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/api/assets/scan/seed', json={'roots': ['input']})
|
||||
assert resp.status == 200
|
||||
|
||||
async def test_seed_accepts_roots_parameter(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/api/assets/scan/seed', json={'roots': ['input', 'output']})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body.get('roots') == ['input', 'output']
|
||||
240
tests-unit/assets_test/routes_upload_test.py
Normal file
240
tests-unit/assets_test/routes_upload_test.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Tests for upload and create endpoints in assets API routes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import FormData
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestUploadAsset:
|
||||
"""Tests for POST /api/assets (multipart upload)."""
|
||||
|
||||
async def test_upload_success(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
with patch("app.assets.manager.upload_asset_from_temp_path") as mock_upload:
|
||||
mock_result = MagicMock()
|
||||
mock_result.created_new = True
|
||||
mock_result.model_dump.return_value = {
|
||||
"id": "11111111-1111-1111-1111-111111111111",
|
||||
"name": "Test Asset",
|
||||
"tags": ["input"],
|
||||
}
|
||||
mock_upload.return_value = mock_result
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("file", test_image_bytes, filename="test.png", content_type="image/png")
|
||||
data.add_field("tags", "input")
|
||||
data.add_field("name", "Test Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status == 201
|
||||
body = await resp.json()
|
||||
assert "id" in body
|
||||
assert body["name"] == "Test Asset"
|
||||
|
||||
async def test_upload_existing_hash_returns_200(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
with patch("app.assets.manager.asset_exists", return_value=True):
|
||||
with patch("app.assets.manager.create_asset_from_hash") as mock_create:
|
||||
mock_result = MagicMock()
|
||||
mock_result.created_new = False
|
||||
mock_result.model_dump.return_value = {
|
||||
"id": "22222222-2222-2222-2222-222222222222",
|
||||
"name": "Existing Asset",
|
||||
"tags": ["input"],
|
||||
}
|
||||
mock_create.return_value = mock_result
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("hash", "blake3:" + "a" * 64)
|
||||
data.add_field("file", test_image_bytes, filename="test.png", content_type="image/png")
|
||||
data.add_field("tags", "input")
|
||||
data.add_field("name", "Existing Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert "id" in body
|
||||
|
||||
async def test_upload_missing_file_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("tags", "input")
|
||||
data.add_field("name", "No File Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status in (400, 415)
|
||||
|
||||
async def test_upload_empty_file_returns_400(self, aiohttp_client, app, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("file", b"", filename="empty.png", content_type="image/png")
|
||||
data.add_field("tags", "input")
|
||||
data.add_field("name", "Empty File Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
async def test_upload_invalid_tags_missing_root_returns_400(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("file", test_image_bytes, filename="test.png", content_type="image/png")
|
||||
data.add_field("tags", "invalid_root_tag")
|
||||
data.add_field("name", "Invalid Tags Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
async def test_upload_hash_mismatch_returns_400(self, aiohttp_client, app, test_image_bytes, tmp_upload_dir):
|
||||
with patch("app.assets.manager.asset_exists", return_value=False):
|
||||
with patch("app.assets.manager.upload_asset_from_temp_path") as mock_upload:
|
||||
mock_upload.side_effect = ValueError("HASH_MISMATCH")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
data = FormData()
|
||||
data.add_field("hash", "blake3:" + "b" * 64)
|
||||
data.add_field("file", test_image_bytes, filename="test.png", content_type="image/png")
|
||||
data.add_field("tags", "input")
|
||||
data.add_field("name", "Hash Mismatch Asset")
|
||||
|
||||
resp = await client.post("/api/assets", data=data)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "HASH_MISMATCH"
|
||||
|
||||
async def test_upload_non_multipart_returns_415(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post("/api/assets", json={"name": "test"})
|
||||
assert resp.status == 415
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
class TestCreateFromHash:
|
||||
"""Tests for POST /api/assets/from-hash."""
|
||||
|
||||
async def test_create_from_hash_success(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.create_asset_from_hash") as mock_create:
|
||||
mock_result = MagicMock()
|
||||
mock_result.model_dump.return_value = {
|
||||
"id": "33333333-3333-3333-3333-333333333333",
|
||||
"name": "Created From Hash",
|
||||
"tags": ["input"],
|
||||
}
|
||||
mock_create.return_value = mock_result
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post("/api/assets/from-hash", json={
|
||||
"hash": "blake3:" + "c" * 64,
|
||||
"name": "Created From Hash",
|
||||
"tags": ["input"],
|
||||
})
|
||||
assert resp.status == 201
|
||||
body = await resp.json()
|
||||
assert body["id"] == "33333333-3333-3333-3333-333333333333"
|
||||
assert body["name"] == "Created From Hash"
|
||||
|
||||
async def test_create_from_hash_unknown_hash_returns_404(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.create_asset_from_hash", return_value=None):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post("/api/assets/from-hash", json={
|
||||
"hash": "blake3:" + "d" * 64,
|
||||
"name": "Unknown Hash",
|
||||
"tags": ["input"],
|
||||
})
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
async def test_create_from_hash_invalid_hash_format_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post("/api/assets/from-hash", json={
|
||||
"hash": "invalid_hash_no_colon",
|
||||
"name": "Invalid Hash",
|
||||
"tags": ["input"],
|
||||
})
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
async def test_create_from_hash_missing_name_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post("/api/assets/from-hash", json={
|
||||
"hash": "blake3:" + "e" * 64,
|
||||
"tags": ["input"],
|
||||
})
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
async def test_create_from_hash_invalid_json_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/api/assets/from-hash",
|
||||
data="not valid json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert body["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
class TestHeadAssetByHash:
|
||||
"""Tests for HEAD /api/assets/hash/{hash}."""
|
||||
|
||||
async def test_head_existing_hash_returns_200(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.asset_exists", return_value=True):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/blake3:" + "f" * 64)
|
||||
assert resp.status == 200
|
||||
|
||||
async def test_head_missing_hash_returns_404(self, aiohttp_client, app):
|
||||
with patch("app.assets.manager.asset_exists", return_value=False):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/blake3:" + "0" * 64)
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_head_invalid_hash_no_colon_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/invalidhashwithoutcolon")
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_head_invalid_hash_wrong_algo_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/sha256:" + "a" * 64)
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_head_invalid_hash_non_hex_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/blake3:zzzz")
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_head_empty_hash_returns_400(self, aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.head("/api/assets/hash/blake3:")
|
||||
assert resp.status == 400
|
||||
509
tests-unit/assets_test/schemas_test.py
Normal file
509
tests-unit/assets_test/schemas_test.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Comprehensive tests for Pydantic schemas in the assets API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.assets.api.schemas_in import (
|
||||
ListAssetsQuery,
|
||||
UpdateAssetBody,
|
||||
CreateFromHashBody,
|
||||
UploadAssetSpec,
|
||||
SetPreviewBody,
|
||||
TagsAdd,
|
||||
TagsRemove,
|
||||
TagsListQuery,
|
||||
ScheduleAssetScanBody,
|
||||
)
|
||||
|
||||
|
||||
class TestListAssetsQuery:
|
||||
def test_defaults(self):
|
||||
q = ListAssetsQuery()
|
||||
assert q.limit == 20
|
||||
assert q.offset == 0
|
||||
assert q.sort == "created_at"
|
||||
assert q.order == "desc"
|
||||
assert q.include_tags == []
|
||||
assert q.exclude_tags == []
|
||||
assert q.name_contains is None
|
||||
assert q.metadata_filter is None
|
||||
|
||||
def test_csv_tags_parsing_string(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_csv_tags_parsing_with_whitespace(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": " a , b , c "})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_csv_tags_parsing_list(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": ["a", "b", "c"]})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_csv_tags_parsing_list_with_csv(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": ["a,b", "c"]})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_csv_tags_exclude_tags(self):
|
||||
q = ListAssetsQuery.model_validate({"exclude_tags": "x,y,z"})
|
||||
assert q.exclude_tags == ["x", "y", "z"]
|
||||
|
||||
def test_csv_tags_empty_string(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": ""})
|
||||
assert q.include_tags == []
|
||||
|
||||
def test_csv_tags_none(self):
|
||||
q = ListAssetsQuery.model_validate({"include_tags": None})
|
||||
assert q.include_tags == []
|
||||
|
||||
def test_metadata_filter_json_string(self):
|
||||
q = ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
|
||||
assert q.metadata_filter == {"key": "value"}
|
||||
|
||||
def test_metadata_filter_dict(self):
|
||||
q = ListAssetsQuery.model_validate({"metadata_filter": {"key": "value"}})
|
||||
assert q.metadata_filter == {"key": "value"}
|
||||
|
||||
def test_metadata_filter_none(self):
|
||||
q = ListAssetsQuery.model_validate({"metadata_filter": None})
|
||||
assert q.metadata_filter is None
|
||||
|
||||
def test_metadata_filter_empty_string(self):
|
||||
q = ListAssetsQuery.model_validate({"metadata_filter": ""})
|
||||
assert q.metadata_filter is None
|
||||
|
||||
def test_metadata_filter_invalid_json(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ListAssetsQuery.model_validate({"metadata_filter": "not json"})
|
||||
assert "must be JSON" in str(exc_info.value)
|
||||
|
||||
def test_metadata_filter_non_object_json(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ListAssetsQuery.model_validate({"metadata_filter": "[1, 2, 3]"})
|
||||
assert "must be a JSON object" in str(exc_info.value)
|
||||
|
||||
def test_limit_bounds_min(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ListAssetsQuery.model_validate({"limit": 0})
|
||||
|
||||
def test_limit_bounds_max(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ListAssetsQuery.model_validate({"limit": 501})
|
||||
|
||||
def test_limit_bounds_valid(self):
|
||||
q = ListAssetsQuery.model_validate({"limit": 500})
|
||||
assert q.limit == 500
|
||||
|
||||
def test_offset_bounds_min(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ListAssetsQuery.model_validate({"offset": -1})
|
||||
|
||||
def test_sort_enum_valid(self):
|
||||
for sort_val in ["name", "created_at", "updated_at", "size", "last_access_time"]:
|
||||
q = ListAssetsQuery.model_validate({"sort": sort_val})
|
||||
assert q.sort == sort_val
|
||||
|
||||
def test_sort_enum_invalid(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ListAssetsQuery.model_validate({"sort": "invalid"})
|
||||
|
||||
def test_order_enum_valid(self):
|
||||
for order_val in ["asc", "desc"]:
|
||||
q = ListAssetsQuery.model_validate({"order": order_val})
|
||||
assert q.order == order_val
|
||||
|
||||
def test_order_enum_invalid(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ListAssetsQuery.model_validate({"order": "invalid"})
|
||||
|
||||
|
||||
class TestUpdateAssetBody:
|
||||
def test_requires_at_least_one_field(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UpdateAssetBody.model_validate({})
|
||||
assert "at least one of" in str(exc_info.value)
|
||||
|
||||
def test_name_only(self):
|
||||
body = UpdateAssetBody.model_validate({"name": "new_name"})
|
||||
assert body.name == "new_name"
|
||||
assert body.tags is None
|
||||
assert body.user_metadata is None
|
||||
|
||||
def test_tags_only(self):
|
||||
body = UpdateAssetBody.model_validate({"tags": ["tag1", "tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_user_metadata_only(self):
|
||||
body = UpdateAssetBody.model_validate({"user_metadata": {"key": "value"}})
|
||||
assert body.user_metadata == {"key": "value"}
|
||||
|
||||
def test_tags_must_be_list_of_strings(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UpdateAssetBody.model_validate({"tags": "not_a_list"})
|
||||
assert "list" in str(exc_info.value).lower()
|
||||
|
||||
def test_tags_must_contain_strings(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UpdateAssetBody.model_validate({"tags": [1, 2, 3]})
|
||||
assert "string" in str(exc_info.value).lower()
|
||||
|
||||
def test_multiple_fields(self):
|
||||
body = UpdateAssetBody.model_validate({
|
||||
"name": "new_name",
|
||||
"tags": ["tag1"],
|
||||
"user_metadata": {"foo": "bar"}
|
||||
})
|
||||
assert body.name == "new_name"
|
||||
assert body.tags == ["tag1"]
|
||||
assert body.user_metadata == {"foo": "bar"}
|
||||
|
||||
|
||||
class TestCreateFromHashBody:
|
||||
def test_valid_blake3(self):
|
||||
body = CreateFromHashBody(
|
||||
hash="blake3:" + "a" * 64,
|
||||
name="test"
|
||||
)
|
||||
assert body.hash.startswith("blake3:")
|
||||
assert body.name == "test"
|
||||
|
||||
def test_valid_blake3_lowercase(self):
|
||||
body = CreateFromHashBody(
|
||||
hash="BLAKE3:" + "A" * 64,
|
||||
name="test"
|
||||
)
|
||||
assert body.hash == "blake3:" + "a" * 64
|
||||
|
||||
def test_rejects_sha256(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CreateFromHashBody(hash="sha256:" + "a" * 64, name="test")
|
||||
assert "blake3" in str(exc_info.value).lower()
|
||||
|
||||
def test_rejects_no_colon(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CreateFromHashBody(hash="a" * 64, name="test")
|
||||
assert "blake3:<hex>" in str(exc_info.value)
|
||||
|
||||
def test_rejects_invalid_hex(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CreateFromHashBody(hash="blake3:" + "g" * 64, name="test")
|
||||
assert "hex" in str(exc_info.value).lower()
|
||||
|
||||
def test_rejects_empty_digest(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CreateFromHashBody(hash="blake3:", name="test")
|
||||
assert "hex" in str(exc_info.value).lower()
|
||||
|
||||
def test_default_tags_empty(self):
|
||||
body = CreateFromHashBody(hash="blake3:" + "a" * 64, name="test")
|
||||
assert body.tags == []
|
||||
|
||||
def test_default_user_metadata_empty(self):
|
||||
body = CreateFromHashBody(hash="blake3:" + "a" * 64, name="test")
|
||||
assert body.user_metadata == {}
|
||||
|
||||
def test_tags_normalized_lowercase(self):
|
||||
body = CreateFromHashBody(
|
||||
hash="blake3:" + "a" * 64,
|
||||
name="test",
|
||||
tags=["TAG1", "Tag2"]
|
||||
)
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_tags_deduplicated(self):
|
||||
body = CreateFromHashBody(
|
||||
hash="blake3:" + "a" * 64,
|
||||
name="test",
|
||||
tags=["tag", "TAG", "tag"]
|
||||
)
|
||||
assert body.tags == ["tag"]
|
||||
|
||||
def test_tags_csv_parsing(self):
|
||||
body = CreateFromHashBody(
|
||||
hash="blake3:" + "a" * 64,
|
||||
name="test",
|
||||
tags="a,b,c"
|
||||
)
|
||||
assert body.tags == ["a", "b", "c"]
|
||||
|
||||
def test_whitespace_stripping(self):
|
||||
body = CreateFromHashBody(
|
||||
hash=" blake3:" + "a" * 64 + " ",
|
||||
name=" test "
|
||||
)
|
||||
assert body.hash == "blake3:" + "a" * 64
|
||||
assert body.name == "test"
|
||||
|
||||
|
||||
class TestUploadAssetSpec:
|
||||
def test_first_tag_must_be_root_type_models(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["models", "loras"]})
|
||||
assert spec.tags[0] == "models"
|
||||
|
||||
def test_first_tag_must_be_root_type_input(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["input"]})
|
||||
assert spec.tags[0] == "input"
|
||||
|
||||
def test_first_tag_must_be_root_type_output(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["output"]})
|
||||
assert spec.tags[0] == "output"
|
||||
|
||||
def test_rejects_invalid_first_tag(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UploadAssetSpec.model_validate({"tags": ["invalid"]})
|
||||
assert "models, input, output" in str(exc_info.value)
|
||||
|
||||
def test_models_requires_category_tag(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UploadAssetSpec.model_validate({"tags": ["models"]})
|
||||
assert "category tag" in str(exc_info.value)
|
||||
|
||||
def test_input_does_not_require_second_tag(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["input"]})
|
||||
assert spec.tags == ["input"]
|
||||
|
||||
def test_output_does_not_require_second_tag(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["output"]})
|
||||
assert spec.tags == ["output"]
|
||||
|
||||
def test_tags_empty_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UploadAssetSpec.model_validate({"tags": []})
|
||||
|
||||
def test_tags_csv_parsing(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": "models,loras"})
|
||||
assert spec.tags == ["models", "loras"]
|
||||
|
||||
def test_tags_json_array_parsing(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": '["models", "loras"]'})
|
||||
assert spec.tags == ["models", "loras"]
|
||||
|
||||
def test_tags_normalized_lowercase(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["MODELS", "LORAS"]})
|
||||
assert spec.tags == ["models", "loras"]
|
||||
|
||||
def test_tags_deduplicated(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["models", "loras", "models"]})
|
||||
assert spec.tags == ["models", "loras"]
|
||||
|
||||
def test_hash_validation_valid_blake3(self):
|
||||
spec = UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"hash": "blake3:" + "a" * 64
|
||||
})
|
||||
assert spec.hash == "blake3:" + "a" * 64
|
||||
|
||||
def test_hash_validation_rejects_sha256(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"hash": "sha256:" + "a" * 64
|
||||
})
|
||||
|
||||
def test_hash_none_allowed(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["input"], "hash": None})
|
||||
assert spec.hash is None
|
||||
|
||||
def test_hash_empty_string_becomes_none(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["input"], "hash": ""})
|
||||
assert spec.hash is None
|
||||
|
||||
def test_name_optional(self):
|
||||
spec = UploadAssetSpec.model_validate({"tags": ["input"]})
|
||||
assert spec.name is None
|
||||
|
||||
def test_name_max_length(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"name": "x" * 513
|
||||
})
|
||||
|
||||
def test_user_metadata_json_string(self):
|
||||
spec = UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"user_metadata": '{"key": "value"}'
|
||||
})
|
||||
assert spec.user_metadata == {"key": "value"}
|
||||
|
||||
def test_user_metadata_dict(self):
|
||||
spec = UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"user_metadata": {"key": "value"}
|
||||
})
|
||||
assert spec.user_metadata == {"key": "value"}
|
||||
|
||||
def test_user_metadata_empty_string(self):
|
||||
spec = UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"user_metadata": ""
|
||||
})
|
||||
assert spec.user_metadata == {}
|
||||
|
||||
def test_user_metadata_invalid_json(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UploadAssetSpec.model_validate({
|
||||
"tags": ["input"],
|
||||
"user_metadata": "not json"
|
||||
})
|
||||
assert "must be JSON" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestSetPreviewBody:
|
||||
def test_valid_uuid(self):
|
||||
body = SetPreviewBody.model_validate({"preview_id": "550e8400-e29b-41d4-a716-446655440000"})
|
||||
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_none_allowed(self):
|
||||
body = SetPreviewBody.model_validate({"preview_id": None})
|
||||
assert body.preview_id is None
|
||||
|
||||
def test_empty_string_becomes_none(self):
|
||||
body = SetPreviewBody.model_validate({"preview_id": ""})
|
||||
assert body.preview_id is None
|
||||
|
||||
def test_whitespace_only_becomes_none(self):
|
||||
body = SetPreviewBody.model_validate({"preview_id": " "})
|
||||
assert body.preview_id is None
|
||||
|
||||
def test_invalid_uuid(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SetPreviewBody.model_validate({"preview_id": "not-a-uuid"})
|
||||
assert "UUID" in str(exc_info.value)
|
||||
|
||||
def test_default_is_none(self):
|
||||
body = SetPreviewBody.model_validate({})
|
||||
assert body.preview_id is None
|
||||
|
||||
|
||||
class TestTagsAdd:
|
||||
def test_non_empty_required(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsAdd.model_validate({"tags": []})
|
||||
|
||||
def test_valid_tags(self):
|
||||
body = TagsAdd.model_validate({"tags": ["tag1", "tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_tags_normalized_lowercase(self):
|
||||
body = TagsAdd.model_validate({"tags": ["TAG1", "Tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_tags_whitespace_stripped(self):
|
||||
body = TagsAdd.model_validate({"tags": [" tag1 ", " tag2 "]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_tags_deduplicated(self):
|
||||
body = TagsAdd.model_validate({"tags": ["tag", "TAG", "tag"]})
|
||||
assert body.tags == ["tag"]
|
||||
|
||||
def test_empty_strings_filtered(self):
|
||||
body = TagsAdd.model_validate({"tags": ["tag1", "", " ", "tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_missing_tags_field_fails(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsAdd.model_validate({})
|
||||
|
||||
|
||||
class TestTagsRemove:
|
||||
def test_non_empty_required(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsRemove.model_validate({"tags": []})
|
||||
|
||||
def test_valid_tags(self):
|
||||
body = TagsRemove.model_validate({"tags": ["tag1", "tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
def test_inherits_normalization(self):
|
||||
body = TagsRemove.model_validate({"tags": ["TAG1", "Tag2"]})
|
||||
assert body.tags == ["tag1", "tag2"]
|
||||
|
||||
|
||||
class TestTagsListQuery:
|
||||
def test_defaults(self):
|
||||
q = TagsListQuery()
|
||||
assert q.prefix is None
|
||||
assert q.limit == 100
|
||||
assert q.offset == 0
|
||||
assert q.order == "count_desc"
|
||||
assert q.include_zero is True
|
||||
|
||||
def test_prefix_normalized_lowercase(self):
|
||||
q = TagsListQuery.model_validate({"prefix": "PREFIX"})
|
||||
assert q.prefix == "prefix"
|
||||
|
||||
def test_prefix_whitespace_stripped(self):
|
||||
q = TagsListQuery.model_validate({"prefix": " prefix "})
|
||||
assert q.prefix == "prefix"
|
||||
|
||||
def test_prefix_whitespace_only_fails_min_length(self):
|
||||
# After stripping, whitespace-only prefix becomes empty, which fails min_length=1
|
||||
# The min_length check happens before the normalizer can return None
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"prefix": " "})
|
||||
|
||||
def test_prefix_min_length(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"prefix": ""})
|
||||
|
||||
def test_prefix_max_length(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"prefix": "x" * 257})
|
||||
|
||||
def test_limit_bounds_min(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"limit": 0})
|
||||
|
||||
def test_limit_bounds_max(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"limit": 1001})
|
||||
|
||||
def test_limit_bounds_valid(self):
|
||||
q = TagsListQuery.model_validate({"limit": 1000})
|
||||
assert q.limit == 1000
|
||||
|
||||
def test_offset_bounds_min(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"offset": -1})
|
||||
|
||||
def test_offset_bounds_max(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"offset": 10_000_001})
|
||||
|
||||
def test_order_valid_values(self):
|
||||
for order_val in ["count_desc", "name_asc"]:
|
||||
q = TagsListQuery.model_validate({"order": order_val})
|
||||
assert q.order == order_val
|
||||
|
||||
def test_order_invalid(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagsListQuery.model_validate({"order": "invalid"})
|
||||
|
||||
def test_include_zero_bool(self):
|
||||
q = TagsListQuery.model_validate({"include_zero": False})
|
||||
assert q.include_zero is False
|
||||
|
||||
|
||||
class TestScheduleAssetScanBody:
|
||||
def test_valid_roots(self):
|
||||
body = ScheduleAssetScanBody.model_validate({"roots": ["models"]})
|
||||
assert body.roots == ["models"]
|
||||
|
||||
def test_multiple_roots(self):
|
||||
body = ScheduleAssetScanBody.model_validate({"roots": ["models", "input", "output"]})
|
||||
assert body.roots == ["models", "input", "output"]
|
||||
|
||||
def test_empty_roots_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleAssetScanBody.model_validate({"roots": []})
|
||||
|
||||
def test_invalid_root_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleAssetScanBody.model_validate({"roots": ["invalid"]})
|
||||
|
||||
def test_missing_roots_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleAssetScanBody.model_validate({})
|
||||
Reference in New Issue
Block a user