diff --git a/pyproject.toml b/pyproject.toml index 59427f1..0583d5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "malcolm-test" -version = "0.5.0" +version = "0.5.1" authors = [ { name="Seth Grover", email="mero.mero.guero@gmail.com" }, ] diff --git a/src/maltest/maltest.py b/src/maltest/maltest.py index 3720571..e0efd57 100755 --- a/src/maltest/maltest.py +++ b/src/maltest/maltest.py @@ -348,7 +348,7 @@ def main(): # get connection information about the VM and set it so the tests can access it as a fixture malcolmInfo = malcolmVm.Info() - logging.info(json.dumps(malcolmInfo)) + logging.info(json.dumps(malcolmInfo, default=mmguero.JsonObjSerializer)) set_malcolm_vm_info(malcolmInfo) # malcolm is started; wait for it to be ready to process data, then start testing diff --git a/src/maltest/tests/conftest.py b/src/maltest/tests/conftest.py index d66c6e0..7cb529a 100644 --- a/src/maltest/tests/conftest.py +++ b/src/maltest/tests/conftest.py @@ -6,6 +6,7 @@ get_pcap_hash_map, get_malcolm_http_auth, get_malcolm_url, + get_database_objs, ) @@ -24,6 +25,11 @@ def malcolm_http_auth(): yield get_malcolm_http_auth() +@pytest.fixture +def database_objs(): + yield get_database_objs() + + @pytest.fixture def malcolm_url(): yield get_malcolm_url() diff --git a/src/maltest/tests/test_malcolm_db_health.py b/src/maltest/tests/test_malcolm_db_health.py new file mode 100644 index 0000000..bd612a7 --- /dev/null +++ b/src/maltest/tests/test_malcolm_db_health.py @@ -0,0 +1,14 @@ +def test_malcolm_db_health( + malcolm_url, + database_objs, +): + dbObjs = database_objs + healthDict = dict( + dbObjs.DatabaseClass( + hosts=[ + f"{malcolm_url}/mapi/opensearch", + ], + **dbObjs.DatabaseInitArgs, + ).cluster.health() + ) + assert healthDict.get("status", "unknown") in ["green", "yellow"] diff --git a/src/maltest/utils.py b/src/maltest/utils.py index bf206b5..a223498 100644 --- a/src/maltest/utils.py +++ b/src/maltest/utils.py @@ -35,6 +35,14 @@ # the hash can be used as a query filter for tags. PcapHashMap = defaultdict(lambda: None) + +class DatabaseObjs: + def __init__(self): + self.DatabaseClass = None + self.SearchClass = None + self.DatabaseInitArgs = defaultdict(lambda: None) + + UPLOAD_ARTIFACT_LIST_NAME = 'UPLOAD_ARTIFACTS' MALCOLM_READY_TIMEOUT_SECONDS = 600 @@ -110,6 +118,14 @@ def get_malcolm_url(info=None): return 'http://localhost' +def get_database_objs(info=None): + global MalcolmVmInfo + if tmpInfo := info if info else MalcolmVmInfo: + return tmpInfo.get('database_objs', DatabaseObjs()) + else: + return DatabaseObjs() + + ################################################################################################### def parse_virter_log_line(log_line): pattern = r'(\w+)=(".*?"|\S+)' @@ -174,18 +190,7 @@ def __init__( self.debug = debug self.logger = logger self.apiSession = requests.Session() - ( - self.name, - self.DatabaseClass, - self.DatabaseInitArgs, - self.SearchClass, - ) = ( - None, - None, - None, - None, - ) - + self.dbObjs = None self.provisionErrorEncountered = False self.buildMode = False @@ -431,14 +436,14 @@ def ArkimeAlreadyHasFile( if not self.buildMode: url, auth = self.ConnectionParams() - if self.DatabaseClass: + if self.dbObjs: try: - s = self.SearchClass( - using=self.DatabaseClass( + s = self.dbObjs.SearchClass( + using=self.dbObjs.DatabaseClass( hosts=[ f"{url}/mapi/opensearch", ], - **self.DatabaseInitArgs, + **self.dbObjs.DatabaseInitArgs, ), index=ARKIME_FILES_INDEX, ).query("wildcard", name=f"*{os.path.basename(filename)}") @@ -525,13 +530,13 @@ def Info(self): try: # the first time we call Info for this object, set up our database classes, etc. - if self.DatabaseClass is None: + if self.dbObjs is None: - self.DatabaseInitArgs = {} - self.DatabaseInitArgs['request_timeout'] = 1 - self.DatabaseInitArgs['verify_certs'] = False - self.DatabaseInitArgs['ssl_assert_hostname'] = False - self.DatabaseInitArgs['ssl_show_warn'] = False + self.dbObjs = DatabaseObjs() + self.dbObjs.DatabaseInitArgs['request_timeout'] = 1 + self.dbObjs.DatabaseInitArgs['verify_certs'] = False + self.dbObjs.DatabaseInitArgs['ssl_assert_hostname'] = False + self.dbObjs.DatabaseInitArgs['ssl_show_warn'] = False if 'elastic' in mmguero.DeepGet(result, ['version', 'mode'], '').lower(): elasticImport = mmguero.DoDynamicImport( @@ -540,22 +545,24 @@ def Info(self): elasticDslImport = mmguero.DoDynamicImport( 'elasticsearch_dsl', 'elasticsearch-dsl', interactive=False, debug=self.debug ) - self.DatabaseClass = elasticImport.Elasticsearch - self.SearchClass = elasticDslImport.Search + self.dbObjs.DatabaseClass = elasticImport.Elasticsearch + self.dbObjs.SearchClass = elasticDslImport.Search if self.malcolmUsername: - self.DatabaseInitArgs['basic_auth'] = (self.malcolmUsername, self.malcolmPassword) + self.dbObjs.DatabaseInitArgs['basic_auth'] = (self.malcolmUsername, self.malcolmPassword) else: osImport = mmguero.DoDynamicImport( 'opensearchpy', 'opensearch-py', interactive=False, debug=self.debug ) - self.DatabaseClass = osImport.OpenSearch - self.SearchClass = osImport.Search + self.dbObjs.DatabaseClass = osImport.OpenSearch + self.dbObjs.SearchClass = osImport.Search if self.malcolmUsername: - self.DatabaseInitArgs['http_auth'] = (self.malcolmUsername, self.malcolmPassword) + self.dbObjs.DatabaseInitArgs['http_auth'] = (self.malcolmUsername, self.malcolmPassword) except Exception as e: self.logger.error(f"Error getting database objects: {e}") + result['database_objs'] = self.dbObjs + return result # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~