From 401b9109369f4b1bf78d39bfb7a7866457cb6c85 Mon Sep 17 00:00:00 2001 From: Michael Hedman Date: Mon, 19 May 2025 07:57:05 +0200 Subject: [PATCH] missed a dtype cleanup immediate surroundings --- acestep/pipeline_ace_step.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 3d4fc41..1a2ec1a 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -702,16 +702,13 @@ class ACEStepPipeline: attention_mask=attention_mask, momentum_buffer=momentum_buffer, ) - V_delta_avg += (1 / n_avg) * ( - Vt_tar - Vt_src - ) # - (hfg-1)*( x_src)) + V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg - 1) * (x_src) - zt_edit = zt_edit.to(torch.float32) + zt_edit = zt_edit.to(torch.float32) # arbitrary, should be settable for compatibility if scheduler_type != "pingpong": # propagate direct ODE - zt_edit = zt_edit.to(torch.float32) zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg - zt_edit = zt_edit.to(V_delta_avg.dtype) + zt_edit = zt_edit.to(self.dtype) else: # propagate pingpong SDE zt_edit_denoised = zt_edit - t_i * V_delta_avg