diff --git a/mbta-performance/chalicelib/lamp/ingest.py b/mbta-performance/chalicelib/lamp/ingest.py index 3675609..858fe48 100644 --- a/mbta-performance/chalicelib/lamp/ingest.py +++ b/mbta-performance/chalicelib/lamp/ingest.py @@ -3,7 +3,7 @@ import os import pandas as pd import requests -from typing import Tuple +from typing import Iterable, Tuple from .constants import LAMP_COLUMNS, S3_COLUMNS from ..date import format_dateint, get_current_service_date @@ -22,7 +22,6 @@ S3_BUCKET = "tm-mbta-performance" # month and day are not zero-padded S3_KEY_TEMPLATE = "Events-lamp/daily-data/{stop_id}/Year={YYYY}/Month={_M}/Day={_D}/events.csv" -TEMP_GTFS_LOCAL_PREFIX = ".temp/gtfs-feeds/" COLUMN_RENAME_MAP = { "headway_trunk_seconds": "headway_seconds", @@ -34,6 +33,12 @@ # that the vehicle is currently on (this can be due to AVL glitches, trip diversions, test train trips, etc.) TRIP_IDS_TO_DROP = ("NONREV-", "ADDED-") +# information to fetch from GTFS +TEMP_GTFS_LOCAL_PREFIX = ".temp/gtfs-feeds/" +MAX_QUERY_DEPTH = 950 # actually 1000 +# defining these columns in particular becasue we use them everywhere +RTE_DIR_STOP = ["route_id", "direction_id", "stop_id"] + def _local_save(s3_key, stop_events): """TODO remove this temp code, it saves the output files locally!""" @@ -127,40 +132,70 @@ def fetch_pq_file_from_remote(service_date: date) -> pd.DataFrame: ) -def _recalculate_fields_from_gtfs(pq_df: pd.DataFrame, service_date: date): - # calculate gtfs traveltimes +def fetch_from_gtfs(trip_ids: Iterable[str]) -> pd.DataFrame: mbta_gtfs = MbtaGtfsArchive(TEMP_GTFS_LOCAL_PREFIX) + service_date = get_current_service_date() feed = mbta_gtfs.get_feed_for_date(service_date) feed.download_or_build() session = feed.create_sqlite_session() - gtfs_stops = pd.read_sql( - session.query( - StopTime.trip_id, - StopTime.stop_id, - StopTime.arrival_time, - # func.min(StopTime.arrival_time).label("trip_start_time") - ) - .filter( - or_(StopTime.trip_id == tid for tid in pq_df["trip_id"].unique()) - # ).group_by(StopTime.trip_id).statement, + + gtfs_stops = [] + for start in range(1, len(trip_ids), MAX_QUERY_DEPTH): + gtfs_stops.append( + pd.read_sql( + session.query( + StopTime.trip_id, + StopTime.stop_id, + StopTime.arrival_time, + ) + .filter(or_(StopTime.trip_id == tid for tid in trip_ids[start : start + MAX_QUERY_DEPTH])) + .statement, + session.bind, + ) ) - .statement, - session.bind, - ) + return pd.concat(gtfs_stops) + + +def _recalculate_fields_from_gtfs(pq_df: pd.DataFrame, service_date: date): + trip_ids = pq_df["trip_id"].unique() + gtfs_stops = fetch_from_gtfs(trip_ids) + # we could do this groupby/min/merge in sql, but let's keep our computations in # pandas to stay consistent across services trip_start_times = gtfs_stops.groupby("trip_id").arrival_time.transform("min") - gtfs_stops["scheduled_tt"] = gtfs_stops.arrival_time - trip_start_times + gtfs_stops["scheduled_tt"] = gtfs_stops["arrival_time"] - trip_start_times + gtfs_stops["arrival_time"] = gtfs_stops["arrival_time"].astype(float) + # gtfs_stops["arrival_time"] = pd.to_datetime(gtfs_stops.arrival_time, unit="s") + + # assign each actual trip a scheduled trip_id, based on when it started the route + route_starts = pq_df.loc[pq_df.groupby("trip_id").event_time.idxmin()] + route_starts["arrival_time"] = ( + route_starts.event_time - pd.Timestamp(service_date).tz_localize("US/Eastern") + ).dt.total_seconds() + route_starts = route_starts.sort_values(by="arrival_time") + + trip_id_map = pd.merge_asof( + route_starts, + gtfs_stops[["trip_id", "stop_id"] + ["arrival_time"]], + on="arrival_time", + direction="nearest", + by=["trip_id", "stop_id"], + suffixes=["", "_scheduled"], + ) + trip_id_map = trip_id_map.set_index("trip_id").trip_id_scheduled + # merged # TODO check, hamima: can one conceivably return to a stop_id multiple times in a trip? - augmented_events = pd.merge( + # use the scheduled trip matching to get the scheduled traveltime + pq_df["scheduled_trip_id"] = pq_df.trip_id.map(trip_id_map) + pq_df = pd.merge( pq_df, gtfs_stops[["trip_id", "stop_id", "scheduled_tt"]], how="left", on=["trip_id", "stop_id"], suffixes=["", "_gtfs"], ) - return augmented_events + return pq_df def ingest_pq_file(pq_df: pd.DataFrame, service_date: date) -> pd.DataFrame: