Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions packages/router/src/__tests__/useBlocker.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,74 @@ describe('useBlocker', () => {
gHistory.remove(listenerId)
unmount()
})

describe('when function', () => {
it('should initialize with IDLE state when using a function', () => {
const { result, unmount } = renderHook(() =>
useBlocker({ when: () => false }),
)
expect(result.current.state).toBe('IDLE')
unmount()
})

it('should block when function returns true', () => {
const whenFn = vi.fn(() => true)
const { result, unmount } = renderHook(() => useBlocker({ when: whenFn }))

act(() => {
navigate('/blocked-path')
})

expect(whenFn).toHaveBeenCalled()
expect(result.current.state).toBe('BLOCKED')
unmount()
})

it('should not block when function returns false', () => {
const whenFn = vi.fn(() => false)
const { result, unmount } = renderHook(() => useBlocker({ when: whenFn }))

act(() => {
navigate('/allowed-path')
})

expect(whenFn).toHaveBeenCalled()
expect(result.current.state).toBe('IDLE')
unmount()
})

it('should pass nextLocation to when function', () => {
const whenFn = vi.fn(() => true)
const { result, unmount } = renderHook(() => useBlocker({ when: whenFn }))

act(() => {
navigate('/new-destination')
})

expect(whenFn).toHaveBeenCalledWith({
nextLocation: '/new-destination',
})
expect(result.current.state).toBe('BLOCKED')
unmount()
})

it('should block based on nextLocation', () => {
const whenFn = vi.fn(({ nextLocation }: { nextLocation: string }) =>
nextLocation.startsWith('/protected'),
)
const { result, unmount } = renderHook(() => useBlocker({ when: whenFn }))

act(() => {
navigate('/allowed')
})
expect(result.current.state).toBe('IDLE')

act(() => {
navigate('/protected/page')
})
expect(result.current.state).toBe('BLOCKED')

unmount()
})
})
})
19 changes: 14 additions & 5 deletions packages/router/src/history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ export interface NavigateOptions {

export type Listener = (ev?: PopStateEvent, options?: NavigateOptions) => any
export type BeforeUnloadListener = (ev: BeforeUnloadEvent) => any
export type BlockerCallback = (tx: { retry: () => void }) => void
export type BlockerCallback = (tx: {
retry: () => void
nextLocation: string
}) => void
export type Blocker = { id: string; callback: BlockerCallback }

const createHistory = () => {
Expand Down Expand Up @@ -51,7 +54,7 @@ const createHistory = () => {
}

if (blockers.length > 0) {
processBlockers(0, performNavigation)
processBlockers(0, performNavigation, to)
} else {
performNavigation()
}
Expand All @@ -65,7 +68,8 @@ const createHistory = () => {
}

if (blockers.length > 0) {
processBlockers(0, performBack)
// FIXME: for navigating back, we don't have the next location info
processBlockers(0, performBack, '')
} else {
performBack()
}
Expand Down Expand Up @@ -105,10 +109,15 @@ const createHistory = () => {
},
}

const processBlockers = (index: number, navigate: () => void) => {
const processBlockers = (
index: number,
navigate: () => void,
nextLocation: string,
) => {
if (index < blockers.length) {
blockers[index].callback({
retry: () => processBlockers(index + 1, navigate),
retry: () => processBlockers(index + 1, navigate, nextLocation),
nextLocation,
})
} else {
navigate()
Expand Down
14 changes: 10 additions & 4 deletions packages/router/src/useBlocker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import type { BlockerCallback } from './history.js'

type BlockerState = 'IDLE' | 'BLOCKED'

type WhenFunction = (args: { nextLocation: string }) => boolean

interface UseBlockerOptions {
when: boolean
when: boolean | WhenFunction
}

export function useBlocker({ when }: UseBlockerOptions) {
Expand All @@ -17,8 +19,11 @@ export function useBlocker({ when }: UseBlockerOptions) {
const blockerId = useId()

const blocker: BlockerCallback = useCallback(
({ retry }) => {
if (when) {
({ retry, nextLocation }) => {
const shouldBlock =
typeof when === 'function' ? when({ nextLocation }) : when

if (shouldBlock) {
setBlockerState('BLOCKED')
setPendingNavigation(() => retry)
} else {
Expand All @@ -29,7 +34,8 @@ export function useBlocker({ when }: UseBlockerOptions) {
)

useEffect(() => {
if (when) {
const shouldRegister = typeof when === 'function' || when
if (shouldRegister) {
block(blockerId, blocker)
} else {
unblock(blockerId)
Expand Down