From 0c3ff2f94560037aa23d655f7a95aa42849119f1 Mon Sep 17 00:00:00 2001 From: Vineet Joshi Date: Thu, 12 Dec 2024 16:59:07 -0800 Subject: [PATCH] Evenly distributed the negative-labeled records to all ensemble models. The different models in the ensemble should not not receive different proportion of negative labeled records to ensure uniformity in their training and performance. Through this commit, we explicitly assign all the negative-labeled features with the batch of unlabeled records to each model in the ensemble. PiperOrigin-RevId: 705675828 --- spade_anomaly_detection/csv_data_loader.py | 13 ++- spade_anomaly_detection/occ_ensemble.py | 114 ++++++++++++++++++- spade_anomaly_detection/occ_ensemble_test.py | 76 +++++++++++++ spade_anomaly_detection/runner.py | 10 +- 4 files changed, 201 insertions(+), 12 deletions(-) diff --git a/spade_anomaly_detection/csv_data_loader.py b/spade_anomaly_detection/csv_data_loader.py index 96f038f..13d2fe7 100644 --- a/spade_anomaly_detection/csv_data_loader.py +++ b/spade_anomaly_detection/csv_data_loader.py @@ -528,15 +528,18 @@ def combine_features_dict_into_tensor( dataset = dataset.batch(batch_size, deterministic=True) dataset = dataset.prefetch(tf.data.AUTOTUNE) - # This Dataset was just created. Calculate the label distribution. - # Any string labels were already re-mapped to integers. So keys are always - # strings and values are always integers. - self._label_counts = self.counts_by_label(dataset) + # This Dataset was just created. Calculate the label distribution. Any + # string labels were already re-mapped to integers. So keys are always + # integers and values are EagerTensors. We need to extract the value within + # this Tensor for subsequent use. + self._label_counts = { + k: v.numpy() for k, v in self.counts_by_label(dataset).items() + } logging.info('Label counts: %s', self._label_counts) return dataset - def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, int]: + def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]: """Counts the number of samples in each label class in the dataset. When this function is called, the labels in the Dataset have already been diff --git a/spade_anomaly_detection/occ_ensemble.py b/spade_anomaly_detection/occ_ensemble.py index 15baa02..d01c46b 100644 --- a/spade_anomaly_detection/occ_ensemble.py +++ b/spade_anomaly_detection/occ_ensemble.py @@ -42,6 +42,8 @@ _RANDOM_SEED: Final[int] = 42 +_LABEL_TYPE: Final[str] = 'INT64' + # TODO(b/247116870): Create abstract class for templating out future OCC models. class GmmEnsemble: @@ -74,6 +76,12 @@ class GmmEnsemble: precision when raising this value, and an increase in recall when lowering it. Equavalent to saying the given data point needs to be X percentile or greater in order to be considered anomalous. + unlabeled_record_count: The number of unlabeled records in the dataset. + negative_record_count: The number of negative records in the dataset. + unlabeled_data_value: The value used in the label column to denote unlabeled + data. + negative_data_value: The value used in the label column to denote negative + data. verbose: Boolean denoting whether to send model performance and pseudo-labeling metrics to the GCP console. ensemble: A trained ensemble of one class classifiers. @@ -90,6 +98,10 @@ def __init__( positive_threshold: float = 1.0, negative_threshold: float = 95.0, random_seed: int = _RANDOM_SEED, + unlabeled_record_count: int | None = None, + negative_record_count: int | None = None, + unlabeled_data_value: int | None = None, + negative_data_value: int | None = None, verbose: bool = False, ) -> None: self.n_components = n_components @@ -100,6 +112,10 @@ def __init__( self.positive_threshold = positive_threshold self.negative_threshold = negative_threshold self._random_seed = random_seed + self.unlabeled_record_count = unlabeled_record_count + self.negative_record_count = negative_record_count + self.unlabeled_data_value = unlabeled_data_value + self.negative_data_value = negative_data_value self.verbose = verbose self.ensemble = [] @@ -121,6 +137,38 @@ def _get_model(self) -> mixture.GaussianMixture: random_state=self._random_seed, ) + def _get_filter_by_label_value_func(self, label_column_filter_value: int): + """Returns a function that filters a record based on the label column value. + + Args: + label_column_filter_value: The value of the label column to use as a + filter. If None, all records are included. + + Returns: + A function that returns True if the label column value is equal to the + label_column_filter_value parameter. + """ + + def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disable=unused-argument + if label_column_filter_value is None: + return True + label_cast = tf.cast(label, tf.dtypes.as_dtype(_LABEL_TYPE.lower())) + label_column_filter_value_cast = tf.cast( + label_column_filter_value, label_cast.dtype + ) + broadcast_equal = tf.equal(label_column_filter_value_cast, label_cast) + return tf.reduce_all(broadcast_equal) + + return filter_func + + def is_batched(self, dataset: tf.data.Dataset) -> bool: + """Returns True if the dataset is batched.""" + # This suffices for the current use case of the OCC ensemble. + return len(dataset.element_spec[0].shape) == 2 and ( + dataset.element_spec[0].shape[0] is None + or isinstance(dataset.element_spec[0].shape[0], int) + ) + def fit( self, train_x: tf.data.Dataset, batches_per_occ: int ) -> Sequence[mixture.GaussianMixture]: @@ -142,15 +190,73 @@ def fit( if batches_per_occ > 1: self._warm_start = True - dataset_iterator = train_x.as_numpy_iterator() + has_batches = self.is_batched(train_x) + logging.info('has_batches is %s', has_batches) + negative_features = None + + if ( + not self.unlabeled_record_count + or not self.negative_record_count + or not has_batches + or self.unlabeled_data_value is None + or self.negative_data_value is None + ): + # Either the dataset is not batched, or we don't have all the details to + # extract the negative-labeled data. Hence we will use all the data for + # training. + dataset_iterator = train_x.as_numpy_iterator() + else: + # We unbatch the dataset so that we can separate-out the unlabeled and + # negative data points + ds_unbatched = train_x.unbatch() + + ds_unlabeled = ds_unbatched.filter( + self._get_filter_by_label_value_func(self.unlabeled_data_value) + ) + + ds_negative = ds_unbatched.filter( + self._get_filter_by_label_value_func(self.negative_data_value) + ) + + negative_features_and_labels_zip = list( + zip(*ds_negative.as_numpy_iterator()) + ) + + negative_features = ( + negative_features_and_labels_zip[0] + if len(negative_features_and_labels_zip) == 2 + else None + ) + + if negative_features is None: + # The negative features were not extracted. This can happen when the + # dataset elements are not tuples of features and labels. So we will use + # all the data for training. + ds_batched = train_x + else: + # The negative features were extracted. How we can proceed with creating + # batches of unlabeled data, to which the negative data will be added + # before training. + batch_size = ( + self.unlabeled_record_count // self.ensemble_count + ) // batches_per_occ + ds_batched = ds_unlabeled.batch( + batch_size, + drop_remainder=False, + ) + dataset_iterator = ds_batched.as_numpy_iterator() for _ in range(self.ensemble_count): model = self._get_model() for _ in range(batches_per_occ): - features, labels = dataset_iterator.next() - del labels # Not needed for this task. - model.fit(features) + features, _ = dataset_iterator.next() + all_features = ( + np.concatenate([features, negative_features], axis=0) + if negative_features is not None + else features + ) + model.fit(all_features) self.ensemble.append(model) diff --git a/spade_anomaly_detection/occ_ensemble_test.py b/spade_anomaly_detection/occ_ensemble_test.py index 72d2c03..9f7ce9d 100644 --- a/spade_anomaly_detection/occ_ensemble_test.py +++ b/spade_anomaly_detection/occ_ensemble_test.py @@ -85,6 +85,82 @@ def test_ensemble_training_no_error( msg='Model count in ensemble not equal to specified ensemble size.', ) + @parameterized.named_parameters( + ('components_1_ensemble_10_full', 1, 10, 'full'), + ('components_1_ensemble_5_tied', 1, 5, 'tied'), + ) + def test_ensemble_training_unlabeled_negative_no_error( + self, n_components, ensemble_count, covariance_type + ): + batches_per_occ = 10 + negative_data_value = 0 + unlabeled_data_value = -1 + + ensemble_obj = occ_ensemble.GmmEnsemble( + n_components=n_components, + ensemble_count=ensemble_count, + covariance_type=covariance_type, + negative_data_value=negative_data_value, + unlabeled_data_value=unlabeled_data_value, + ) + + tf_dataset = data_loader.load_tf_dataset_from_csv( + dataset_name='covertype_pnu_100000', batch_size=None + ) + # These are the actual counts of unlabeled and negative records in the + # dataset. + unlabeled_record_count = 94950 + negative_record_count = 4333 + ensemble_obj.unlabeled_record_count = unlabeled_record_count + ensemble_obj.negative_record_count = negative_record_count + + features_len = tf_dataset.cardinality().numpy() + records_per_occ = features_len // ensemble_obj.ensemble_count + batch_size = records_per_occ // batches_per_occ + + tf_dataset = tf_dataset.shuffle(batch_size).batch( + batch_size, drop_remainder=True + ) + + ensemble_models = ensemble_obj.fit(tf_dataset, batches_per_occ) + + self.assertLen( + ensemble_models, + ensemble_obj.ensemble_count, + msg='Model count in ensemble not equal to specified ensemble size.', + ) + + def test_dataset_filtering(self): + positive_data_value = 1 + negative_data_value = 0 + unlabeled_data_value = -1 + gmm_ensemble = occ_ensemble.GmmEnsemble(n_components=1, ensemble_count=10) + + tf_dataset = data_loader.load_tf_dataset_from_csv( + dataset_name='covertype_pnu_100000', batch_size=None + ) + tf_unlabeled_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(unlabeled_data_value) + ) + tf_negative_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(negative_data_value) + ) + tf_positive_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(positive_data_value) + ) + self.assertEqual( + tf_unlabeled_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 94950, + ) + self.assertEqual( + tf_negative_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 4333, + ) + self.assertEqual( + tf_positive_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 715, + ) + @parameterized.named_parameters( ('labels_are_integers', False), ('labels_are_strings', True), diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index e70a4f9..798cccb 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -313,6 +313,8 @@ def instantiate_and_fit_ensemble( negative_threshold=self.runner_parameters.negative_threshold, random_seed=self.runner_parameters.random_seed, verbose=self.runner_parameters.verbose, + unlabeled_data_value=self.int_unlabeled_data_value, + negative_data_value=self.int_negative_data_value, ) training_record_count = unlabeled_record_count + negative_record_count @@ -327,7 +329,7 @@ def instantiate_and_fit_ensemble( self.input_data_loader = cast( data_loader.DataLoader, self.input_data_loader ) - unlabeled_data = self.input_data_loader.load_tf_dataset_from_bigquery( + training_data = self.input_data_loader.load_tf_dataset_from_bigquery( input_path=self.runner_parameters.input_bigquery_table_path, label_col_name=self.runner_parameters.label_col_name, where_statements=self.runner_parameters.where_statements, @@ -346,7 +348,7 @@ def instantiate_and_fit_ensemble( self.input_data_loader = cast( csv_data_loader.CsvDataLoader, self.input_data_loader ) - unlabeled_data = self.input_data_loader.load_tf_dataset_from_csv( + training_data = self.input_data_loader.load_tf_dataset_from_csv( input_path=self.runner_parameters.data_input_gcs_uri, label_col_name=self.runner_parameters.label_col_name, batch_size=batch_size, @@ -358,10 +360,12 @@ def instantiate_and_fit_ensemble( self.int_negative_data_value, ], ) + ensemble_object.unlabeled_record_count = unlabeled_record_count + ensemble_object.negative_record_count = negative_record_count logging.info('Fitting ensemble.') ensemble_object.fit( - train_x=unlabeled_data, + train_x=training_data, batches_per_occ=self.runner_parameters.batches_per_model, ) logging.info('Ensemble fit complete.')