diff --git a/training/create-buckets.py b/training/create-buckets.py index 619744c..c950198 100644 --- a/training/create-buckets.py +++ b/training/create-buckets.py @@ -7,9 +7,8 @@ import numpy as np import pandas as pd import torch -from tqdm import tqdm - from model import LIDatasetPredict, load_model +from tqdm import tqdm from utils import ( create_dir, dir_exists, @@ -42,16 +41,23 @@ def load_all_embeddings(path): def parse_model_params(model_path): LOG.info(f'Parsing out model params from model path: {model_path}') pattern = r'model-(\w+)--.*?n_classes-(\d+)(?:--.*?dimensionality-(\d+))?' + + if model_path is None: + model = 'MLP' + dimensionality = DEFAULT_DIMENSIONALITY + n_classes = 2 + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') + return model, dimensionality, n_classes + match = re.search(pattern, model_path, re.MULTILINE) - # new model format if match and len(match.groups()) == 3: - model = match.group(1) - n_classes = int(match.group(2)) - dimensionality = match.group(3) + model, n_classes, dimensionality = match.groups() dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY + n_classes = int(n_classes) else: LOG.info(f'Failed to parse out model params from model path: {model_path}') exit(1) + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') return model, dimensionality, n_classes