Skip to content

Commit

Permalink
change dtype backend for cleaner merges
Browse files Browse the repository at this point in the history
  • Loading branch information
hamima-halim committed May 6, 2024
1 parent 661b662 commit 35485e1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mbta-performance/chalicelib/lamp/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def fetch_pq_file_from_remote(service_date: date) -> pd.DataFrame:
)


def fetch_from_gtfs(trip_ids: Iterable[str]) -> pd.DataFrame:
def fetch_stop_times_from_gtfs(trip_ids: Iterable[str], service_date: date) -> pd.DataFrame:
"""Fetch scheduled stop time information from GTFS."""
mbta_gtfs = MbtaGtfsArchive(TEMP_GTFS_LOCAL_PREFIX)
service_date = get_current_service_date()
feed = mbta_gtfs.get_feed_for_date(service_date)
feed.download_or_build()
session = feed.create_sqlite_session()

gtfs_stops = []
for start in range(1, len(trip_ids), MAX_QUERY_DEPTH):
for start in range(0, len(trip_ids), MAX_QUERY_DEPTH):
gtfs_stops.append(
pd.read_sql(
session.query(
Expand All @@ -151,21 +151,22 @@ def fetch_from_gtfs(trip_ids: Iterable[str]) -> pd.DataFrame:
.filter(or_(StopTime.trip_id == tid for tid in trip_ids[start : start + MAX_QUERY_DEPTH]))
.statement,
session.bind,
dtype_backend="numpy_nullable",
)
)
return pd.concat(gtfs_stops)


def _recalculate_fields_from_gtfs(pq_df: pd.DataFrame, service_date: date):
"""Enrich LAMP data with GTFS data for some schedule information."""
trip_ids = pq_df["trip_id"].unique()
gtfs_stops = fetch_from_gtfs(trip_ids)
gtfs_stops = fetch_stop_times_from_gtfs(trip_ids, service_date)

# we could do this groupby/min/merge in sql, but let's keep our computations in
# pandas to stay consistent across services
trip_start_times = gtfs_stops.groupby("trip_id").arrival_time.transform("min")
gtfs_stops["scheduled_tt"] = gtfs_stops["arrival_time"] - trip_start_times
gtfs_stops["arrival_time"] = gtfs_stops["arrival_time"].astype(float)
# gtfs_stops["arrival_time"] = pd.to_datetime(gtfs_stops.arrival_time, unit="s")

# assign each actual trip a scheduled trip_id, based on when it started the route
route_starts = pq_df.loc[pq_df.groupby("trip_id").event_time.idxmin()]
Expand Down

0 comments on commit 35485e1

Please sign in to comment.