diff --git a/cvdm/utils/metrics_utils.py b/cvdm/utils/metrics_utils.py deleted file mode 100644 index 5e9ce81..0000000 --- a/cvdm/utils/metrics_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Dict, Optional - -import numpy as np -from skimage.metrics import peak_signal_noise_ratio, structural_similarity - - -def nmae(y_pred: np.ndarray, y_real: np.ndarray) -> float: - nmae: float = np.sqrt(np.sum((y_pred - y_real) ** 2)) / np.sqrt(np.sum(y_real**2)) - return nmae - - -def calculate_metrics( - y_pred_batch: np.ndarray, - y_real_batch: np.ndarray, -) -> Dict[str, float]: - y_pred_batch = np.array(y_pred_batch) - y_real_batch = np.array(y_real_batch) - - if y_pred_batch.shape[3] > 1: - channel_axis = 2 - else: - channel_axis = None - - metrics = { - "mse": np.mean((y_pred_batch - y_real_batch) ** 2), - "mape": np.mean(np.abs((y_real_batch - y_pred_batch) / y_real_batch + 1e-10)) - * 100, - "nmae": np.mean( - [ - np.mean(np.abs(y_pred - y_real)) / np.mean(np.abs(y_real) + 1e-10) - for y_pred, y_real in zip(y_pred_batch, y_real_batch) - ] - ), - "psnr": np.mean( - [ - peak_signal_noise_ratio( - np.squeeze(y_pred), np.squeeze(y_real), data_range=2 - ) - for y_pred, y_real in zip(y_pred_batch, y_real_batch) - ] - ), - "ssim": np.mean( - [ - structural_similarity( - np.squeeze(y_pred), - np.squeeze(y_real), - data_range=2, - channel_axis=channel_axis, - ) - for y_pred, y_real in zip(y_pred_batch, y_real_batch) - ] - ), - } - - return metrics diff --git a/scripts/eval.py b/eval.py similarity index 100% rename from scripts/eval.py rename to eval.py diff --git a/scripts/train.py b/train.py similarity index 73% rename from scripts/train.py rename to train.py index d94b85a..779e39f 100644 --- a/scripts/train.py +++ b/train.py @@ -67,16 +67,11 @@ def main() -> None: print("Getting data...") batch_size = data_config.batch_size dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=True) - val_dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=False) - dataset = dataset.shuffle(5000, reshuffle_each_iteration=False) - val_len = eval_config.val_len - val_dataset = val_dataset.take(val_len) dataset = dataset.skip(val_len) dataset = dataset.batch(batch_size, drop_remainder=True) - val_dataset = val_dataset.batch(batch_size, drop_remainder=True) epochs = training_config.epochs generation_timesteps = eval_config.generation_timesteps @@ -107,7 +102,6 @@ def main() -> None: log_freq = eval_config.log_freq checkpoint_freq = eval_config.checkpoint_freq image_freq = eval_config.image_freq - val_freq = eval_config.val_freq output_path = eval_config.output_path diff_inp = model_config.diff_inp @@ -161,46 +155,6 @@ def main() -> None: prefix="train", cmap=cmap, ) - - if step % val_freq == 0: - if model_config.zmd: - val_loss = np.zeros(6) - else: - val_loss = np.zeros(5) - for batch in val_dataset: - batch_x, batch_y = batch - model_input = prepare_model_input( - batch_x, batch_y, diff_inp=diff_inp - ) - val_loss += joint_model.evaluate( - model_input, np.zeros_like(batch_y), verbose=0 - ) - - log_loss(run=run, avg_loss=val_loss, prefix="val") - # To speed up, images are only generated and metrics are calculated only for one batch. - random_batch = val_dataset.take(1) - for batch_x, batch_y in random_batch: - output_montage, metrics = obtain_output_montage_and_metrics( - batch_x, - batch_y.numpy(), - noise_model, - schedule_model, - mu_model, - generation_timesteps, - diff_inp, - task, - ) - log_metrics(run, metrics, prefix="val") - save_output_montage( - run=run, - output_montage=output_montage, - step=step, - output_path=output_path, - run_id=run_id, - prefix="val", - cmap=cmap, - ) - step += 1 if run is not None: