From fbf091ccf6bf320f841ad39d4a9186441f2c35b4 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Fri, 13 Oct 2023 11:55:44 -0400 Subject: [PATCH] Faster addition of dataset entries --- .../qcfractal/components/dataset_socket.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 9eb74b7fa..5c84007f3 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING -from sqlalchemy import select, delete, func, union, text +from sqlalchemy import select, delete, func, union, text, and_ from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM @@ -15,6 +15,7 @@ from qcportal.exceptions import AlreadyExistsError, MissingDataError from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata from qcportal.record_models import RecordStatusEnum, PriorityEnum +from qcportal.utils import chunk_iterable if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -718,21 +719,32 @@ def add_entries( # Create orm for all entries (in derived class) entry_orm = self._create_entries(session, dataset_id, new_entries) - # Get all existing entries first - stmt = select(self.entry_orm.name) - stmt = stmt.where(self.entry_orm.dataset_id == dataset_id) - existing_entries = session.execute(stmt).scalars().all() - inserted_idx: List[int] = [] existing_idx: List[int] = [] - for idx, entry in enumerate(entry_orm): - # Only add if the entry does not exist - if entry.name in existing_entries: - existing_idx.append(idx) - else: - session.add(entry) - inserted_idx.append(idx) + # Go in batches of 200 to avoid huge queries + idx = 0 + for entry_orm_batch in chunk_iterable(entry_orm, 200): + # Get all existing entries first + stmt = select(self.entry_orm.name) + stmt = stmt.where( + and_( + self.entry_orm.dataset_id == dataset_id, + self.entry_orm.name.in_([x.name for x in entry_orm_batch]), + ) + ) + + existing_entries = session.execute(stmt).scalars().all() + + for entry in entry_orm_batch: + # Only add if the entry does not exist + if entry.name in existing_entries: + existing_idx.append(idx) + else: + session.add(entry) + inserted_idx.append(idx) + + idx += 1 return InsertMetadata(inserted_idx=inserted_idx, existing_idx=existing_idx)