diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 4141e2184..ea78cb74a 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -676,6 +676,15 @@ vector Setup::loadParams( if(cfg.contains("numVirtualLossesPerThread"+idxStr)) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread"+idxStr, 0.01, 1000.0); else if(cfg.contains("numVirtualLossesPerThread")) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread", 0.01, 1000.0); else params.numVirtualLossesPerThread = 1.0; + if(cfg.contains("suppressVirtualLossExploreFactor"+idxStr)) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor"+idxStr, 1.0, 1e10); + else if(cfg.contains("suppressVirtualLossExploreFactor")) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor", 1.0, 1e10); + else params.suppressVirtualLossExploreFactor = 1e10; + if(cfg.contains("suppressVirtualLossHindsight"+idxStr)) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"+idxStr); + else if(cfg.contains("suppressVirtualLossHindsight")) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"); + else params.suppressVirtualLossHindsight = false; + if(cfg.contains("suppressVirtualLossLeakCatchUp"+idxStr)) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp"+idxStr); + else if(cfg.contains("suppressVirtualLossLeakCatchUp")) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp"); + else params.suppressVirtualLossLeakCatchUp = false; if(cfg.contains("treeReuseCarryOverTimeFactor"+idxStr)) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor"+idxStr,0.0,1.0); else if(cfg.contains("treeReuseCarryOverTimeFactor")) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor",0.0,1.0); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index b5653ff74..25537471c 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -554,8 +554,8 @@ void Search::runWholeSearch( upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxPlayouts - numPlayouts); upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxVisits - numPlayouts - numNonPlayoutVisits); - bool finishedPlayout = runSinglePlayout(*stbuf, upperBoundVisitsLeft); - if(finishedPlayout) { + PlayoutResult playoutResult = runSinglePlayout(*stbuf, upperBoundVisitsLeft); + if(playoutResult == PLAYOUT_SUCCESS) { numPlayouts = numPlayoutsShared.fetch_add((int64_t)1, std::memory_order_relaxed); numPlayouts += 1; } @@ -747,7 +747,7 @@ void Search::beginSearch(bool pondering) { node.statsLock.clear(std::memory_order_release); //Update all other stats - recomputeNodeStats(node, dummyThread, 0, true); + recomputeNodeStats(node, dummyThread, true); } } @@ -974,8 +974,12 @@ void Search::recursivelyRecomputeStats(SearchNode& n) { //and has 0 visits because we began a search and then stopped it before any playouts happened. //In that case, there's not much to recompute. if(weightSum <= 0.0) { - assert(numVisits == 0); - assert(isRoot); + //It's also possible that a suppressed virtual loss edge visit on a multi-move chain + //causes the parent to have 0 visits... somehow??? + if(searchParams.suppressVirtualLossExploreFactor >= 1e10) { + assert(numVisits == 0); + assert(isRoot); + } } else { double resultUtility = getResultUtility(winLossValueAvg, noResultValueAvg); @@ -992,7 +996,7 @@ void Search::recursivelyRecomputeStats(SearchNode& n) { } else { //Otherwise recompute it using the usual method - recomputeNodeStats(*node, thread, 0, isRoot); + recomputeNodeStats(*node, thread, isRoot); } }; @@ -1068,12 +1072,11 @@ void Search::computeRootValues() { } } - -bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) { +PlayoutResult Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) { //Store this value, used for futile-visit pruning this thread's root children selections. thread.upperBoundVisitsLeft = upperBoundVisitsLeft; - bool finishedPlayout = playoutDescend(thread,*rootNode,true); + PlayoutResult playoutResult = playoutDescend(thread,*rootNode,true); //Restore thread state back to the root state thread.pla = rootPla; @@ -1082,10 +1085,10 @@ bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) thread.graphHash = rootGraphHash; thread.graphPath.clear(); - return finishedPlayout; + return playoutResult; } -bool Search::playoutDescend( +PlayoutResult Search::playoutDescend( SearchThread& thread, SearchNode& node, bool isRoot ) { @@ -1107,7 +1110,7 @@ bool Search::playoutDescend( double lead = 0.0; double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0; addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false); - return true; + return PLAYOUT_SUCCESS; } else { double winLossValue = 2.0 * ScoreValue::whiteWinsOfWinner(thread.history.winner, searchParams.drawEquivalentWinsForWhite) - 1; @@ -1117,7 +1120,7 @@ bool Search::playoutDescend( double lead = scoreMean; double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0; addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false); - return true; + return PLAYOUT_SUCCESS; } } @@ -1129,25 +1132,25 @@ bool Search::playoutDescend( //Leave the node as unevaluated - only the thread that first actually set the nnOutput into the node //gets to update the state, to avoid races where we update the state while the node stats aren't updated yet. if(!suc) - return false; + return PLAYOUT_FAILED; } bool suc = node.state.compare_exchange_strong(nodeState, SearchNode::STATE_EVALUATING, std::memory_order_seq_cst); if(!suc) { //Presumably someone else got there first. //Just give up on this playout and try again from the start. - return false; + return PLAYOUT_FAILED; } else { //Perform the nn evaluation and finish! node.initializeChildren(); node.state.store(SearchNode::STATE_EXPANDED0, std::memory_order_seq_cst); - return true; + return PLAYOUT_SUCCESS; } } else if(nodeState == SearchNode::STATE_EVALUATING) { //Just give up on this playout and try again from the start. - return false; + return PLAYOUT_FAILED; } assert(nodeState >= SearchNode::STATE_EXPANDED0); @@ -1157,10 +1160,12 @@ bool Search::playoutDescend( int numChildrenFound; int bestChildIdx; Loc bestChildMoveLoc; + bool suppressEdgeVisit; + double suppressEdgeVisitUtilityThreshold; SearchNode* child = NULL; while(true) { - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot); //The absurdly rare case that the move chosen is not legal //(this should only happen either on a bug or where the nnHash doesn't have full legality information or when there's an actual hash collision). @@ -1189,7 +1194,7 @@ bool Search::playoutDescend( //As isReInit is true, we don't return, just keep going, since we didn't count this as a true visit in the node stats nodeState = node.state.load(std::memory_order_acquire); - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot); if(bestChildIdx >= 0) { //New child @@ -1198,7 +1203,7 @@ bool Search::playoutDescend( //against someone reInitializing the output to add dirichlet noise or something, who was doing so based on an older cached //nnOutput that still had the illegal move. If so, then just fail this playout and try again. if(!thread.history.isLegal(thread.board,bestChildMoveLoc,thread.pla)) - return false; + return PLAYOUT_FAILED; } //Existing child else { @@ -1209,7 +1214,7 @@ bool Search::playoutDescend( assert(childrenCapacity > bestChildIdx); (void)childrenCapacity; children[bestChildIdx].addEdgeVisits(1); - return true; + return PLAYOUT_SUCCESS; } } } @@ -1218,7 +1223,7 @@ bool Search::playoutDescend( //This might happen if all moves have been forbidden. The node will just get stuck counting visits without expanding //and we won't do any search. addCurrentNNOutputAsLeafValue(node,false); - return true; + return PLAYOUT_SUCCESS; } //Do we think we are searching a new child for the first time? @@ -1275,16 +1280,21 @@ bool Search::playoutDescend( //Even if the node was newly allocated, no need to delete the node, it will get cleaned up next time we mark and sweep the node table later. //Clean up virtual losses in case the node is a transposition and is being used. child->virtualLosses.fetch_add(-1,std::memory_order_release); - return false; + return PLAYOUT_FAILED; } } //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and treat that as a visit. - if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) { + + if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return PLAYOUT_FAILED; + } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } } //Searching an existing child @@ -1297,10 +1307,14 @@ bool Search::playoutDescend( //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and treat that as a visit. - if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) { + if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return PLAYOUT_FAILED; + } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } //Make the move! @@ -1323,36 +1337,59 @@ bool Search::playoutDescend( //No insertion, child was already there if(!result.second) { SearchNodeChildrenReference children = node.getChildren(nodeState); + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return PLAYOUT_FAILED; + } children[bestChildIdx].addEdgeVisits(1); updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } } //Recurse! - bool finishedPlayout = playoutDescend(thread,*child,false); + PlayoutResult childPlayoutResult = playoutDescend(thread,*child,false); + PlayoutResult ourPlayoutResult = PLAYOUT_FAILED; + //Update this node stats - if(finishedPlayout) { - nodeState = node.state.load(std::memory_order_acquire); - SearchNodeChildrenReference children = node.getChildren(nodeState); - children[bestChildIdx].addEdgeVisits(1); + if(childPlayoutResult == PLAYOUT_NOINCREMENT || childPlayoutResult == PLAYOUT_SUCCESS) { + if(searchParams.suppressVirtualLossHindsight) { + double childUtilityAvg = node.stats.utilityAvg.load(std::memory_order_acquire); + if(node.nextPla == P_WHITE) + suppressEdgeVisit = childUtilityAvg < suppressEdgeVisitUtilityThreshold; + else + suppressEdgeVisit = childUtilityAvg > suppressEdgeVisitUtilityThreshold; + } + + if(!suppressEdgeVisit) { + nodeState = node.state.load(std::memory_order_acquire); + SearchNodeChildrenReference children = node.getChildren(nodeState); + children[bestChildIdx].addEdgeVisits(1); + ourPlayoutResult = PLAYOUT_SUCCESS; + } + else { + ourPlayoutResult = PLAYOUT_NOINCREMENT; + } updateStatsAfterPlayout(node,thread,isRoot); } child->virtualLosses.fetch_add(-1,std::memory_order_release); - return finishedPlayout; + return ourPlayoutResult; } //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and return immediately. +//Returns true if we do perform a catch up edge visit, OR if the child visits is already sufficient but suppressEdgeVisit +//is true. In other words, returns true when we can terminate the playout and false when we need to go deeper. bool Search::maybeCatchUpEdgeVisits( SearchThread& thread, SearchNode& node, SearchNode* child, const SearchNodeState& nodeState, - const int bestChildIdx + const int bestChildIdx, + bool suppressEdgeVisit ) { //Don't need to do this since we already are pretty recent as of finding the best child. //nodeState = node.state.load(std::memory_order_acquire); @@ -1373,15 +1410,20 @@ bool Search::maybeCatchUpEdgeVisits( if(searchParams.graphSearchCatchUpLeakProb > 0.0 && edgeVisits < childVisits && thread.rand.nextBool(searchParams.graphSearchCatchUpLeakProb)) return false; + if(edgeVisits >= childVisits) + return false; + if(suppressEdgeVisit) + return !searchParams.suppressVirtualLossLeakCatchUp; + //If the edge visits exceeds the child then we need to search the child more, but as long as that's not the case, //we can add more edge visits. constexpr int64_t numToAdd = 1; // int64_t numToAdd; - do { + while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)) { if(edgeVisits >= childVisits) return false; // numToAdd = std::min((childVisits - edgeVisits + 3) / 4, maxNumToAdd); - } while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)); + } return true; } diff --git a/cpp/search/search.h b/cpp/search/search.h index 5f7bc426f..929ea3be2 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -24,6 +24,7 @@ #include "../external/nlohmann_json/json.hpp" typedef int SearchNodeState; // See SearchNode::STATE_* +typedef int PlayoutResult; struct SearchNode; struct SearchThread; @@ -253,9 +254,13 @@ struct Search { //Without performing a whole search, recompute the root nn output for any root-level parameters. void maybeRecomputeRootNNOutput(); + static constexpr PlayoutResult PLAYOUT_FAILED = 0; + static constexpr PlayoutResult PLAYOUT_SUCCESS = 1; + static constexpr PlayoutResult PLAYOUT_NOINCREMENT = 2; + //Expert manual playout-by-playout interface void beginSearch(bool pondering); - bool runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft); + PlayoutResult runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft); //================================================================================================================ // SEARCH RESULTS AND TREE INSPECTION METHODS @@ -450,6 +455,7 @@ struct Search { ) const; void maybeApplyAntiMirrorForcedExplore( double& childUtility, + double& childUtilityNoVL, const double parentUtility, const Loc moveLoc, const float* policyProbs, @@ -511,14 +517,29 @@ struct Search { // Move selection during search // searchexplorehelpers.cpp //---------------------------------------------------------------------------------------- + struct ExploreInfo { + double exploreSelectionValue; + double valueComponentNoVL; // no virtual loss + double exploreComponent; + + static inline ExploreInfo constantSelectionValue(double d) { + ExploreInfo info; + info.exploreSelectionValue = d; + info.valueComponentNoVL = d; + info.exploreComponent = 0.0; + return info; + } + }; + double getExploreScaling( double totalChildWeight, double parentUtilityStdevFactor ) const; - double getExploreSelectionValue( + ExploreInfo getExploreSelectionValue( double exploreScaling, double nnPolicyProb, double childWeight, double childUtility, + double childUtilityNoVL, Player pla ) const; double getExploreSelectionValueInverse( @@ -528,7 +549,7 @@ struct Search { double childUtility, Player pla ) const; - double getExploreSelectionValueOfChild( + ExploreInfo getExploreSelectionValueOfChild( const SearchNode& parent, const float* parentPolicyProbs, const SearchNode* child, Loc moveLoc, double exploreScaling, @@ -536,7 +557,7 @@ struct Search { double parentUtility, double parentWeightPerVisit, bool isDuringSearch, bool antiMirror, double maxChildWeight, SearchThread* thread ) const; - double getNewExploreSelectionValue( + ExploreInfo getNewExploreSelectionValue( const SearchNode& parent, double exploreScaling, float nnPolicyProb, @@ -557,9 +578,14 @@ struct Search { double& parentUtility, double& parentWeightPerVisit, double& parentUtilityStdevFactor ) const; + // suppressEdgeVisit is filled in with whether we think we want to suppress this edge visit + // due to over-exploration. + // suppressEdgeVisitUtilityThreshold is the value such that if child utility improves past + // this value, then we do want to keep the visit void selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, + bool& suppressEdgeVisit, double& suppressEdgeVisitUtilityThreshold, bool isRoot ) const; @@ -584,7 +610,7 @@ struct Search { double computeWeightFromNNOutput(const NNOutput* nnOutput) const; void updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, bool isRoot); - void recomputeNodeStats(SearchNode& node, SearchThread& thread, int32_t numVisitsToAdd, bool isRoot); + void recomputeNodeStats(SearchNode& node, SearchThread& thread, bool isRoot); void downweightBadChildrenAndNormalizeWeight( int numChildren, @@ -616,7 +642,7 @@ struct Search { void computeRootValues(); // Helper for begin search void recursivelyRecomputeStats(SearchNode& node); // Helper for search initialization - bool playoutDescend( + PlayoutResult playoutDescend( SearchThread& thread, SearchNode& node, bool isRoot ); @@ -626,7 +652,8 @@ struct Search { SearchNode& node, SearchNode* child, const SearchNodeState& nodeState, - const int bestChildIdx + const int bestChildIdx, + bool suppressEdgeVisit ); //---------------------------------------------------------------------------------------- diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 06d00a20b..6eb52ad3b 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -24,22 +24,29 @@ double Search::getExploreScaling( * parentUtilityStdevFactor; } -double Search::getExploreSelectionValue( +Search::ExploreInfo Search::getExploreSelectionValue( double exploreScaling, double nnPolicyProb, double childWeight, double childUtility, + double childUtilityNoVL, Player pla ) const { if(nnPolicyProb < 0) - return POLICY_ILLEGAL_SELECTION_VALUE; + return ExploreInfo::constantSelectionValue(POLICY_ILLEGAL_SELECTION_VALUE); double exploreComponent = exploreScaling * nnPolicyProb / (1.0 + childWeight); //At the last moment, adjust value to be from the player's perspective, so that players prefer values in their favor //rather than in white's favor double valueComponent = pla == P_WHITE ? childUtility : -childUtility; - return exploreComponent + valueComponent; + double valueComponentNoVL = pla == P_WHITE ? childUtilityNoVL : -childUtilityNoVL; + + Search::ExploreInfo info; + info.exploreSelectionValue = exploreComponent + valueComponent; + info.valueComponentNoVL = valueComponentNoVL; + info.exploreComponent = exploreComponent; + return info; } //Return the childWeight that would make Search::getExploreSelectionValue return the given explore selection value. @@ -87,7 +94,7 @@ static void maybeApplyWideRootNoise( } -double Search::getExploreSelectionValueOfChild( +Search::ExploreInfo Search::getExploreSelectionValueOfChild( const SearchNode& parent, const float* parentPolicyProbs, const SearchNode* child, Loc moveLoc, double exploreScaling, @@ -123,6 +130,7 @@ double Search::getExploreSelectionValueOfChild( } //Virtual losses to direct threads down different paths + double childUtilityNoVL = childUtility; if(childVirtualLosses > 0) { double virtualLossWeight = childVirtualLosses * searchParams.numVirtualLossesPerThread; @@ -143,12 +151,12 @@ double Search::getExploreSelectionValueOfChild( double averageVisitsPerWeight = (childEdgeVisits + 1.0) / (childWeight + parentWeightPerVisit); double estimatedRequiredVisits = requiredWeight * averageVisitsPerWeight; if(childVisits + thread->upperBoundVisitsLeft < estimatedRequiredVisits) - return FUTILE_VISITS_PRUNE_VALUE; + return ExploreInfo::constantSelectionValue(FUTILE_VISITS_PRUNE_VALUE); } //Hack to get the root to funnel more visits down child branches if(searchParams.rootDesiredPerChildVisitsCoeff > 0.0) { if(nnPolicyProb > 0 && childWeight < sqrt(nnPolicyProb * totalChildWeight * searchParams.rootDesiredPerChildVisitsCoeff)) { - return 1e20; + return ExploreInfo::constantSelectionValue(1e20); } } //Hack for hintloc - must search this move almost as often as the most searched move @@ -164,23 +172,24 @@ double Search::getExploreSelectionValueOfChild( int64_t cEdgeVisits = childPointer.getEdgeVisits(); double cWeight = c->stats.getChildWeight(cEdgeVisits); if(childWeight + averageWeightPerVisit < cWeight * 0.8) - return 1e20; + return ExploreInfo::constantSelectionValue(1e20); } } if(searchParams.wideRootNoise > 0.0 && nnPolicyProb >= 0) { + // Does NOT update childUtilityNoVL - wide root noise is also ignored. maybeApplyWideRootNoise(childUtility, nnPolicyProb, searchParams, thread, parent); } } if(isDuringSearch && antiMirror && nnPolicyProb >= 0) { maybeApplyAntiMirrorPolicy(nnPolicyProb, moveLoc, parentPolicyProbs, parent.nextPla, thread); - maybeApplyAntiMirrorForcedExplore(childUtility, parentUtility, moveLoc, parentPolicyProbs, childWeight, totalChildWeight, parent.nextPla, thread, parent); + maybeApplyAntiMirrorForcedExplore(childUtility, childUtilityNoVL, parentUtility, moveLoc, parentPolicyProbs, childWeight, totalChildWeight, parent.nextPla, thread, parent); } - return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,parent.nextPla); + return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,childUtilityNoVL,parent.nextPla); } -double Search::getNewExploreSelectionValue( +Search::ExploreInfo Search::getNewExploreSelectionValue( const SearchNode& parent, double exploreScaling, float nnPolicyProb, @@ -198,13 +207,13 @@ double Search::getNewExploreSelectionValue( double requiredWeight = searchParams.futileVisitsThreshold * maxChildWeight; double estimatedRequiredVisits = requiredWeight * averageVisitsPerWeight; if(thread->upperBoundVisitsLeft < estimatedRequiredVisits) - return FUTILE_VISITS_PRUNE_VALUE; + return ExploreInfo::constantSelectionValue(FUTILE_VISITS_PRUNE_VALUE); } if(searchParams.wideRootNoise > 0.0) { maybeApplyWideRootNoise(childUtility, nnPolicyProb, searchParams, thread, parent); } } - return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,parent.nextPla); + return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,childUtility,parent.nextPla); } double Search::getReducedPlaySelectionWeight( @@ -305,13 +314,21 @@ double Search::getFpuValueForChildrenAssumeVisited( void Search::selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, + bool& suppressEdgeVisit, double& suppressEdgeVisitUtilityThreshold, bool isRoot) const { assert(thread.pla == node.nextPla); - double maxSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + double bestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + double bestChildExploreComponent = 0.0; + double bestChildValueComponentNoVL = 0.0; bestChildIdx = -1; bestChildMoveLoc = Board::NULL_LOC; + suppressEdgeVisit = false; + suppressEdgeVisitUtilityThreshold = thread.pla == P_WHITE ? -1e10 : 1e10; + + double noVLBestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + int noVLBestIdx = -1; ConstSearchNodeChildrenReference children = node.getChildren(nodeState); int childrenCapacity = children.getCapacity(); @@ -372,7 +389,7 @@ void Search::selectBestChildToDescend( Loc moveLoc = childPointer.getMoveLocRelaxed(); bool isDuringSearch = true; - double selectionValue = getExploreSelectionValueOfChild( + ExploreInfo info = getExploreSelectionValueOfChild( node,policyProbs,child, moveLoc, exploreScaling, @@ -380,17 +397,19 @@ void Search::selectBestChildToDescend( parentUtility,parentWeightPerVisit, isDuringSearch,antiMirror,maxChildWeight,&thread ); - if(selectionValue > maxSelectionValue) { - // if(child->state.load(std::memory_order_seq_cst) == SearchNode::STATE_EVALUATING) { - // selectionValue -= EVALUATING_SELECTION_VALUE_PENALTY; - // if(isRoot && child->prevMoveLoc == Location::ofString("K4",thread.board)) { - // out << "ouch" << "\n"; - // } - // } - maxSelectionValue = selectionValue; + double selectionValue = info.exploreSelectionValue; + if(selectionValue > bestSelectionValue) { + bestSelectionValue = selectionValue; + bestChildExploreComponent = info.exploreComponent; + bestChildValueComponentNoVL = info.valueComponentNoVL; bestChildIdx = i; bestChildMoveLoc = moveLoc; } + double noVLSelectionValue = info.valueComponentNoVL + info.exploreComponent; + if(noVLSelectionValue > noVLBestSelectionValue) { + noVLBestSelectionValue = noVLSelectionValue; + noVLBestIdx = i; + } posesWithChildBuf[getPos(moveLoc)] = true; } @@ -438,17 +457,37 @@ void Search::selectBestChildToDescend( } } if(bestNewMoveLoc != Board::NULL_LOC) { - double selectionValue = getNewExploreSelectionValue( + ExploreInfo info = getNewExploreSelectionValue( node, exploreScaling, bestNewNNPolicyProb,fpuValue, parentWeightPerVisit, maxChildWeight,&thread ); - if(selectionValue > maxSelectionValue) { - maxSelectionValue = selectionValue; + double selectionValue = info.exploreSelectionValue; + if(selectionValue > bestSelectionValue) { + bestSelectionValue = selectionValue; + bestChildExploreComponent = info.exploreComponent; + bestChildValueComponentNoVL = info.valueComponentNoVL; bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; } + double noVLSelectionValue = info.valueComponentNoVL + info.exploreComponent; + if(noVLSelectionValue > noVLBestSelectionValue) { + noVLBestSelectionValue = noVLSelectionValue; + noVLBestIdx = numChildrenFound; + } + } + + if(searchParams.suppressVirtualLossExploreFactor < 1e10 && noVLBestIdx != bestChildIdx) { + // Compute the selection value if we uesd a much larger explore factor but with no virtual loss + // If that's still not good enough, suppress visit. + double expandedChildSelectionValue = bestChildValueComponentNoVL + bestChildExploreComponent * searchParams.suppressVirtualLossExploreFactor; + double gap = noVLBestSelectionValue - expandedChildSelectionValue; + // Taking advantage of the fact that the value component is just the sign-adjusted utility + suppressEdgeVisitUtilityThreshold = (thread.pla == P_WHITE) ? bestChildValueComponentNoVL+gap : -bestChildValueComponentNoVL-gap; + if(gap > 0) { + suppressEdgeVisit = true; + } } } diff --git a/cpp/search/searchmirror.cpp b/cpp/search/searchmirror.cpp index 9c474c4ec..b837e4c33 100644 --- a/cpp/search/searchmirror.cpp +++ b/cpp/search/searchmirror.cpp @@ -155,6 +155,7 @@ void Search::maybeApplyAntiMirrorPolicy( //to have bad values, and also tolerate us playing certain countering moves even if their values are a bit worse. void Search::maybeApplyAntiMirrorForcedExplore( double& childUtility, + double& childUtilityNoVL, const double parentUtility, const Loc moveLoc, const float* policyProbs, @@ -219,15 +220,18 @@ void Search::maybeApplyAntiMirrorForcedExplore( proportionToBias /= mirrorCenterSymmetryError; } + double utilityToAdd = 0.0; if(thisChildWeight < proportionToDump * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 100.0 : -100.0); + utilityToAdd += (parent.nextPla == P_WHITE ? 100.0 : -100.0); } if(thisChildWeight < proportionToBias * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 0.18 : -0.18) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + utilityToAdd += (parent.nextPla == P_WHITE ? 0.18 : -0.18) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); } if(thisChildWeight < 0.5 * proportionToBias * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 0.36 : -0.36) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + utilityToAdd += (parent.nextPla == P_WHITE ? 0.36 : -0.36) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); } + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; } } //Encourage us to find refuting moves, even if they look a little bad, in the difficult case @@ -236,8 +240,10 @@ void Search::maybeApplyAntiMirrorForcedExplore( double proportionToDump = 0.0; if(isDifficult) { if(thread->board.isAdjacentToChain(moveLoc,centerLoc)) { - childUtility += (parent.nextPla == P_WHITE ? 0.75 : -0.75) / (1.0 + thread->board.getNumLiberties(centerLoc)) + double utilityToAdd = (parent.nextPla == P_WHITE ? 0.75 : -0.75) / (1.0 + thread->board.getNumLiberties(centerLoc)) / std::max(1.0,mirrorCenterSymmetryError) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; proportionToDump = 0.10 / thread->board.getNumLiberties(centerLoc); } int distanceSq = Location::euclideanDistanceSquared(moveLoc,centerLoc,xSize); @@ -273,7 +279,9 @@ void Search::maybeApplyAntiMirrorForcedExplore( } if(thisChildWeight < proportionToDump * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 100.0 : -100.0); + double utilityToAdd = (parent.nextPla == P_WHITE ? 100.0 : -100.0); + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; } } } diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index b86946ce9..8f30f1b6b 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -77,6 +77,9 @@ SearchParams::SearchParams() subtreeValueBiasWeightExponent(0.5), nodeTableShardsPowerOfTwo(16), numVirtualLossesPerThread(3.0), + suppressVirtualLossExploreFactor(1e10), + suppressVirtualLossHindsight(false), + suppressVirtualLossLeakCatchUp(false), numThreads(1), minPlayoutsPerThread(0.0), maxVisits(((int64_t)1) << 50), @@ -306,6 +309,9 @@ void SearchParams::printParams(std::ostream& out) { PRINTPARAM(nodeTableShardsPowerOfTwo); PRINTPARAM(numVirtualLossesPerThread); + PRINTPARAM(suppressVirtualLossExploreFactor); + PRINTPARAM(suppressVirtualLossHindsight); + PRINTPARAM(suppressVirtualLossLeakCatchUp); PRINTPARAM(numThreads); diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index 4da6c3cc6..151d2dd2c 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -110,6 +110,9 @@ struct SearchParams { //Threading-related int nodeTableShardsPowerOfTwo; //Controls number of shards of node table for graph search transposition lookup double numVirtualLossesPerThread; //Number of virtual losses for one thread to add + double suppressVirtualLossExploreFactor; //Suppress edge visit if virtual loss or wide root noise explores child, but scaling cpuct by this factor wouldn't explore it. + bool suppressVirtualLossHindsight; //Suppression of edge visit uses the hindsight value on the child + bool suppressVirtualLossLeakCatchUp; //When suppressing edge visits, if child visits > edge visits, visit the child anyways and make it even greater. //Asyncbot int numThreads; //Number of threads diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 2bad4875b..e56ef9406 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -164,7 +164,7 @@ bool Search::getPlaySelectionValues( totalChildWeight,bestChildEdgeVisits,fpuValue, parentUtility,parentWeightPerVisit, isDuringSearch,false,nonLCBBestChildWeight,NULL - ); + ).exploreSelectionValue; for(int i = 0; i 0) return; - int32_t numVisitsCompleted = 1; + int32_t numThreadsCompleted = 1; while(true) { //Perform update - recomputeNodeStats(node,thread,numVisitsCompleted,isRoot); + recomputeNodeStats(node,thread,isRoot); //Now attempt to undo the counter - oldDirtyCounter = node.dirtyCounter.fetch_add(-numVisitsCompleted,std::memory_order_acq_rel); - int32_t newDirtyCounter = oldDirtyCounter - numVisitsCompleted; + oldDirtyCounter = node.dirtyCounter.fetch_add(-numThreadsCompleted,std::memory_order_acq_rel); + int32_t newDirtyCounter = oldDirtyCounter - numThreadsCompleted; //If no other threads incremented it in the meantime, so our decrement hits zero, we're done. if(newDirtyCounter <= 0) { assert(newDirtyCounter == 0); break; } - //Otherwise, more threads incremented this more in the meantime. So we need to loop again and add their visits, recomputing again. - numVisitsCompleted = newDirtyCounter; + //Otherwise, more threads incremented this more in the meantime. So we need to loop, recomputing again. + numThreadsCompleted = newDirtyCounter; continue; } } @@ -148,7 +148,7 @@ void Search::updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, boo //Recompute all the stats of this node based on its children, except its visits and virtual losses, which are not child-dependent and //are updated in the manner specified. //Assumes this node has an nnOutput -void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numVisitsToAdd, bool isRoot) { +void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, bool isRoot) { //Find all children and compute weighting of the children based on their values vector& statsBuf = thread.statsBuf; int numGoodChildren = 0; @@ -156,6 +156,7 @@ void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numV ConstSearchNodeChildrenReference children = node.getChildren(); int childrenCapacity = children.getCapacity(); double origTotalChildWeight = 0.0; + int64_t thisNodeVisitsSum = 0; for(int i = 0; i