diff --git a/package.json b/package.json index 19ee576..8e08368 100644 --- a/package.json +++ b/package.json @@ -43,8 +43,8 @@ "eslint-plugin-react": "^7.1.0", "jest": "^21.1.0", "prettier": "^1.7.0", - "react": "^15.6.1", - "react-dom": "^15.6.1", + "react": "^16.0.0", + "react-dom": "^16.0.0", "rimraf": "^2.6.2" } } diff --git a/src/index.js b/src/index.js index 2a366a6..f3a7ee3 100644 --- a/src/index.js +++ b/src/index.js @@ -24,16 +24,13 @@ function traverseRenderedChildren(internalInstance, callback, argument) { } } -function setPendingForceUpdate(internalInstance, shouldUpdate) { - if ( - internalInstance._pendingForceUpdate === false && - shouldUpdate(internalInstance) - ) { +function setPendingForceUpdate(internalInstance) { + if (internalInstance._pendingForceUpdate === false) { internalInstance._pendingForceUpdate = true } } -function forceUpdateIfPending(internalInstance, onUpdate) { +function forceUpdateIfPending(internalInstance) { if (internalInstance._pendingForceUpdate === true) { const publicInstance = internalInstance._instance const { updater } = publicInstance @@ -43,57 +40,86 @@ function forceUpdateIfPending(internalInstance, onUpdate) { } else if (updater && typeof updater.enqueueForceUpdate === 'function') { updater.enqueueForceUpdate(publicInstance) } - onUpdate(internalInstance) } } -function deepForceUpdateStack(instance, shouldUpdate, onUpdate) { +function deepForceUpdateStack(instance) { const internalInstance = instance._reactInternalInstance - traverseRenderedChildren( - internalInstance, - setPendingForceUpdate, - shouldUpdate, - ) - traverseRenderedChildren(internalInstance, forceUpdateIfPending, onUpdate) + traverseRenderedChildren(internalInstance, setPendingForceUpdate) + traverseRenderedChildren(internalInstance, forceUpdateIfPending) +} + +function onEnterFiber(node, toUpdate, shouldUpdate) { + if (!toUpdate) { + return undefined + } + if (node.tag === ReactClassComponent) { + toUpdate.push(shouldUpdate( + node.tag, + node.stateNode, + node._debugSource && node._debugSource.fileName, + )) + } else if (!toUpdate[toUpdate.length - 1]) { + toUpdate[toUpdate.length - 1] = shouldUpdate( + node.tag, + null, // publicInstance + node._debugSource && node._debugSource.fileName, + ) + } + return undefined +} + +function onLeaveFiber(node, toUpdate, onUpdate) { + if (node.tag !== ReactClassComponent || (toUpdate && !toUpdate.pop())) { + return undefined + } + const publicInstance = node.stateNode + const { updater } = publicInstance + if (typeof publicInstance.forceUpdate === 'function') { + publicInstance.forceUpdate() + } else if ( + updater && typeof updater.enqueueForceUpdate === 'function' + ) { + updater.enqueueForceUpdate(publicInstance) + } + if (onUpdate) { + onUpdate(publicInstance) + } + return undefined } export default function deepForceUpdate( instance, - shouldUpdate = () => true, - onUpdate = () => {}, + shouldUpdateClosestClassInstance, + onUpdateClassInstance, ) { const root = instance._reactInternalFiber || instance._reactInternalInstance if (typeof root.tag !== 'number') { + if (shouldUpdateClosestClassInstance || onUpdateClassInstance) { + throw new Error('shouldUpdateClosestClassInstance and ' + + 'onUpdateClassInstance are only supported in React Fiber') + } // Traverse stack-based React tree. - return deepForceUpdateStack(instance, shouldUpdate, onUpdate) + return deepForceUpdateStack(instance) } let node = root + const toUpdate = shouldUpdateClosestClassInstance ? [] : null while (true) { - if (node.tag === ReactClassComponent && shouldUpdate(node)) { - const publicInstance = node.stateNode - const { updater } = publicInstance - if (typeof publicInstance.forceUpdate === 'function') { - publicInstance.forceUpdate() - } else if (updater && typeof updater.enqueueForceUpdate === 'function') { - updater.enqueueForceUpdate(publicInstance) - } - onUpdate(node) - } + onEnterFiber(node, toUpdate, shouldUpdateClosestClassInstance) if (node.child) { node.child.return = node node = node.child continue } - if (node === root) { - return undefined - } - while (!node.sibling) { - if (!node.return || node.return === root) { + while (!node.sibling || node === root) { + onLeaveFiber(node, toUpdate, onUpdateClassInstance) + if (!node.return || node === root) { return undefined } node = node.return } + onLeaveFiber(node, toUpdate, onUpdateClassInstance) node.sibling.return = node.return node = node.sibling }