diff --git a/assets/src/hooks/useSwings.ts b/assets/src/hooks/useSwings.ts
index d5cdf9beb..25c3f1712 100644
--- a/assets/src/hooks/useSwings.ts
+++ b/assets/src/hooks/useSwings.ts
@@ -12,9 +12,9 @@ const useSwings = (): Swing[] | null => {
allOpenRouteIds(routeTabs)
)
- const newRouteIds = allOpenRouteIds(routeTabs)
+ const newRouteIds = routeTabs.find((v) => v.isCurrentTab)?.selectedRouteIds
- if (!equalByElements(routeIds, newRouteIds)) {
+ if (newRouteIds && !equalByElements(routeIds, newRouteIds)) {
setRouteIds(newRouteIds)
}
diff --git a/assets/tests/factories/routeTab.ts b/assets/tests/factories/routeTab.ts
index e66c25b00..b952a7f21 100644
--- a/assets/tests/factories/routeTab.ts
+++ b/assets/tests/factories/routeTab.ts
@@ -19,9 +19,8 @@ export const routeTabPresetFactory = defaultRouteTabFactory.params({
ordering: undefined,
})
-const routeTabFactory = defaultRouteTabFactory.params({
+export const routeTabFactory = defaultRouteTabFactory.params({
presetName: undefined,
})
-routeTabFactory.createList
export default routeTabFactory
diff --git a/assets/tests/hooks/useSwings.test.tsx b/assets/tests/hooks/useSwings.test.tsx
index 60e400a9d..d4a3f160c 100644
--- a/assets/tests/hooks/useSwings.test.tsx
+++ b/assets/tests/hooks/useSwings.test.tsx
@@ -1,155 +1,114 @@
-import { jest, describe, test, expect } from "@jest/globals"
-import { renderHook } from "@testing-library/react"
+import { jest, describe, test, expect, beforeEach } from "@jest/globals"
+import { renderHook, waitFor } from "@testing-library/react"
import React, { ReactNode } from "react"
import * as Api from "../../src/api"
import useSwings from "../../src/hooks/useSwings"
-import { instantPromise } from "../testHelpers/mockHelpers"
import { initialState } from "../../src/state"
import { StateDispatchProvider } from "../../src/contexts/stateDispatchContext"
import { RouteTab } from "../../src/models/routeTab"
-import routeTabFactory from "../factories/routeTab"
+import routeTabFactory, { routeTabPresetFactory } from "../factories/routeTab"
+import { neverPromise } from "../testHelpers/mockHelpers"
+import { swingFactory } from "../factories/swing"
-jest.mock("../../src/api", () => ({
- __esModule: true,
+jest.mock("../../src/api")
- fetchSwings: jest.fn(() => new Promise(() => {})),
-}))
+beforeEach(() => {
+ jest.mocked(Api.fetchSwings).mockReturnValue(neverPromise())
+})
describe("useSwings", () => {
test("returns null while loading", () => {
- const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock
- const { result } = renderHook(() => {
- return useSwings()
- })
- expect(mockFetchSwings).toHaveBeenCalledTimes(1)
+ const { result } = renderHook(useSwings)
+ expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1)
expect(result.current).toEqual(null)
})
- test("returns result when loaded", () => {
- const swings = [
- {
- from_route_id: "1",
- from_run_id: "123-456",
- from_trip_id: "1234",
- to_route_id: "1",
- to_run_id: "123-789",
- to_trip_id: "5678",
- time: 100,
- },
- {
- from_route_id: "2",
- from_run_id: "124-456",
- from_trip_id: "4321",
- to_route_id: "2",
- to_run_id: "124-789",
- to_trip_id: "8765",
- time: 100,
- },
- ]
- const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock
- mockFetchSwings.mockImplementationOnce(() => instantPromise(swings))
- const { result } = renderHook(
- () => {
- return useSwings()
- },
- {
- wrapper: ({ children }) => (
-
- {children}
-
- ),
- }
- )
+ test("returns result when loaded", async () => {
+ const swings = swingFactory.buildList(2)
+ jest.mocked(Api.fetchSwings).mockResolvedValue(swings)
+
+ const { result } = renderHook(useSwings, {
+ wrapper: ({ children }) => (
+
+ {children}
+
+ ),
+ })
- expect(mockFetchSwings).toHaveBeenCalledWith(["1", "2"])
- expect(result.current).toEqual(swings)
+ await waitFor(() => {
+ expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledWith(["1"])
+ expect(result.current).toEqual(swings)
+ })
})
test("doesn't refetch swings on every render", () => {
- const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock
- const { rerender } = renderHook(
- () => {
- useSwings()
- },
- {
- wrapper: ({ children }) => (
-
- {children}
-
- ),
- }
- )
+ const { rerender } = renderHook(useSwings, {
+ wrapper: ({ children }) => (
+
+ {children}
+
+ ),
+ })
rerender()
- expect(mockFetchSwings).toHaveBeenCalledTimes(1)
+ expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1)
})
test("doesn't refetch swings when route Ids don't change", () => {
- const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock
-
let routeTabs = [
routeTabFactory.build({
selectedRouteIds: ["1"],
isCurrentTab: true,
}),
routeTabFactory.build({
- selectedRouteIds: ["2"],
+ selectedRouteIds: ["1"],
isCurrentTab: false,
}),
]
- const { rerender } = renderHook(
- () => {
- useSwings()
- },
- {
- wrapper: ({ children }) => (
- {children}
- ),
- }
- )
+ const { rerender } = renderHook(useSwings, {
+ wrapper: ({ children }) => (
+ {children}
+ ),
+ })
routeTabs = [
routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: false }),
- routeTabFactory.build({ selectedRouteIds: ["2"], isCurrentTab: true }),
+ routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: true }),
]
rerender()
- expect(mockFetchSwings).toHaveBeenCalledTimes(1)
+ expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(1)
})
test("does refetch swings when selected routes change", () => {
- const mockFetchSwings: jest.Mock = Api.fetchSwings as jest.Mock
-
- let routeTabs = [routeTabFactory.build({ selectedRouteIds: ["1"] })]
- const { rerender } = renderHook(
- () => {
- useSwings()
- },
- {
- wrapper: ({ children }) => (
- {children}
- ),
- }
- )
- routeTabs = [routeTabFactory.build({ selectedRouteIds: ["2"] })]
+ let routeTabs = [
+ routeTabFactory.build({ selectedRouteIds: ["1"], isCurrentTab: true }),
+ ]
+ const { rerender } = renderHook(useSwings, {
+ wrapper: ({ children }) => (
+ {children}
+ ),
+ })
+ routeTabs = [
+ routeTabFactory.build({ selectedRouteIds: ["2"], isCurrentTab: true }),
+ ]
rerender()
- expect(mockFetchSwings).toHaveBeenCalledTimes(2)
+ expect(jest.mocked(Api.fetchSwings)).toHaveBeenCalledTimes(2)
})
})