Skip to content

Commit

Permalink
Internal update.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702899732
  • Loading branch information
raj-sinha authored and The spade_anomaly_detection Authors committed Dec 6, 2024
1 parent e3ec5e3 commit c7307a0
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions spade_anomaly_detection/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"""

import enum
# TODO(b/247116870): Change to collections when Vertex supports python 3.9
from typing import Mapping, Optional, Tuple, cast

from absl import logging
Expand All @@ -49,6 +48,8 @@
from spade_anomaly_detection import supervised_model
import tensorflow as tf

# TODO(b/247116870): Change to collections when Vertex supports python 3.9


@enum.unique
class DataFormat(enum.Enum):
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
else:
self.supervised_model_object = None

# If the thresholds are not set, use the thresholds from the input table.
if (
self.runner_parameters.positive_threshold is None
or self.runner_parameters.negative_threshold is None
Expand Down Expand Up @@ -760,7 +762,7 @@ def run(self) -> None:
batch_size=1,
)
train_label_counts = self.input_data_loader.label_counts
# TODO(sinharaj): This is not ideal, we should not need to read the files
# This is not ideal, we should not need to read the files
# again. Find a way to get the label counts without reading the files.
# Assumes that data loader has already been used to read the input table.
total_record_count = sum(train_label_counts.values())
Expand Down Expand Up @@ -885,6 +887,7 @@ def run(self) -> None:
labels=updated_labels,
weights=weights,
)
# End of pseudolabeling and supervised model training loop.

if not self.runner_parameters.upload_only:
self.evaluate_model()
Expand Down

0 comments on commit c7307a0

Please sign in to comment.