From fd9fdfd6d90c4ffc370ef4f32a9377b5b207b454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ter=C3=A9zia=20Slanin=C3=A1kov=C3=A1?= <445526@mail.muni.cz> Date: Mon, 30 Sep 2024 20:03:00 +0200 Subject: [PATCH 1/2] fix model params not parsed if model_path empty --- training/create-buckets.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/training/create-buckets.py b/training/create-buckets.py index 619744c..c3982cc 100644 --- a/training/create-buckets.py +++ b/training/create-buckets.py @@ -7,9 +7,8 @@ import numpy as np import pandas as pd import torch -from tqdm import tqdm - from model import LIDatasetPredict, load_model +from tqdm import tqdm from utils import ( create_dir, dir_exists, @@ -42,16 +41,22 @@ def load_all_embeddings(path): def parse_model_params(model_path): LOG.info(f'Parsing out model params from model path: {model_path}') pattern = r'model-(\w+)--.*?n_classes-(\d+)(?:--.*?dimensionality-(\d+))?' - match = re.search(pattern, model_path, re.MULTILINE) - # new model format - if match and len(match.groups()) == 3: - model = match.group(1) - n_classes = int(match.group(2)) - dimensionality = match.group(3) - dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY + + if model_path is None: + model = 'MLP' + dimensionality = DEFAULT_DIMENSIONALITY + n_classes = 2 else: - LOG.info(f'Failed to parse out model params from model path: {model_path}') - exit(1) + match = re.search(pattern, model_path, re.MULTILINE) + if match and len(match.groups()) == 3: + model = match.group(1) + n_classes = int(match.group(2)) + dimensionality = match.group(3) + dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY + else: + LOG.info(f'Failed to parse out model params from model path: {model_path}') + exit(1) + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') return model, dimensionality, n_classes From 2f41978099c3b51a82bcbfc58a69f4348ae67931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ter=C3=A9zia=20Slanin=C3=A1kov=C3=A1?= <445526@mail.muni.cz> Date: Tue, 1 Oct 2024 12:44:27 +0200 Subject: [PATCH 2/2] improve model,dim,classes parsing --- training/create-buckets.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/training/create-buckets.py b/training/create-buckets.py index c3982cc..c950198 100644 --- a/training/create-buckets.py +++ b/training/create-buckets.py @@ -46,16 +46,17 @@ def parse_model_params(model_path): model = 'MLP' dimensionality = DEFAULT_DIMENSIONALITY n_classes = 2 + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') + return model, dimensionality, n_classes + + match = re.search(pattern, model_path, re.MULTILINE) + if match and len(match.groups()) == 3: + model, n_classes, dimensionality = match.groups() + dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY + n_classes = int(n_classes) else: - match = re.search(pattern, model_path, re.MULTILINE) - if match and len(match.groups()) == 3: - model = match.group(1) - n_classes = int(match.group(2)) - dimensionality = match.group(3) - dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY - else: - LOG.info(f'Failed to parse out model params from model path: {model_path}') - exit(1) + LOG.info(f'Failed to parse out model params from model path: {model_path}') + exit(1) LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') return model, dimensionality, n_classes