Skip to content

Commit

Permalink
Merge pull request #2387 from bmaltais/2370-2406-train-toml-config-se…
Browse files Browse the repository at this point in the history
…ed-type-error

Fix [24.0.6] Train toml config seed type error #2370
  • Loading branch information
bmaltais authored Apr 25, 2024
2 parents 7e6d805 + 433fabf commit d7e39c3
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 46 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
### 2024/04/25 (v24.0.7)

- Prevent crash if tkinter is not installed
- Fix [24.0.6] Train toml config seed type error #2370

### 2024/04/22 (v24.0.6)

Expand Down
14 changes: 7 additions & 7 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def train_model(
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_skip": clip_skip if clip_skip != 0 else None,
Expand Down Expand Up @@ -727,16 +727,16 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": max_bucket_reso,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": max_train_epochs if max_train_epochs != 0 else None,
"max_train_steps": max_train_steps if max_train_steps != 0 else None,
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
"metadata_description": metadata_description,
Expand All @@ -761,7 +761,7 @@ def train_model(
"optimizer_type": optimizer,
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"persistent_data_loader_workers": int(persistent_data_loader_workers),
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"prior_loss_weight": prior_loss_weight,
"random_crop": random_crop,
Expand Down Expand Up @@ -792,7 +792,7 @@ def train_model(
"save_state_to_huggingface": save_state_to_huggingface,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"sdpa": True if xformers == "sdpa" else None,
"seed": seed if seed != 0 else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
Expand All @@ -818,7 +818,7 @@ def train_model(
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))
Expand Down
12 changes: 6 additions & 6 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def train_model(
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"cache_text_encoder_outputs": cache_text_encoder_outputs,
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_skip": clip_skip if clip_skip != 0 else None,
Expand Down Expand Up @@ -812,8 +812,8 @@ def train_model(
"max_bucket_reso": int(max_bucket_reso),
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": max_train_epochs if max_train_epochs != 0 else None,
"max_train_steps": max_train_steps if max_train_steps != 0 else None,
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
"metadata_description": metadata_description,
Expand All @@ -836,7 +836,7 @@ def train_model(
"optimizer_args": str(optimizer_args).replace('"', "").split(),
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"persistent_data_loader_workers": int(persistent_data_loader_workers),
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"random_crop": random_crop,
"resolution": max_resolution,
Expand Down Expand Up @@ -865,7 +865,7 @@ def train_model(
"save_state_to_huggingface": save_state_to_huggingface,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"sdpa": True if xformers == "sdpa" else None,
"seed": seed if seed != 0 else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"train_batch_size": train_batch_size,
"train_data_dir": image_folder,
Expand All @@ -889,7 +889,7 @@ def train_model(
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))
Expand Down
14 changes: 7 additions & 7 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def train_model(
"cache_text_encoder_outputs": (
True if sdxl and sdxl_cache_text_encoder_outputs else None
),
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_skip": clip_skip if clip_skip != 0 else None,
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
Expand All @@ -1088,8 +1088,8 @@ def train_model(
"max_grad_norm": max_grad_norm,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": max_train_epochs if max_train_epochs != 0 else None,
"max_train_steps": max_train_steps if max_train_steps != 0 else None,
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
"metadata_description": metadata_description,
Expand Down Expand Up @@ -1120,7 +1120,7 @@ def train_model(
"optimizer_args": str(optimizer_args).replace('"', "").split(),
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"persistent_data_loader_workers": int(persistent_data_loader_workers),
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"prior_loss_weight": prior_loss_weight,
"random_crop": random_crop,
Expand Down Expand Up @@ -1152,7 +1152,7 @@ def train_model(
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"scale_weight_norms": scale_weight_norms,
"sdpa": True if xformers == "sdpa" else None,
"seed": seed if seed != 0 else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def train_model(
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))
Expand Down
14 changes: 7 additions & 7 deletions kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def train_model(
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_extension": caption_extension,
"clip_skip": clip_skip if clip_skip != 0 else None,
"color_aug": color_aug,
Expand Down Expand Up @@ -743,15 +743,15 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"max_bucket_reso": max_bucket_reso,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": max_train_epochs if max_train_epochs != 0 else None,
"max_train_steps": max_train_steps if max_train_steps != 0 else None,
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
"metadata_description": metadata_description,
Expand All @@ -776,7 +776,7 @@ def train_model(
"optimizer_args": str(optimizer_args).replace('"', "").split(),
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"persistent_data_loader_workers": int(persistent_data_loader_workers),
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"prior_loss_weight": prior_loss_weight,
"random_crop": random_crop,
Expand Down Expand Up @@ -807,7 +807,7 @@ def train_model(
"save_state_to_huggingface": save_state_to_huggingface,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"sdpa": True if xformers == "sdpa" else None,
"seed": seed if seed != 0 else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
Expand Down Expand Up @@ -837,7 +837,7 @@ def train_model(
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))
Expand Down
Loading

0 comments on commit d7e39c3

Please sign in to comment.