diff --git a/assets/src/hooks/useSwings.ts b/assets/src/hooks/useSwings.ts index d5cdf9beb..25c3f1712 100644 --- a/assets/src/hooks/useSwings.ts +++ b/assets/src/hooks/useSwings.ts @@ -12,9 +12,9 @@ const useSwings = (): Swing[] | null => { allOpenRouteIds(routeTabs) ) - const newRouteIds = allOpenRouteIds(routeTabs) + const newRouteIds = routeTabs.find((v) => v.isCurrentTab)?.selectedRouteIds - if (!equalByElements(routeIds, newRouteIds)) { + if (newRouteIds && !equalByElements(routeIds, newRouteIds)) { setRouteIds(newRouteIds) } diff --git a/assets/tests/factories/routeTab.ts b/assets/tests/factories/routeTab.ts index e66c25b00..b952a7f21 100644 --- a/assets/tests/factories/routeTab.ts +++ b/assets/tests/factories/routeTab.ts @@ -19,9 +19,8 @@ export const routeTabPresetFactory = defaultRouteTabFactory.params({ ordering: undefined, }) -const routeTabFactory = defaultRouteTabFactory.params({ +export const routeTabFactory = defaultRouteTabFactory.params({ presetName: undefined, }) -routeTabFactory.createList export default routeTabFactory diff --git a/assets/tests/hooks/useSwings.test.tsx b/assets/tests/hooks/useSwings.test.tsx index 60e400a9d..d4a3f160c 100644 --- a/assets/tests/hooks/useSwings.test.tsx +++ b/assets/tests/hooks/useSwings.test.tsx @@ -1,155 +1,114 @@ -import { jest, describe, test, expect } from "@jest/globals" -import { renderHook } from "@testing-library/react" +import { jest, describe, test, expect, beforeEach } from "@jest/globals" +import { renderHook, waitFor } from "@testing-library/react" import React, { ReactNode } from "react" import * as Api from "../../src/api" import useSwings from "../../src/hooks/useSwings" -import { instantPromise } from "../testHelpers/mockHelpers" import { initialState } from "../../src/state" import { StateDispatchProvider } from "../../src/contexts/stateDispatchContext" import { RouteTab } from "../../src/models/routeTab" -import routeTabFactory from "../factories/routeTab" +import routeTabFactory, { routeTabPresetFactory } from "../factories/routeTab" +import { neverPromise } from "../testHelpers/mockHelpers" +import { swingFactory } from "../factories/swing" -jest.mock("../../src/api", () => ({ - __esModule: true, +jest.mock("../../src/api") - fetchSwings: jest.fn(() => new Promise(() => {})), -})) +beforeEach(() => { + jest.mocked(Api.fetchSwings).mockReturnValue(neverPromise()) +}) describe("useSwings", () => { test("returns null while loading", () => { - const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock - const { result } = renderHook(() => { - return useSwings() - }) - expect(mockFetchSwings).toHaveBeenCalledTimes(1) + const { result } = renderHook(useSwings) + expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1) expect(result.current).toEqual(null) }) - test("returns result when loaded", () => { - const swings = [ - { - from_route_id: "1", - from_run_id: "123-456", - from_trip_id: "1234", - to_route_id: "1", - to_run_id: "123-789", - to_trip_id: "5678", - time: 100, - }, - { - from_route_id: "2", - from_run_id: "124-456", - from_trip_id: "4321", - to_route_id: "2", - to_run_id: "124-789", - to_trip_id: "8765", - time: 100, - }, - ] - const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock - mockFetchSwings.mockImplementationOnce(() => instantPromise(swings)) - const { result } = renderHook( - () => { - return useSwings() - }, - { - wrapper: ({ children }) => ( - - {children} - - ), - } - ) + test("returns result when loaded", async () => { + const swings = swingFactory.buildList(2) + jest.mocked(Api.fetchSwings).mockResolvedValue(swings) + + const { result } = renderHook(useSwings, { + wrapper: ({ children }) => ( + + {children} + + ), + }) - expect(mockFetchSwings).toHaveBeenCalledWith(["1", "2"]) - expect(result.current).toEqual(swings) + await waitFor(() => { + expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledWith(["1"]) + expect(result.current).toEqual(swings) + }) }) test("doesn't refetch swings on every render", () => { - const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock - const { rerender } = renderHook( - () => { - useSwings() - }, - { - wrapper: ({ children }) => ( - - {children} - - ), - } - ) + const { rerender } = renderHook(useSwings, { + wrapper: ({ children }) => ( + + {children} + + ), + }) rerender() - expect(mockFetchSwings).toHaveBeenCalledTimes(1) + expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1) }) test("doesn't refetch swings when route Ids don't change", () => { - const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock - let routeTabs = [ routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: true, }), routeTabFactory.build({ - selectedRouteIds: ["2"], + selectedRouteIds: ["1"], isCurrentTab: false, }), ] - const { rerender } = renderHook( - () => { - useSwings() - }, - { - wrapper: ({ children }) => ( - {children} - ), - } - ) + const { rerender } = renderHook(useSwings, { + wrapper: ({ children }) => ( + {children} + ), + }) routeTabs = [ routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: false }), - routeTabFactory.build({ selectedRouteIds: ["2"], isCurrentTab: true }), + routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: true }), ] rerender() - expect(mockFetchSwings).toHaveBeenCalledTimes(1) + expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1) }) test("does refetch swings when selected routes change", () => { - const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock - - let routeTabs = [routeTabFactory.build({ selectedRouteIds: ["1"] })] - const { rerender } = renderHook( - () => { - useSwings() - }, - { - wrapper: ({ children }) => ( - {children} - ), - } - ) - routeTabs = [routeTabFactory.build({ selectedRouteIds: ["2"] })] + let routeTabs = [ + routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: true }), + ] + const { rerender } = renderHook(useSwings, { + wrapper: ({ children }) => ( + {children} + ), + }) + routeTabs = [ + routeTabFactory.build({ selectedRouteIds: ["2"], isCurrentTab: true }), + ] rerender() - expect(mockFetchSwings).toHaveBeenCalledTimes(2) + expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(2) }) })