Skip to content

Commit

Permalink
Evenly distributed the negative-labeled records to all ensemble models.
Browse files Browse the repository at this point in the history
The different models in the ensemble should not not receive different proportion
of negative labeled records to ensure uniformity in their training and
performance. Through this commit, we explicitly assign all the negative-labeled
features with the batch of unlabeled records to each model in the ensemble.

PiperOrigin-RevId: 705983061
  • Loading branch information
Vineet Joshi authored and The spade_anomaly_detection Authors committed Dec 13, 2024
1 parent 2ec1a5a commit 462e53e
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 12 deletions.
13 changes: 8 additions & 5 deletions spade_anomaly_detection/csv_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,18 @@ def combine_features_dict_into_tensor(
dataset = dataset.batch(batch_size, deterministic=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# This Dataset was just created. Calculate the label distribution.
# Any string labels were already re-mapped to integers. So keys are always
# strings and values are always integers.
self._label_counts = self.counts_by_label(dataset)
# This Dataset was just created. Calculate the label distribution. Any
# string labels were already re-mapped to integers. So keys are always
# integers and values are EagerTensors. We need to extract the value within
# this Tensor for subsequent use.
self._label_counts = {
k: v.numpy() for k, v in self.counts_by_label(dataset).items()
}
logging.info('Label counts: %s', self._label_counts)

return dataset

def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, int]:
def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:
"""Counts the number of samples in each label class in the dataset.
When this function is called, the labels in the Dataset have already been
Expand Down
114 changes: 110 additions & 4 deletions spade_anomaly_detection/occ_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

_RANDOM_SEED: Final[int] = 42

_LABEL_TYPE: Final[str] = 'INT64'


# TODO(b/247116870): Create abstract class for templating out future OCC models.
class GmmEnsemble:
Expand Down Expand Up @@ -74,6 +76,12 @@ class GmmEnsemble:
precision when raising this value, and an increase in recall when lowering
it. Equavalent to saying the given data point needs to be X percentile or
greater in order to be considered anomalous.
unlabeled_record_count: The number of unlabeled records in the dataset.
negative_record_count: The number of negative records in the dataset.
unlabeled_data_value: The value used in the label column to denote unlabeled
data.
negative_data_value: The value used in the label column to denote negative
data.
verbose: Boolean denoting whether to send model performance and
pseudo-labeling metrics to the GCP console.
ensemble: A trained ensemble of one class classifiers.
Expand All @@ -90,6 +98,10 @@ def __init__(
positive_threshold: float = 1.0,
negative_threshold: float = 95.0,
random_seed: int = _RANDOM_SEED,
unlabeled_record_count: int | None = None,
negative_record_count: int | None = None,
unlabeled_data_value: int | None = None,
negative_data_value: int | None = None,
verbose: bool = False,
) -> None:
self.n_components = n_components
Expand All @@ -100,6 +112,10 @@ def __init__(
self.positive_threshold = positive_threshold
self.negative_threshold = negative_threshold
self._random_seed = random_seed
self.unlabeled_record_count = unlabeled_record_count
self.negative_record_count = negative_record_count
self.unlabeled_data_value = unlabeled_data_value
self.negative_data_value = negative_data_value
self.verbose = verbose

self.ensemble = []
Expand All @@ -121,6 +137,38 @@ def _get_model(self) -> mixture.GaussianMixture:
random_state=self._random_seed,
)

def _get_filter_by_label_value_func(self, label_column_filter_value: int):
"""Returns a function that filters a record based on the label column value.
Args:
label_column_filter_value: The value of the label column to use as a
filter. If None, all records are included.
Returns:
A function that returns True if the label column value is equal to the
label_column_filter_value parameter.
"""

def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disable=unused-argument
if label_column_filter_value is None:
return True
label_cast = tf.cast(label, tf.dtypes.as_dtype(_LABEL_TYPE.lower()))
label_column_filter_value_cast = tf.cast(
label_column_filter_value, label_cast.dtype
)
broadcast_equal = tf.equal(label_column_filter_value_cast, label_cast)
return tf.reduce_all(broadcast_equal)

return filter_func

def is_batched(self, dataset: tf.data.Dataset) -> bool:
"""Returns True if the dataset is batched."""
# This suffices for the current use case of the OCC ensemble.
return len(dataset.element_spec[0].shape) == 2 and (
dataset.element_spec[0].shape[0] is None
or isinstance(dataset.element_spec[0].shape[0], int)
)

def fit(
self, train_x: tf.data.Dataset, batches_per_occ: int
) -> Sequence[mixture.GaussianMixture]:
Expand All @@ -142,15 +190,73 @@ def fit(
if batches_per_occ > 1:
self._warm_start = True

dataset_iterator = train_x.as_numpy_iterator()
has_batches = self.is_batched(train_x)
logging.info('has_batches is %s', has_batches)
negative_features = None

if (
not self.unlabeled_record_count
or not self.negative_record_count
or not has_batches
or self.unlabeled_data_value is None
or self.negative_data_value is None
):
# Either the dataset is not batched, or we don't have all the details to
# extract the negative-labeled data. Hence we will use all the data for
# training.
dataset_iterator = train_x.as_numpy_iterator()
else:
# We unbatch the dataset so that we can separate-out the unlabeled and
# negative data points
ds_unbatched = train_x.unbatch()

ds_unlabeled = ds_unbatched.filter(
self._get_filter_by_label_value_func(self.unlabeled_data_value)
)

ds_negative = ds_unbatched.filter(
self._get_filter_by_label_value_func(self.negative_data_value)
)

negative_features_and_labels_zip = list(
zip(*ds_negative.as_numpy_iterator())
)

negative_features = (
negative_features_and_labels_zip[0]
if len(negative_features_and_labels_zip) == 2
else None
)

if negative_features is None:
# The negative features were not extracted. This can happen when the
# dataset elements are not tuples of features and labels. So we will use
# all the data for training.
ds_batched = train_x
else:
# The negative features were extracted. How we can proceed with creating
# batches of unlabeled data, to which the negative data will be added
# before training.
batch_size = (
self.unlabeled_record_count // self.ensemble_count
) // batches_per_occ
ds_batched = ds_unlabeled.batch(
batch_size,
drop_remainder=False,
)
dataset_iterator = ds_batched.as_numpy_iterator()

for _ in range(self.ensemble_count):
model = self._get_model()

for _ in range(batches_per_occ):
features, labels = dataset_iterator.next()
del labels # Not needed for this task.
model.fit(features)
features, _ = dataset_iterator.next()
all_features = (
np.concatenate([features, negative_features], axis=0)
if negative_features is not None
else features
)
model.fit(all_features)

self.ensemble.append(model)

Expand Down
76 changes: 76 additions & 0 deletions spade_anomaly_detection/occ_ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,82 @@ def test_ensemble_training_no_error(
msg='Model count in ensemble not equal to specified ensemble size.',
)

@parameterized.named_parameters(
('components_1_ensemble_10_full', 1, 10, 'full'),
('components_1_ensemble_5_tied', 1, 5, 'tied'),
)
def test_ensemble_training_unlabeled_negative_no_error(
self, n_components, ensemble_count, covariance_type
):
batches_per_occ = 10
negative_data_value = 0
unlabeled_data_value = -1

ensemble_obj = occ_ensemble.GmmEnsemble(
n_components=n_components,
ensemble_count=ensemble_count,
covariance_type=covariance_type,
negative_data_value=negative_data_value,
unlabeled_data_value=unlabeled_data_value,
)

tf_dataset = data_loader.load_tf_dataset_from_csv(
dataset_name='covertype_pnu_100000', batch_size=None
)
# These are the actual counts of unlabeled and negative records in the
# dataset.
unlabeled_record_count = 94950
negative_record_count = 4333
ensemble_obj.unlabeled_record_count = unlabeled_record_count
ensemble_obj.negative_record_count = negative_record_count

features_len = tf_dataset.cardinality().numpy()
records_per_occ = features_len // ensemble_obj.ensemble_count
batch_size = records_per_occ // batches_per_occ

tf_dataset = tf_dataset.shuffle(batch_size).batch(
batch_size, drop_remainder=True
)

ensemble_models = ensemble_obj.fit(tf_dataset, batches_per_occ)

self.assertLen(
ensemble_models,
ensemble_obj.ensemble_count,
msg='Model count in ensemble not equal to specified ensemble size.',
)

def test_dataset_filtering(self):
positive_data_value = 1
negative_data_value = 0
unlabeled_data_value = -1
gmm_ensemble = occ_ensemble.GmmEnsemble(n_components=1, ensemble_count=10)

tf_dataset = data_loader.load_tf_dataset_from_csv(
dataset_name='covertype_pnu_100000', batch_size=None
)
tf_unlabeled_dataset = tf_dataset.filter(
gmm_ensemble._get_filter_by_label_value_func(unlabeled_data_value)
)
tf_negative_dataset = tf_dataset.filter(
gmm_ensemble._get_filter_by_label_value_func(negative_data_value)
)
tf_positive_dataset = tf_dataset.filter(
gmm_ensemble._get_filter_by_label_value_func(positive_data_value)
)
self.assertEqual(
tf_unlabeled_dataset.reduce(0, lambda x, _: x + 1).numpy(),
94950,
)
self.assertEqual(
tf_negative_dataset.reduce(0, lambda x, _: x + 1).numpy(),
4333,
)
self.assertEqual(
tf_positive_dataset.reduce(0, lambda x, _: x + 1).numpy(),
715,
)

@parameterized.named_parameters(
('labels_are_integers', False),
('labels_are_strings', True),
Expand Down
10 changes: 7 additions & 3 deletions spade_anomaly_detection/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ def instantiate_and_fit_ensemble(
negative_threshold=self.runner_parameters.negative_threshold,
random_seed=self.runner_parameters.random_seed,
verbose=self.runner_parameters.verbose,
unlabeled_data_value=self.int_unlabeled_data_value,
negative_data_value=self.int_negative_data_value,
)

training_record_count = unlabeled_record_count + negative_record_count
Expand All @@ -327,7 +329,7 @@ def instantiate_and_fit_ensemble(
self.input_data_loader = cast(
data_loader.DataLoader, self.input_data_loader
)
unlabeled_data = self.input_data_loader.load_tf_dataset_from_bigquery(
training_data = self.input_data_loader.load_tf_dataset_from_bigquery(
input_path=self.runner_parameters.input_bigquery_table_path,
label_col_name=self.runner_parameters.label_col_name,
where_statements=self.runner_parameters.where_statements,
Expand All @@ -346,7 +348,7 @@ def instantiate_and_fit_ensemble(
self.input_data_loader = cast(
csv_data_loader.CsvDataLoader, self.input_data_loader
)
unlabeled_data = self.input_data_loader.load_tf_dataset_from_csv(
training_data = self.input_data_loader.load_tf_dataset_from_csv(
input_path=self.runner_parameters.data_input_gcs_uri,
label_col_name=self.runner_parameters.label_col_name,
batch_size=batch_size,
Expand All @@ -358,10 +360,12 @@ def instantiate_and_fit_ensemble(
self.int_negative_data_value,
],
)
ensemble_object.unlabeled_record_count = unlabeled_record_count
ensemble_object.negative_record_count = negative_record_count

logging.info('Fitting ensemble.')
ensemble_object.fit(
train_x=unlabeled_data,
train_x=training_data,
batches_per_occ=self.runner_parameters.batches_per_model,
)
logging.info('Ensemble fit complete.')
Expand Down

0 comments on commit 462e53e

Please sign in to comment.