@@ -5829,8 +5829,8 @@ def save_sd_model_on_train_end_common(
58295829
58305830
58315831def get_timesteps (min_timestep , max_timestep , b_size , device ):
5832- timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = device )
5833- timesteps = timesteps .long ()
5832+ timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5833+ timesteps = timesteps .long (). to ( device )
58345834 return timesteps
58355835
58365836
@@ -5875,8 +5875,8 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch
58755875 alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
58765876 result = torch .exp (- alpha * timesteps ) * args .huber_scale
58775877 elif args .huber_schedule == "snr" :
5878- if not hasattr (noise_scheduler , ' alphas_cumprod' ):
5879- raise NotImplementedError (f "Huber schedule 'snr' is not supported with the current model." )
5878+ if not hasattr (noise_scheduler , " alphas_cumprod" ):
5879+ raise NotImplementedError ("Huber schedule 'snr' is not supported with the current model." )
58805880 alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps .cpu ())
58815881 sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
58825882 result = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
0 commit comments