diff --git a/mbta-performance/chalicelib/lamp/ingest.py b/mbta-performance/chalicelib/lamp/ingest.py index 858fe48..0755edf 100644 --- a/mbta-performance/chalicelib/lamp/ingest.py +++ b/mbta-performance/chalicelib/lamp/ingest.py @@ -132,15 +132,15 @@ def fetch_pq_file_from_remote(service_date: date) -> pd.DataFrame: ) -def fetch_from_gtfs(trip_ids: Iterable[str]) -> pd.DataFrame: +def fetch_stop_times_from_gtfs(trip_ids: Iterable[str], service_date: date) -> pd.DataFrame: + """Fetch scheduled stop time information from GTFS.""" mbta_gtfs = MbtaGtfsArchive(TEMP_GTFS_LOCAL_PREFIX) - service_date = get_current_service_date() feed = mbta_gtfs.get_feed_for_date(service_date) feed.download_or_build() session = feed.create_sqlite_session() gtfs_stops = [] - for start in range(1, len(trip_ids), MAX_QUERY_DEPTH): + for start in range(0, len(trip_ids), MAX_QUERY_DEPTH): gtfs_stops.append( pd.read_sql( session.query( @@ -151,21 +151,22 @@ def fetch_from_gtfs(trip_ids: Iterable[str]) -> pd.DataFrame: .filter(or_(StopTime.trip_id == tid for tid in trip_ids[start : start + MAX_QUERY_DEPTH])) .statement, session.bind, + dtype_backend="numpy_nullable", ) ) return pd.concat(gtfs_stops) def _recalculate_fields_from_gtfs(pq_df: pd.DataFrame, service_date: date): + """Enrich LAMP data with GTFS data for some schedule information.""" trip_ids = pq_df["trip_id"].unique() - gtfs_stops = fetch_from_gtfs(trip_ids) + gtfs_stops = fetch_stop_times_from_gtfs(trip_ids, service_date) # we could do this groupby/min/merge in sql, but let's keep our computations in # pandas to stay consistent across services trip_start_times = gtfs_stops.groupby("trip_id").arrival_time.transform("min") gtfs_stops["scheduled_tt"] = gtfs_stops["arrival_time"] - trip_start_times gtfs_stops["arrival_time"] = gtfs_stops["arrival_time"].astype(float) - # gtfs_stops["arrival_time"] = pd.to_datetime(gtfs_stops.arrival_time, unit="s") # assign each actual trip a scheduled trip_id, based on when it started the route route_starts = pq_df.loc[pq_df.groupby("trip_id").event_time.idxmin()]