diff --git a/spade_anomaly_detection/csv_data_loader.py b/spade_anomaly_detection/csv_data_loader.py index 96f038f..13d2fe7 100644 --- a/spade_anomaly_detection/csv_data_loader.py +++ b/spade_anomaly_detection/csv_data_loader.py @@ -528,15 +528,18 @@ def combine_features_dict_into_tensor( dataset = dataset.batch(batch_size, deterministic=True) dataset = dataset.prefetch(tf.data.AUTOTUNE) - # This Dataset was just created. Calculate the label distribution. - # Any string labels were already re-mapped to integers. So keys are always - # strings and values are always integers. - self._label_counts = self.counts_by_label(dataset) + # This Dataset was just created. Calculate the label distribution. Any + # string labels were already re-mapped to integers. So keys are always + # integers and values are EagerTensors. We need to extract the value within + # this Tensor for subsequent use. + self._label_counts = { + k: v.numpy() for k, v in self.counts_by_label(dataset).items() + } logging.info('Label counts: %s', self._label_counts) return dataset - def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, int]: + def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]: """Counts the number of samples in each label class in the dataset. When this function is called, the labels in the Dataset have already been diff --git a/spade_anomaly_detection/occ_ensemble.py b/spade_anomaly_detection/occ_ensemble.py index 15baa02..d01c46b 100644 --- a/spade_anomaly_detection/occ_ensemble.py +++ b/spade_anomaly_detection/occ_ensemble.py @@ -42,6 +42,8 @@ _RANDOM_SEED: Final[int] = 42 +_LABEL_TYPE: Final[str] = 'INT64' + # TODO(b/247116870): Create abstract class for templating out future OCC models. class GmmEnsemble: @@ -74,6 +76,12 @@ class GmmEnsemble: precision when raising this value, and an increase in recall when lowering it. Equavalent to saying the given data point needs to be X percentile or greater in order to be considered anomalous. + unlabeled_record_count: The number of unlabeled records in the dataset. + negative_record_count: The number of negative records in the dataset. + unlabeled_data_value: The value used in the label column to denote unlabeled + data. + negative_data_value: The value used in the label column to denote negative + data. verbose: Boolean denoting whether to send model performance and pseudo-labeling metrics to the GCP console. ensemble: A trained ensemble of one class classifiers. @@ -90,6 +98,10 @@ def __init__( positive_threshold: float = 1.0, negative_threshold: float = 95.0, random_seed: int = _RANDOM_SEED, + unlabeled_record_count: int | None = None, + negative_record_count: int | None = None, + unlabeled_data_value: int | None = None, + negative_data_value: int | None = None, verbose: bool = False, ) -> None: self.n_components = n_components @@ -100,6 +112,10 @@ def __init__( self.positive_threshold = positive_threshold self.negative_threshold = negative_threshold self._random_seed = random_seed + self.unlabeled_record_count = unlabeled_record_count + self.negative_record_count = negative_record_count + self.unlabeled_data_value = unlabeled_data_value + self.negative_data_value = negative_data_value self.verbose = verbose self.ensemble = [] @@ -121,6 +137,38 @@ def _get_model(self) -> mixture.GaussianMixture: random_state=self._random_seed, ) + def _get_filter_by_label_value_func(self, label_column_filter_value: int): + """Returns a function that filters a record based on the label column value. + + Args: + label_column_filter_value: The value of the label column to use as a + filter. If None, all records are included. + + Returns: + A function that returns True if the label column value is equal to the + label_column_filter_value parameter. + """ + + def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disable=unused-argument + if label_column_filter_value is None: + return True + label_cast = tf.cast(label, tf.dtypes.as_dtype(_LABEL_TYPE.lower())) + label_column_filter_value_cast = tf.cast( + label_column_filter_value, label_cast.dtype + ) + broadcast_equal = tf.equal(label_column_filter_value_cast, label_cast) + return tf.reduce_all(broadcast_equal) + + return filter_func + + def is_batched(self, dataset: tf.data.Dataset) -> bool: + """Returns True if the dataset is batched.""" + # This suffices for the current use case of the OCC ensemble. + return len(dataset.element_spec[0].shape) == 2 and ( + dataset.element_spec[0].shape[0] is None + or isinstance(dataset.element_spec[0].shape[0], int) + ) + def fit( self, train_x: tf.data.Dataset, batches_per_occ: int ) -> Sequence[mixture.GaussianMixture]: @@ -142,15 +190,73 @@ def fit( if batches_per_occ > 1: self._warm_start = True - dataset_iterator = train_x.as_numpy_iterator() + has_batches = self.is_batched(train_x) + logging.info('has_batches is %s', has_batches) + negative_features = None + + if ( + not self.unlabeled_record_count + or not self.negative_record_count + or not has_batches + or self.unlabeled_data_value is None + or self.negative_data_value is None + ): + # Either the dataset is not batched, or we don't have all the details to + # extract the negative-labeled data. Hence we will use all the data for + # training. + dataset_iterator = train_x.as_numpy_iterator() + else: + # We unbatch the dataset so that we can separate-out the unlabeled and + # negative data points + ds_unbatched = train_x.unbatch() + + ds_unlabeled = ds_unbatched.filter( + self._get_filter_by_label_value_func(self.unlabeled_data_value) + ) + + ds_negative = ds_unbatched.filter( + self._get_filter_by_label_value_func(self.negative_data_value) + ) + + negative_features_and_labels_zip = list( + zip(*ds_negative.as_numpy_iterator()) + ) + + negative_features = ( + negative_features_and_labels_zip[0] + if len(negative_features_and_labels_zip) == 2 + else None + ) + + if negative_features is None: + # The negative features were not extracted. This can happen when the + # dataset elements are not tuples of features and labels. So we will use + # all the data for training. + ds_batched = train_x + else: + # The negative features were extracted. How we can proceed with creating + # batches of unlabeled data, to which the negative data will be added + # before training. + batch_size = ( + self.unlabeled_record_count // self.ensemble_count + ) // batches_per_occ + ds_batched = ds_unlabeled.batch( + batch_size, + drop_remainder=False, + ) + dataset_iterator = ds_batched.as_numpy_iterator() for _ in range(self.ensemble_count): model = self._get_model() for _ in range(batches_per_occ): - features, labels = dataset_iterator.next() - del labels # Not needed for this task. - model.fit(features) + features, _ = dataset_iterator.next() + all_features = ( + np.concatenate([features, negative_features], axis=0) + if negative_features is not None + else features + ) + model.fit(all_features) self.ensemble.append(model) diff --git a/spade_anomaly_detection/occ_ensemble_test.py b/spade_anomaly_detection/occ_ensemble_test.py index 72d2c03..9f7ce9d 100644 --- a/spade_anomaly_detection/occ_ensemble_test.py +++ b/spade_anomaly_detection/occ_ensemble_test.py @@ -85,6 +85,82 @@ def test_ensemble_training_no_error( msg='Model count in ensemble not equal to specified ensemble size.', ) + @parameterized.named_parameters( + ('components_1_ensemble_10_full', 1, 10, 'full'), + ('components_1_ensemble_5_tied', 1, 5, 'tied'), + ) + def test_ensemble_training_unlabeled_negative_no_error( + self, n_components, ensemble_count, covariance_type + ): + batches_per_occ = 10 + negative_data_value = 0 + unlabeled_data_value = -1 + + ensemble_obj = occ_ensemble.GmmEnsemble( + n_components=n_components, + ensemble_count=ensemble_count, + covariance_type=covariance_type, + negative_data_value=negative_data_value, + unlabeled_data_value=unlabeled_data_value, + ) + + tf_dataset = data_loader.load_tf_dataset_from_csv( + dataset_name='covertype_pnu_100000', batch_size=None + ) + # These are the actual counts of unlabeled and negative records in the + # dataset. + unlabeled_record_count = 94950 + negative_record_count = 4333 + ensemble_obj.unlabeled_record_count = unlabeled_record_count + ensemble_obj.negative_record_count = negative_record_count + + features_len = tf_dataset.cardinality().numpy() + records_per_occ = features_len // ensemble_obj.ensemble_count + batch_size = records_per_occ // batches_per_occ + + tf_dataset = tf_dataset.shuffle(batch_size).batch( + batch_size, drop_remainder=True + ) + + ensemble_models = ensemble_obj.fit(tf_dataset, batches_per_occ) + + self.assertLen( + ensemble_models, + ensemble_obj.ensemble_count, + msg='Model count in ensemble not equal to specified ensemble size.', + ) + + def test_dataset_filtering(self): + positive_data_value = 1 + negative_data_value = 0 + unlabeled_data_value = -1 + gmm_ensemble = occ_ensemble.GmmEnsemble(n_components=1, ensemble_count=10) + + tf_dataset = data_loader.load_tf_dataset_from_csv( + dataset_name='covertype_pnu_100000', batch_size=None + ) + tf_unlabeled_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(unlabeled_data_value) + ) + tf_negative_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(negative_data_value) + ) + tf_positive_dataset = tf_dataset.filter( + gmm_ensemble._get_filter_by_label_value_func(positive_data_value) + ) + self.assertEqual( + tf_unlabeled_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 94950, + ) + self.assertEqual( + tf_negative_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 4333, + ) + self.assertEqual( + tf_positive_dataset.reduce(0, lambda x, _: x + 1).numpy(), + 715, + ) + @parameterized.named_parameters( ('labels_are_integers', False), ('labels_are_strings', True), diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index e70a4f9..798cccb 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -313,6 +313,8 @@ def instantiate_and_fit_ensemble( negative_threshold=self.runner_parameters.negative_threshold, random_seed=self.runner_parameters.random_seed, verbose=self.runner_parameters.verbose, + unlabeled_data_value=self.int_unlabeled_data_value, + negative_data_value=self.int_negative_data_value, ) training_record_count = unlabeled_record_count + negative_record_count @@ -327,7 +329,7 @@ def instantiate_and_fit_ensemble( self.input_data_loader = cast( data_loader.DataLoader, self.input_data_loader ) - unlabeled_data = self.input_data_loader.load_tf_dataset_from_bigquery( + training_data = self.input_data_loader.load_tf_dataset_from_bigquery( input_path=self.runner_parameters.input_bigquery_table_path, label_col_name=self.runner_parameters.label_col_name, where_statements=self.runner_parameters.where_statements, @@ -346,7 +348,7 @@ def instantiate_and_fit_ensemble( self.input_data_loader = cast( csv_data_loader.CsvDataLoader, self.input_data_loader ) - unlabeled_data = self.input_data_loader.load_tf_dataset_from_csv( + training_data = self.input_data_loader.load_tf_dataset_from_csv( input_path=self.runner_parameters.data_input_gcs_uri, label_col_name=self.runner_parameters.label_col_name, batch_size=batch_size, @@ -358,10 +360,12 @@ def instantiate_and_fit_ensemble( self.int_negative_data_value, ], ) + ensemble_object.unlabeled_record_count = unlabeled_record_count + ensemble_object.negative_record_count = negative_record_count logging.info('Fitting ensemble.') ensemble_object.fit( - train_x=unlabeled_data, + train_x=training_data, batches_per_occ=self.runner_parameters.batches_per_model, ) logging.info('Ensemble fit complete.')