Skip to content

Commit

Permalink
fix: BioSR dataloder and pathing of the project.
Browse files Browse the repository at this point in the history
  • Loading branch information
nanoxas committed Oct 28, 2024
1 parent 7e52f60 commit c745ab3
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 101 deletions.
55 changes: 0 additions & 55 deletions cvdm/utils/metrics_utils.py

This file was deleted.

File renamed without changes.
46 changes: 0 additions & 46 deletions scripts/train.py → train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,11 @@ def main() -> None:
print("Getting data...")
batch_size = data_config.batch_size
dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=True)
val_dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=False)

dataset = dataset.shuffle(5000, reshuffle_each_iteration=False)

val_len = eval_config.val_len
val_dataset = val_dataset.take(val_len)
dataset = dataset.skip(val_len)

dataset = dataset.batch(batch_size, drop_remainder=True)
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)

epochs = training_config.epochs
generation_timesteps = eval_config.generation_timesteps
Expand Down Expand Up @@ -107,7 +102,6 @@ def main() -> None:
log_freq = eval_config.log_freq
checkpoint_freq = eval_config.checkpoint_freq
image_freq = eval_config.image_freq
val_freq = eval_config.val_freq
output_path = eval_config.output_path
diff_inp = model_config.diff_inp

Expand Down Expand Up @@ -161,46 +155,6 @@ def main() -> None:
prefix="train",
cmap=cmap,
)

if step % val_freq == 0:
if model_config.zmd:
val_loss = np.zeros(6)
else:
val_loss = np.zeros(5)
for batch in val_dataset:
batch_x, batch_y = batch
model_input = prepare_model_input(
batch_x, batch_y, diff_inp=diff_inp
)
val_loss += joint_model.evaluate(
model_input, np.zeros_like(batch_y), verbose=0
)

log_loss(run=run, avg_loss=val_loss, prefix="val")
# To speed up, images are only generated and metrics are calculated only for one batch.
random_batch = val_dataset.take(1)
for batch_x, batch_y in random_batch:
output_montage, metrics = obtain_output_montage_and_metrics(
batch_x,
batch_y.numpy(),
noise_model,
schedule_model,
mu_model,
generation_timesteps,
diff_inp,
task,
)
log_metrics(run, metrics, prefix="val")
save_output_montage(
run=run,
output_montage=output_montage,
step=step,
output_path=output_path,
run_id=run_id,
prefix="val",
cmap=cmap,
)

step += 1

if run is not None:
Expand Down

0 comments on commit c745ab3

Please sign in to comment.