mirror of
https://github.com/Comfy-Org/ComfyUI.git
synced 2026-03-06 16:39:06 +00:00
Compare commits
3 Commits
painter-no
...
curve-node
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c11d5f75f | ||
|
|
b6ddc590ed | ||
|
|
f719a9d928 |
@@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage):
|
||||
latent_format = latent_formats.ZImagePixelSpace
|
||||
|
||||
# Much lower memory than latent-space models (no VAE, small patches).
|
||||
memory_usage_factor = 0.05 # TODO: figure out the optimal value for this.
|
||||
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
@@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2226,5 +2239,6 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -1049,6 +1049,48 @@ class ManualSigmas(io.ComfyNode):
|
||||
sigmas = torch.FloatTensor(sigmas)
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class CurveToSigmas(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveToSigmas",
|
||||
display_name="Curve to Sigmas",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Curve.Input("curve", default=[[0.0, 1.0], [1.0, 0.0]]),
|
||||
io.Model.Input("model", optional=True),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, optional=True),
|
||||
],
|
||||
outputs=[io.Sigmas.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, steps, sigma_max=14.614642, model=None) -> io.NodeOutput:
|
||||
points = sorted(curve, key=lambda p: p[0])
|
||||
model_sampling = model.get_model_object("model_sampling") if model is not None else None
|
||||
|
||||
sigmas = []
|
||||
for i in range(steps + 1):
|
||||
t = i / steps
|
||||
y = points[0][1] if t < points[0][0] else points[-1][1]
|
||||
for j in range(len(points) - 1):
|
||||
if points[j][0] <= t <= points[j + 1][0]:
|
||||
x0, y0 = points[j]
|
||||
x1, y1 = points[j + 1]
|
||||
y = y0 if x1 == x0 else y0 + (y1 - y0) * (t - x0) / (x1 - x0)
|
||||
break
|
||||
if model_sampling is not None:
|
||||
sigmas.append(float(model_sampling.percent_to_sigma(1.0 - y)))
|
||||
else:
|
||||
sigmas.append(y * sigma_max)
|
||||
|
||||
sigmas[-1] = 0.0
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
get_sigmas = execute
|
||||
|
||||
|
||||
class CustomSamplersExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -1088,6 +1130,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
AddNoise,
|
||||
SamplerCustomAdvanced,
|
||||
ManualSigmas,
|
||||
CurveToSigmas,
|
||||
]
|
||||
|
||||
|
||||
|
||||
14
execution.py
14
execution.py
@@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
# Unwraps values wrapped in __value__ key or typed wrapper.
|
||||
# This is used to pass list widget values to execution,
|
||||
# as by default list value is reserved to represent the
|
||||
# connection between nodes.
|
||||
if isinstance(val, dict):
|
||||
if "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if input_type == "INT":
|
||||
val = int(val)
|
||||
|
||||
20
nodes.py
20
nodes.py
@@ -2034,6 +2034,24 @@ class ImagePadForOutpaint:
|
||||
return (new_image, mask.unsqueeze(0))
|
||||
|
||||
|
||||
class CurveEditor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CURVE",)
|
||||
RETURN_NAMES = ("curve",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils"
|
||||
|
||||
def execute(self, curve):
|
||||
return (curve,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"KSampler": KSampler,
|
||||
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||
@@ -2102,6 +2120,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
||||
"CurveEditor": CurveEditor,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -2170,6 +2189,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
"CurveEditor": "Curve Editor",
|
||||
}
|
||||
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
Reference in New Issue
Block a user