From 6907448f3f0e06bedf2f89135541d366d63faac5 Mon Sep 17 00:00:00 2001 From: chuxij Date: Thu, 1 May 2025 10:17:27 +0000 Subject: [PATCH] fix some bugs --- pipeline_ace_step.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline_ace_step.py b/pipeline_ace_step.py index d1d2891..2fae377 100644 --- a/pipeline_ace_step.py +++ b/pipeline_ace_step.py @@ -566,7 +566,8 @@ class ACEStepPipeline: repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) - z0 = repaint_noise + zt_edit = x0.clone() + z0 = target_latents elif is_extend: to_right_pad_gt_latents = None to_left_pad_gt_latents = None @@ -615,9 +616,8 @@ class ACEStepPipeline: padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) target_latents = torch.cat(padd_list, dim=-1) assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" - - zt_edit = x0.clone() - z0 = target_latents + zt_edit = x0.clone() + z0 = target_latents attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)