From 3c55dfc28fe3b2b656bdc9527702049da5890145 Mon Sep 17 00:00:00 2001 From: David Wu Date: Mon, 27 Nov 2023 01:06:49 -0500 Subject: [PATCH 1/5] Partial work on suppress vl visits --- cpp/search/search.cpp | 5 +- cpp/search/search.h | 23 +++++++-- cpp/search/searchexplorehelpers.cpp | 77 +++++++++++++++++++---------- cpp/search/searchmirror.cpp | 18 +++++-- cpp/search/searchparams.cpp | 2 + cpp/search/searchparams.h | 1 + cpp/search/searchresults.cpp | 2 +- 7 files changed, 92 insertions(+), 36 deletions(-) diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index b5653ff74..44657f6e5 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -1157,10 +1157,11 @@ bool Search::playoutDescend( int numChildrenFound; int bestChildIdx; Loc bestChildMoveLoc; + bool suppressEdgeVisit; //TODO use this in the update logic SearchNode* child = NULL; while(true) { - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,isRoot); //The absurdly rare case that the move chosen is not legal //(this should only happen either on a bug or where the nnHash doesn't have full legality information or when there's an actual hash collision). @@ -1189,7 +1190,7 @@ bool Search::playoutDescend( //As isReInit is true, we don't return, just keep going, since we didn't count this as a true visit in the node stats nodeState = node.state.load(std::memory_order_acquire); - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,isRoot); if(bestChildIdx >= 0) { //New child diff --git a/cpp/search/search.h b/cpp/search/search.h index 5f7bc426f..6e6ef1c9c 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -450,6 +450,7 @@ struct Search { ) const; void maybeApplyAntiMirrorForcedExplore( double& childUtility, + double& childUtilityNoVL, const double parentUtility, const Loc moveLoc, const float* policyProbs, @@ -511,14 +512,29 @@ struct Search { // Move selection during search // searchexplorehelpers.cpp //---------------------------------------------------------------------------------------- + struct ExploreInfo { + double exploreSelectionValue; + double exploreSelectionValueNoVL; // no virtual loss + double exploreComponent; + + static inline ExploreInfo constantSelectionValue(double d) { + ExploreInfo info; + info.exploreSelectionValue = d; + info.exploreSelectionValueNoVL = d; + info.exploreComponent = 0.0; + return info; + } + }; + double getExploreScaling( double totalChildWeight, double parentUtilityStdevFactor ) const; - double getExploreSelectionValue( + ExploreInfo getExploreSelectionValue( double exploreScaling, double nnPolicyProb, double childWeight, double childUtility, + double childUtilityNoVL, Player pla ) const; double getExploreSelectionValueInverse( @@ -528,7 +544,7 @@ struct Search { double childUtility, Player pla ) const; - double getExploreSelectionValueOfChild( + ExploreInfo getExploreSelectionValueOfChild( const SearchNode& parent, const float* parentPolicyProbs, const SearchNode* child, Loc moveLoc, double exploreScaling, @@ -536,7 +552,7 @@ struct Search { double parentUtility, double parentWeightPerVisit, bool isDuringSearch, bool antiMirror, double maxChildWeight, SearchThread* thread ) const; - double getNewExploreSelectionValue( + ExploreInfo getNewExploreSelectionValue( const SearchNode& parent, double exploreScaling, float nnPolicyProb, @@ -560,6 +576,7 @@ struct Search { void selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, + bool& suppressEdgeVisit, bool isRoot ) const; diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 06d00a20b..b6ff91368 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -24,22 +24,29 @@ double Search::getExploreScaling( * parentUtilityStdevFactor; } -double Search::getExploreSelectionValue( +Search::ExploreInfo Search::getExploreSelectionValue( double exploreScaling, double nnPolicyProb, double childWeight, double childUtility, + double childUtilityNoVL, Player pla ) const { if(nnPolicyProb < 0) - return POLICY_ILLEGAL_SELECTION_VALUE; + return ExploreInfo::constantSelectionValue(POLICY_ILLEGAL_SELECTION_VALUE); double exploreComponent = exploreScaling * nnPolicyProb / (1.0 + childWeight); //At the last moment, adjust value to be from the player's perspective, so that players prefer values in their favor //rather than in white's favor double valueComponent = pla == P_WHITE ? childUtility : -childUtility; - return exploreComponent + valueComponent; + double valueComponentNoVL = pla == P_WHITE ? childUtilityNoVL : -childUtilityNoVL; + + Search::ExploreInfo info; + info.exploreSelectionValue = exploreComponent + valueComponent; + info.exploreSelectionValueNoVL = exploreComponent + valueComponentNoVL; + info.exploreComponent = exploreComponent; + return info; } //Return the childWeight that would make Search::getExploreSelectionValue return the given explore selection value. @@ -87,7 +94,7 @@ static void maybeApplyWideRootNoise( } -double Search::getExploreSelectionValueOfChild( +Search::ExploreInfo Search::getExploreSelectionValueOfChild( const SearchNode& parent, const float* parentPolicyProbs, const SearchNode* child, Loc moveLoc, double exploreScaling, @@ -123,6 +130,7 @@ double Search::getExploreSelectionValueOfChild( } //Virtual losses to direct threads down different paths + double childUtilityNoVL = childUtility; if(childVirtualLosses > 0) { double virtualLossWeight = childVirtualLosses * searchParams.numVirtualLossesPerThread; @@ -143,12 +151,12 @@ double Search::getExploreSelectionValueOfChild( double averageVisitsPerWeight = (childEdgeVisits + 1.0) / (childWeight + parentWeightPerVisit); double estimatedRequiredVisits = requiredWeight * averageVisitsPerWeight; if(childVisits + thread->upperBoundVisitsLeft < estimatedRequiredVisits) - return FUTILE_VISITS_PRUNE_VALUE; + return ExploreInfo::constantSelectionValue(FUTILE_VISITS_PRUNE_VALUE); } //Hack to get the root to funnel more visits down child branches if(searchParams.rootDesiredPerChildVisitsCoeff > 0.0) { if(nnPolicyProb > 0 && childWeight < sqrt(nnPolicyProb * totalChildWeight * searchParams.rootDesiredPerChildVisitsCoeff)) { - return 1e20; + return ExploreInfo::constantSelectionValue(1e20); } } //Hack for hintloc - must search this move almost as often as the most searched move @@ -164,23 +172,25 @@ double Search::getExploreSelectionValueOfChild( int64_t cEdgeVisits = childPointer.getEdgeVisits(); double cWeight = c->stats.getChildWeight(cEdgeVisits); if(childWeight + averageWeightPerVisit < cWeight * 0.8) - return 1e20; + return ExploreInfo::constantSelectionValue(1e20); } } if(searchParams.wideRootNoise > 0.0 && nnPolicyProb >= 0) { + // TODO is this desired? + // Does NOT update childUtilityNoVL - wide root noise is also ignored. maybeApplyWideRootNoise(childUtility, nnPolicyProb, searchParams, thread, parent); } } if(isDuringSearch && antiMirror && nnPolicyProb >= 0) { maybeApplyAntiMirrorPolicy(nnPolicyProb, moveLoc, parentPolicyProbs, parent.nextPla, thread); - maybeApplyAntiMirrorForcedExplore(childUtility, parentUtility, moveLoc, parentPolicyProbs, childWeight, totalChildWeight, parent.nextPla, thread, parent); + maybeApplyAntiMirrorForcedExplore(childUtility, childUtilityNoVL, parentUtility, moveLoc, parentPolicyProbs, childWeight, totalChildWeight, parent.nextPla, thread, parent); } - return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,parent.nextPla); + return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,childUtilityNoVL,parent.nextPla); } -double Search::getNewExploreSelectionValue( +Search::ExploreInfo Search::getNewExploreSelectionValue( const SearchNode& parent, double exploreScaling, float nnPolicyProb, @@ -198,13 +208,13 @@ double Search::getNewExploreSelectionValue( double requiredWeight = searchParams.futileVisitsThreshold * maxChildWeight; double estimatedRequiredVisits = requiredWeight * averageVisitsPerWeight; if(thread->upperBoundVisitsLeft < estimatedRequiredVisits) - return FUTILE_VISITS_PRUNE_VALUE; + return ExploreInfo::constantSelectionValue(FUTILE_VISITS_PRUNE_VALUE); } if(searchParams.wideRootNoise > 0.0) { maybeApplyWideRootNoise(childUtility, nnPolicyProb, searchParams, thread, parent); } } - return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,parent.nextPla); + return getExploreSelectionValue(exploreScaling,nnPolicyProb,childWeight,childUtility,childUtility,parent.nextPla); } double Search::getReducedPlaySelectionWeight( @@ -305,13 +315,18 @@ double Search::getFpuValueForChildrenAssumeVisited( void Search::selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, + bool& suppressEdgeVisit, bool isRoot) const { assert(thread.pla == node.nextPla); - double maxSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + double bestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + double bestChildExploreComponent = 0.0; bestChildIdx = -1; bestChildMoveLoc = Board::NULL_LOC; + suppressEdgeVisit = false; + + double bestSelectionValueNoVL = POLICY_ILLEGAL_SELECTION_VALUE; ConstSearchNodeChildrenReference children = node.getChildren(nodeState); int childrenCapacity = children.getCapacity(); @@ -372,7 +387,7 @@ void Search::selectBestChildToDescend( Loc moveLoc = childPointer.getMoveLocRelaxed(); bool isDuringSearch = true; - double selectionValue = getExploreSelectionValueOfChild( + ExploreInfo info = getExploreSelectionValueOfChild( node,policyProbs,child, moveLoc, exploreScaling, @@ -380,17 +395,17 @@ void Search::selectBestChildToDescend( parentUtility,parentWeightPerVisit, isDuringSearch,antiMirror,maxChildWeight,&thread ); - if(selectionValue > maxSelectionValue) { - // if(child->state.load(std::memory_order_seq_cst) == SearchNode::STATE_EVALUATING) { - // selectionValue -= EVALUATING_SELECTION_VALUE_PENALTY; - // if(isRoot && child->prevMoveLoc == Location::ofString("K4",thread.board)) { - // out << "ouch" << "\n"; - // } - // } - maxSelectionValue = selectionValue; + double selectionValue = info.exploreSelectionValue; + if(selectionValue > bestSelectionValue) { + bestSelectionValue = selectionValue; + bestChildExploreComponent = info.exploreComponent; bestChildIdx = i; bestChildMoveLoc = moveLoc; } + double selectionValueNoVL = info.exploreSelectionValueNoVL; + if(selectionValueNoVL > bestSelectionValueNoVL) { + bestSelectionValueNoVL = selectionValueNoVL; + } posesWithChildBuf[getPos(moveLoc)] = true; } @@ -438,17 +453,29 @@ void Search::selectBestChildToDescend( } } if(bestNewMoveLoc != Board::NULL_LOC) { - double selectionValue = getNewExploreSelectionValue( + ExploreInfo info = getNewExploreSelectionValue( node, exploreScaling, bestNewNNPolicyProb,fpuValue, parentWeightPerVisit, maxChildWeight,&thread ); - if(selectionValue > maxSelectionValue) { - maxSelectionValue = selectionValue; + double selectionValue = info.exploreSelectionValue; + if(selectionValue > bestSelectionValue) { + bestSelectionValue = selectionValue; + bestChildExploreComponent = info.exploreComponent; bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; } + double selectionValueNoVL = info.exploreSelectionValueNoVL; + if(selectionValueNoVL > bestSelectionValueNoVL) { + bestSelectionValueNoVL = selectionValueNoVL; + } + } + + if(searchParams.suppressVirtualLossExploreFactor < 1e10) { + if(bestSelectionValue + bestChildExploreComponent * (searchParams.suppressVirtualLossExploreFactor-1.0) < bestSelectionValueNoVL) { + suppressEdgeVisit = true; + } } } diff --git a/cpp/search/searchmirror.cpp b/cpp/search/searchmirror.cpp index 9c474c4ec..b837e4c33 100644 --- a/cpp/search/searchmirror.cpp +++ b/cpp/search/searchmirror.cpp @@ -155,6 +155,7 @@ void Search::maybeApplyAntiMirrorPolicy( //to have bad values, and also tolerate us playing certain countering moves even if their values are a bit worse. void Search::maybeApplyAntiMirrorForcedExplore( double& childUtility, + double& childUtilityNoVL, const double parentUtility, const Loc moveLoc, const float* policyProbs, @@ -219,15 +220,18 @@ void Search::maybeApplyAntiMirrorForcedExplore( proportionToBias /= mirrorCenterSymmetryError; } + double utilityToAdd = 0.0; if(thisChildWeight < proportionToDump * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 100.0 : -100.0); + utilityToAdd += (parent.nextPla == P_WHITE ? 100.0 : -100.0); } if(thisChildWeight < proportionToBias * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 0.18 : -0.18) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + utilityToAdd += (parent.nextPla == P_WHITE ? 0.18 : -0.18) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); } if(thisChildWeight < 0.5 * proportionToBias * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 0.36 : -0.36) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + utilityToAdd += (parent.nextPla == P_WHITE ? 0.36 : -0.36) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); } + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; } } //Encourage us to find refuting moves, even if they look a little bad, in the difficult case @@ -236,8 +240,10 @@ void Search::maybeApplyAntiMirrorForcedExplore( double proportionToDump = 0.0; if(isDifficult) { if(thread->board.isAdjacentToChain(moveLoc,centerLoc)) { - childUtility += (parent.nextPla == P_WHITE ? 0.75 : -0.75) / (1.0 + thread->board.getNumLiberties(centerLoc)) + double utilityToAdd = (parent.nextPla == P_WHITE ? 0.75 : -0.75) / (1.0 + thread->board.getNumLiberties(centerLoc)) / std::max(1.0,mirrorCenterSymmetryError) * std::max(0.3, 1.0 - 0.7 * parentUtility * parentUtility); + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; proportionToDump = 0.10 / thread->board.getNumLiberties(centerLoc); } int distanceSq = Location::euclideanDistanceSquared(moveLoc,centerLoc,xSize); @@ -273,7 +279,9 @@ void Search::maybeApplyAntiMirrorForcedExplore( } if(thisChildWeight < proportionToDump * totalChildWeight) { - childUtility += (parent.nextPla == P_WHITE ? 100.0 : -100.0); + double utilityToAdd = (parent.nextPla == P_WHITE ? 100.0 : -100.0); + childUtility += utilityToAdd; + childUtilityNoVL += utilityToAdd; } } } diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index b86946ce9..77afd5a7f 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -77,6 +77,7 @@ SearchParams::SearchParams() subtreeValueBiasWeightExponent(0.5), nodeTableShardsPowerOfTwo(16), numVirtualLossesPerThread(3.0), + suppressVirtualLossExploreFactor(1e10), numThreads(1), minPlayoutsPerThread(0.0), maxVisits(((int64_t)1) << 50), @@ -306,6 +307,7 @@ void SearchParams::printParams(std::ostream& out) { PRINTPARAM(nodeTableShardsPowerOfTwo); PRINTPARAM(numVirtualLossesPerThread); + PRINTPARAM(suppressVirtualLossExploreFactor); PRINTPARAM(numThreads); diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index 4da6c3cc6..b1d8296f3 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -110,6 +110,7 @@ struct SearchParams { //Threading-related int nodeTableShardsPowerOfTwo; //Controls number of shards of node table for graph search transposition lookup double numVirtualLossesPerThread; //Number of virtual losses for one thread to add + double suppressVirtualLossExploreFactor; //Asyncbot int numThreads; //Number of threads diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 2bad4875b..e56ef9406 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -164,7 +164,7 @@ bool Search::getPlaySelectionValues( totalChildWeight,bestChildEdgeVisits,fpuValue, parentUtility,parentWeightPerVisit, isDuringSearch,false,nonLCBBestChildWeight,NULL - ); + ).exploreSelectionValue; for(int i = 0; i Date: Fri, 8 Dec 2023 23:58:42 -0500 Subject: [PATCH 2/5] Suppress vl visits and also nonincremental visits with slight avoid accum change --- cpp/program/setup.cpp | 3 +++ cpp/search/search.cpp | 39 +++++++++++++++++++++------- cpp/search/search.h | 5 ++-- cpp/search/searchexplorehelpers.cpp | 6 +++-- cpp/search/searchparams.h | 2 +- cpp/search/searchupdatehelpers.cpp | 24 ++++++++++------- cpp/tests/results/runOutputTests.txt | 26 ++++++++++++------- 7 files changed, 73 insertions(+), 32 deletions(-) diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 4141e2184..6da363848 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -676,6 +676,9 @@ vector Setup::loadParams( if(cfg.contains("numVirtualLossesPerThread"+idxStr)) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread"+idxStr, 0.01, 1000.0); else if(cfg.contains("numVirtualLossesPerThread")) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread", 0.01, 1000.0); else params.numVirtualLossesPerThread = 1.0; + if(cfg.contains("suppressVirtualLossExploreFactor"+idxStr)) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor"+idxStr, 1.0, 1e10); + else if(cfg.contains("suppressVirtualLossExploreFactor")) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor", 1.0, 1e10); + else params.suppressVirtualLossExploreFactor = 1e10; if(cfg.contains("treeReuseCarryOverTimeFactor"+idxStr)) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor"+idxStr,0.0,1.0); else if(cfg.contains("treeReuseCarryOverTimeFactor")) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor",0.0,1.0); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index 44657f6e5..2b333e9c8 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -747,7 +747,7 @@ void Search::beginSearch(bool pondering) { node.statsLock.clear(std::memory_order_release); //Update all other stats - recomputeNodeStats(node, dummyThread, 0, true); + recomputeNodeStats(node, dummyThread, true); } } @@ -992,7 +992,7 @@ void Search::recursivelyRecomputeStats(SearchNode& n) { } else { //Otherwise recompute it using the usual method - recomputeNodeStats(*node, thread, 0, isRoot); + recomputeNodeStats(*node, thread, isRoot); } }; @@ -1157,7 +1157,7 @@ bool Search::playoutDescend( int numChildrenFound; int bestChildIdx; Loc bestChildMoveLoc; - bool suppressEdgeVisit; //TODO use this in the update logic + bool suppressEdgeVisit; SearchNode* child = NULL; while(true) { @@ -1282,7 +1282,12 @@ bool Search::playoutDescend( //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and treat that as a visit. - if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) { + + if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return false; + } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); return true; @@ -1298,7 +1303,11 @@ bool Search::playoutDescend( //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and treat that as a visit. - if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) { + if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return false; + } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); return true; @@ -1324,6 +1333,10 @@ bool Search::playoutDescend( //No insertion, child was already there if(!result.second) { SearchNodeChildrenReference children = node.getChildren(nodeState); + if(suppressEdgeVisit) { + child->virtualLosses.fetch_add(-1,std::memory_order_release); + return false; + } children[bestChildIdx].addEdgeVisits(1); updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); @@ -1335,9 +1348,11 @@ bool Search::playoutDescend( bool finishedPlayout = playoutDescend(thread,*child,false); //Update this node stats if(finishedPlayout) { - nodeState = node.state.load(std::memory_order_acquire); - SearchNodeChildrenReference children = node.getChildren(nodeState); - children[bestChildIdx].addEdgeVisits(1); + if(!suppressEdgeVisit) { + nodeState = node.state.load(std::memory_order_acquire); + SearchNodeChildrenReference children = node.getChildren(nodeState); + children[bestChildIdx].addEdgeVisits(1); + } updateStatsAfterPlayout(node,thread,isRoot); } child->virtualLosses.fetch_add(-1,std::memory_order_release); @@ -1348,12 +1363,15 @@ bool Search::playoutDescend( //If edge visits is too much smaller than the child's visits, we can avoid descending. //Instead just add edge visits and return immediately. +//Returns true if we do perform a catch up edge visit, OR if the child visits is already sufficient but suppressEdgeVisit +//is true. In other words, returns true when we can terminate the playout and false when we need to go deeper. bool Search::maybeCatchUpEdgeVisits( SearchThread& thread, SearchNode& node, SearchNode* child, const SearchNodeState& nodeState, - const int bestChildIdx + const int bestChildIdx, + bool suppressEdgeVisit ) { //Don't need to do this since we already are pretty recent as of finding the best child. //nodeState = node.state.load(std::memory_order_acquire); @@ -1374,6 +1392,9 @@ bool Search::maybeCatchUpEdgeVisits( if(searchParams.graphSearchCatchUpLeakProb > 0.0 && edgeVisits < childVisits && thread.rand.nextBool(searchParams.graphSearchCatchUpLeakProb)) return false; + if(suppressEdgeVisit) + return true; + //If the edge visits exceeds the child then we need to search the child more, but as long as that's not the case, //we can add more edge visits. constexpr int64_t numToAdd = 1; diff --git a/cpp/search/search.h b/cpp/search/search.h index 6e6ef1c9c..1219a77f5 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -601,7 +601,7 @@ struct Search { double computeWeightFromNNOutput(const NNOutput* nnOutput) const; void updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, bool isRoot); - void recomputeNodeStats(SearchNode& node, SearchThread& thread, int32_t numVisitsToAdd, bool isRoot); + void recomputeNodeStats(SearchNode& node, SearchThread& thread, bool isRoot); void downweightBadChildrenAndNormalizeWeight( int numChildren, @@ -643,7 +643,8 @@ struct Search { SearchNode& node, SearchNode* child, const SearchNodeState& nodeState, - const int bestChildIdx + const int bestChildIdx, + bool suppressEdgeVisit ); //---------------------------------------------------------------------------------------- diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index b6ff91368..533c26e2e 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -177,7 +177,6 @@ Search::ExploreInfo Search::getExploreSelectionValueOfChild( } if(searchParams.wideRootNoise > 0.0 && nnPolicyProb >= 0) { - // TODO is this desired? // Does NOT update childUtilityNoVL - wide root noise is also ignored. maybeApplyWideRootNoise(childUtility, nnPolicyProb, searchParams, thread, parent); } @@ -322,6 +321,7 @@ void Search::selectBestChildToDescend( double bestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; double bestChildExploreComponent = 0.0; + double bestChildSelectionValueNoVL = 0.0; bestChildIdx = -1; bestChildMoveLoc = Board::NULL_LOC; suppressEdgeVisit = false; @@ -399,6 +399,7 @@ void Search::selectBestChildToDescend( if(selectionValue > bestSelectionValue) { bestSelectionValue = selectionValue; bestChildExploreComponent = info.exploreComponent; + bestChildSelectionValueNoVL = info.exploreSelectionValueNoVL; bestChildIdx = i; bestChildMoveLoc = moveLoc; } @@ -464,6 +465,7 @@ void Search::selectBestChildToDescend( if(selectionValue > bestSelectionValue) { bestSelectionValue = selectionValue; bestChildExploreComponent = info.exploreComponent; + bestChildSelectionValueNoVL = info.exploreSelectionValueNoVL; bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; } @@ -474,7 +476,7 @@ void Search::selectBestChildToDescend( } if(searchParams.suppressVirtualLossExploreFactor < 1e10) { - if(bestSelectionValue + bestChildExploreComponent * (searchParams.suppressVirtualLossExploreFactor-1.0) < bestSelectionValueNoVL) { + if(bestChildSelectionValueNoVL + bestChildExploreComponent * (searchParams.suppressVirtualLossExploreFactor-1.0) < bestSelectionValueNoVL) { suppressEdgeVisit = true; } } diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index b1d8296f3..60cfea84b 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -110,7 +110,7 @@ struct SearchParams { //Threading-related int nodeTableShardsPowerOfTwo; //Controls number of shards of node table for graph search transposition lookup double numVirtualLossesPerThread; //Number of virtual losses for one thread to add - double suppressVirtualLossExploreFactor; + double suppressVirtualLossExploreFactor; //Suppress edge visit if virtual loss or wide root noise explores child, but scaling cpuct by this factor wouldn't explore it. //Asyncbot int numThreads; //Number of threads diff --git a/cpp/search/searchupdatehelpers.cpp b/cpp/search/searchupdatehelpers.cpp index ff39b208c..ed9e5a328 100644 --- a/cpp/search/searchupdatehelpers.cpp +++ b/cpp/search/searchupdatehelpers.cpp @@ -127,20 +127,20 @@ void Search::updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, boo //If we atomically grab a nonzero, then we know another thread must already be doing the work, so we can skip the update ourselves. if(oldDirtyCounter > 0) return; - int32_t numVisitsCompleted = 1; + int32_t numThreadsCompleted = 1; while(true) { //Perform update - recomputeNodeStats(node,thread,numVisitsCompleted,isRoot); + recomputeNodeStats(node,thread,isRoot); //Now attempt to undo the counter - oldDirtyCounter = node.dirtyCounter.fetch_add(-numVisitsCompleted,std::memory_order_acq_rel); - int32_t newDirtyCounter = oldDirtyCounter - numVisitsCompleted; + oldDirtyCounter = node.dirtyCounter.fetch_add(-numThreadsCompleted,std::memory_order_acq_rel); + int32_t newDirtyCounter = oldDirtyCounter - numThreadsCompleted; //If no other threads incremented it in the meantime, so our decrement hits zero, we're done. if(newDirtyCounter <= 0) { assert(newDirtyCounter == 0); break; } - //Otherwise, more threads incremented this more in the meantime. So we need to loop again and add their visits, recomputing again. - numVisitsCompleted = newDirtyCounter; + //Otherwise, more threads incremented this more in the meantime. So we need to loop, recomputing again. + numThreadsCompleted = newDirtyCounter; continue; } } @@ -148,7 +148,7 @@ void Search::updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, boo //Recompute all the stats of this node based on its children, except its visits and virtual losses, which are not child-dependent and //are updated in the manner specified. //Assumes this node has an nnOutput -void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numVisitsToAdd, bool isRoot) { +void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, bool isRoot) { //Find all children and compute weighting of the children based on their values vector& statsBuf = thread.statsBuf; int numGoodChildren = 0; @@ -156,6 +156,7 @@ void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numV ConstSearchNodeChildrenReference children = node.getChildren(); int childrenCapacity = children.getCapacity(); double origTotalChildWeight = 0.0; + int64_t thisNodeVisitsSum = 0; for(int i = 0; i Date: Sat, 9 Dec 2023 01:54:19 -0500 Subject: [PATCH 3/5] TODO on assert --- cpp/search/search.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index 2b333e9c8..78009a773 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -974,8 +974,12 @@ void Search::recursivelyRecomputeStats(SearchNode& n) { //and has 0 visits because we began a search and then stopped it before any playouts happened. //In that case, there's not much to recompute. if(weightSum <= 0.0) { - assert(numVisits == 0); - assert(isRoot); + //It's also possible that a suppressed virtual loss edge visit on a multi-move chain + //causes the parent to have 0 visits... somehow??? + if(searchParams.suppressVirtualLossExploreFactor >= 1e10) { + assert(numVisits == 0); + assert(isRoot); + } } else { double resultUtility = getResultUtility(winLossValueAvg, noResultValueAvg); From 56890b2918987474ab6c010bf1163618478f6c7f Mon Sep 17 00:00:00 2001 From: David Wu Date: Fri, 15 Dec 2023 22:46:19 -0500 Subject: [PATCH 4/5] Implement hindsight version of suppress vl --- cpp/program/setup.cpp | 3 +++ cpp/search/search.cpp | 12 +++++++-- cpp/search/search.h | 10 +++++--- cpp/search/searchexplorehelpers.cpp | 38 ++++++++++++++++++----------- cpp/search/searchparams.cpp | 2 ++ cpp/search/searchparams.h | 1 + 6 files changed, 47 insertions(+), 19 deletions(-) diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 6da363848..853b2b979 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -679,6 +679,9 @@ vector Setup::loadParams( if(cfg.contains("suppressVirtualLossExploreFactor"+idxStr)) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor"+idxStr, 1.0, 1e10); else if(cfg.contains("suppressVirtualLossExploreFactor")) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor", 1.0, 1e10); else params.suppressVirtualLossExploreFactor = 1e10; + if(cfg.contains("suppressVirtualLossHindsight"+idxStr)) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"+idxStr); + else if(cfg.contains("suppressVirtualLossHindsight")) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"); + else params.suppressVirtualLossHindsight = false; if(cfg.contains("treeReuseCarryOverTimeFactor"+idxStr)) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor"+idxStr,0.0,1.0); else if(cfg.contains("treeReuseCarryOverTimeFactor")) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor",0.0,1.0); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index 78009a773..be0914f2c 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -1162,10 +1162,11 @@ bool Search::playoutDescend( int bestChildIdx; Loc bestChildMoveLoc; bool suppressEdgeVisit; + double suppressEdgeVisitUtilityThreshold; SearchNode* child = NULL; while(true) { - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot); //The absurdly rare case that the move chosen is not legal //(this should only happen either on a bug or where the nnHash doesn't have full legality information or when there's an actual hash collision). @@ -1194,7 +1195,7 @@ bool Search::playoutDescend( //As isReInit is true, we don't return, just keep going, since we didn't count this as a true visit in the node stats nodeState = node.state.load(std::memory_order_acquire); - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot); if(bestChildIdx >= 0) { //New child @@ -1352,6 +1353,13 @@ bool Search::playoutDescend( bool finishedPlayout = playoutDescend(thread,*child,false); //Update this node stats if(finishedPlayout) { + if(searchParams.suppressVirtualLossHindsight) { + double childUtilityAvg = node.stats.utilityAvg.load(std::memory_order_acquire); + if(node.nextPla == P_WHITE) + suppressEdgeVisit = childUtilityAvg < suppressEdgeVisitUtilityThreshold; + else + suppressEdgeVisit = childUtilityAvg > suppressEdgeVisitUtilityThreshold; + } if(!suppressEdgeVisit) { nodeState = node.state.load(std::memory_order_acquire); SearchNodeChildrenReference children = node.getChildren(nodeState); diff --git a/cpp/search/search.h b/cpp/search/search.h index 1219a77f5..5e8282f91 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -514,13 +514,13 @@ struct Search { //---------------------------------------------------------------------------------------- struct ExploreInfo { double exploreSelectionValue; - double exploreSelectionValueNoVL; // no virtual loss + double valueComponentNoVL; // no virtual loss double exploreComponent; static inline ExploreInfo constantSelectionValue(double d) { ExploreInfo info; info.exploreSelectionValue = d; - info.exploreSelectionValueNoVL = d; + info.valueComponentNoVL = d; info.exploreComponent = 0.0; return info; } @@ -573,10 +573,14 @@ struct Search { double& parentUtility, double& parentWeightPerVisit, double& parentUtilityStdevFactor ) const; + // suppressEdgeVisit is filled in with whether we think we want to suppress this edge visit + // due to over-exploration. + // suppressEdgeVisitUtilityThreshold is the value such that if child utility improves past + // this value, then we do want to keep the visit void selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, - bool& suppressEdgeVisit, + bool& suppressEdgeVisit, double& suppressEdgeVisitUtilityThreshold, bool isRoot ) const; diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 533c26e2e..6eb52ad3b 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -44,7 +44,7 @@ Search::ExploreInfo Search::getExploreSelectionValue( Search::ExploreInfo info; info.exploreSelectionValue = exploreComponent + valueComponent; - info.exploreSelectionValueNoVL = exploreComponent + valueComponentNoVL; + info.valueComponentNoVL = valueComponentNoVL; info.exploreComponent = exploreComponent; return info; } @@ -314,19 +314,21 @@ double Search::getFpuValueForChildrenAssumeVisited( void Search::selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, - bool& suppressEdgeVisit, + bool& suppressEdgeVisit, double& suppressEdgeVisitUtilityThreshold, bool isRoot) const { assert(thread.pla == node.nextPla); double bestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; double bestChildExploreComponent = 0.0; - double bestChildSelectionValueNoVL = 0.0; + double bestChildValueComponentNoVL = 0.0; bestChildIdx = -1; bestChildMoveLoc = Board::NULL_LOC; suppressEdgeVisit = false; + suppressEdgeVisitUtilityThreshold = thread.pla == P_WHITE ? -1e10 : 1e10; - double bestSelectionValueNoVL = POLICY_ILLEGAL_SELECTION_VALUE; + double noVLBestSelectionValue = POLICY_ILLEGAL_SELECTION_VALUE; + int noVLBestIdx = -1; ConstSearchNodeChildrenReference children = node.getChildren(nodeState); int childrenCapacity = children.getCapacity(); @@ -399,13 +401,14 @@ void Search::selectBestChildToDescend( if(selectionValue > bestSelectionValue) { bestSelectionValue = selectionValue; bestChildExploreComponent = info.exploreComponent; - bestChildSelectionValueNoVL = info.exploreSelectionValueNoVL; + bestChildValueComponentNoVL = info.valueComponentNoVL; bestChildIdx = i; bestChildMoveLoc = moveLoc; } - double selectionValueNoVL = info.exploreSelectionValueNoVL; - if(selectionValueNoVL > bestSelectionValueNoVL) { - bestSelectionValueNoVL = selectionValueNoVL; + double noVLSelectionValue = info.valueComponentNoVL + info.exploreComponent; + if(noVLSelectionValue > noVLBestSelectionValue) { + noVLBestSelectionValue = noVLSelectionValue; + noVLBestIdx = i; } posesWithChildBuf[getPos(moveLoc)] = true; @@ -465,18 +468,25 @@ void Search::selectBestChildToDescend( if(selectionValue > bestSelectionValue) { bestSelectionValue = selectionValue; bestChildExploreComponent = info.exploreComponent; - bestChildSelectionValueNoVL = info.exploreSelectionValueNoVL; + bestChildValueComponentNoVL = info.valueComponentNoVL; bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; } - double selectionValueNoVL = info.exploreSelectionValueNoVL; - if(selectionValueNoVL > bestSelectionValueNoVL) { - bestSelectionValueNoVL = selectionValueNoVL; + double noVLSelectionValue = info.valueComponentNoVL + info.exploreComponent; + if(noVLSelectionValue > noVLBestSelectionValue) { + noVLBestSelectionValue = noVLSelectionValue; + noVLBestIdx = numChildrenFound; } } - if(searchParams.suppressVirtualLossExploreFactor < 1e10) { - if(bestChildSelectionValueNoVL + bestChildExploreComponent * (searchParams.suppressVirtualLossExploreFactor-1.0) < bestSelectionValueNoVL) { + if(searchParams.suppressVirtualLossExploreFactor < 1e10 && noVLBestIdx != bestChildIdx) { + // Compute the selection value if we uesd a much larger explore factor but with no virtual loss + // If that's still not good enough, suppress visit. + double expandedChildSelectionValue = bestChildValueComponentNoVL + bestChildExploreComponent * searchParams.suppressVirtualLossExploreFactor; + double gap = noVLBestSelectionValue - expandedChildSelectionValue; + // Taking advantage of the fact that the value component is just the sign-adjusted utility + suppressEdgeVisitUtilityThreshold = (thread.pla == P_WHITE) ? bestChildValueComponentNoVL+gap : -bestChildValueComponentNoVL-gap; + if(gap > 0) { suppressEdgeVisit = true; } } diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index 77afd5a7f..1e7eb881d 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -78,6 +78,7 @@ SearchParams::SearchParams() nodeTableShardsPowerOfTwo(16), numVirtualLossesPerThread(3.0), suppressVirtualLossExploreFactor(1e10), + suppressVirtualLossHindsight(false), numThreads(1), minPlayoutsPerThread(0.0), maxVisits(((int64_t)1) << 50), @@ -308,6 +309,7 @@ void SearchParams::printParams(std::ostream& out) { PRINTPARAM(nodeTableShardsPowerOfTwo); PRINTPARAM(numVirtualLossesPerThread); PRINTPARAM(suppressVirtualLossExploreFactor); + PRINTPARAM(suppressVirtualLossHindsight); PRINTPARAM(numThreads); diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index 60cfea84b..e5f2fc0fa 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -111,6 +111,7 @@ struct SearchParams { int nodeTableShardsPowerOfTwo; //Controls number of shards of node table for graph search transposition lookup double numVirtualLossesPerThread; //Number of virtual losses for one thread to add double suppressVirtualLossExploreFactor; //Suppress edge visit if virtual loss or wide root noise explores child, but scaling cpuct by this factor wouldn't explore it. + bool suppressVirtualLossHindsight; //Suppression of edge visit uses the hindsight value on the child //Asyncbot int numThreads; //Number of threads From 86344de53e49a1b1bacbdadd3e289f1013782909 Mon Sep 17 00:00:00 2001 From: David Wu Date: Fri, 15 Dec 2023 23:32:17 -0500 Subject: [PATCH 5/5] Bugfix for virtual loss suppression not performing the child visit and add leakage option --- cpp/program/setup.cpp | 3 ++ cpp/search/search.cpp | 66 +++++++++++++++++++++---------------- cpp/search/search.h | 9 +++-- cpp/search/searchparams.cpp | 2 ++ cpp/search/searchparams.h | 1 + 5 files changed, 50 insertions(+), 31 deletions(-) diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 853b2b979..ea78cb74a 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -682,6 +682,9 @@ vector Setup::loadParams( if(cfg.contains("suppressVirtualLossHindsight"+idxStr)) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"+idxStr); else if(cfg.contains("suppressVirtualLossHindsight")) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"); else params.suppressVirtualLossHindsight = false; + if(cfg.contains("suppressVirtualLossLeakCatchUp"+idxStr)) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp"+idxStr); + else if(cfg.contains("suppressVirtualLossLeakCatchUp")) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp"); + else params.suppressVirtualLossLeakCatchUp = false; if(cfg.contains("treeReuseCarryOverTimeFactor"+idxStr)) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor"+idxStr,0.0,1.0); else if(cfg.contains("treeReuseCarryOverTimeFactor")) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor",0.0,1.0); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index be0914f2c..25537471c 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -554,8 +554,8 @@ void Search::runWholeSearch( upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxPlayouts - numPlayouts); upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxVisits - numPlayouts - numNonPlayoutVisits); - bool finishedPlayout = runSinglePlayout(*stbuf, upperBoundVisitsLeft); - if(finishedPlayout) { + PlayoutResult playoutResult = runSinglePlayout(*stbuf, upperBoundVisitsLeft); + if(playoutResult == PLAYOUT_SUCCESS) { numPlayouts = numPlayoutsShared.fetch_add((int64_t)1, std::memory_order_relaxed); numPlayouts += 1; } @@ -1072,12 +1072,11 @@ void Search::computeRootValues() { } } - -bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) { +PlayoutResult Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) { //Store this value, used for futile-visit pruning this thread's root children selections. thread.upperBoundVisitsLeft = upperBoundVisitsLeft; - bool finishedPlayout = playoutDescend(thread,*rootNode,true); + PlayoutResult playoutResult = playoutDescend(thread,*rootNode,true); //Restore thread state back to the root state thread.pla = rootPla; @@ -1086,10 +1085,10 @@ bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) thread.graphHash = rootGraphHash; thread.graphPath.clear(); - return finishedPlayout; + return playoutResult; } -bool Search::playoutDescend( +PlayoutResult Search::playoutDescend( SearchThread& thread, SearchNode& node, bool isRoot ) { @@ -1111,7 +1110,7 @@ bool Search::playoutDescend( double lead = 0.0; double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0; addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false); - return true; + return PLAYOUT_SUCCESS; } else { double winLossValue = 2.0 * ScoreValue::whiteWinsOfWinner(thread.history.winner, searchParams.drawEquivalentWinsForWhite) - 1; @@ -1121,7 +1120,7 @@ bool Search::playoutDescend( double lead = scoreMean; double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0; addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false); - return true; + return PLAYOUT_SUCCESS; } } @@ -1133,25 +1132,25 @@ bool Search::playoutDescend( //Leave the node as unevaluated - only the thread that first actually set the nnOutput into the node //gets to update the state, to avoid races where we update the state while the node stats aren't updated yet. if(!suc) - return false; + return PLAYOUT_FAILED; } bool suc = node.state.compare_exchange_strong(nodeState, SearchNode::STATE_EVALUATING, std::memory_order_seq_cst); if(!suc) { //Presumably someone else got there first. //Just give up on this playout and try again from the start. - return false; + return PLAYOUT_FAILED; } else { //Perform the nn evaluation and finish! node.initializeChildren(); node.state.store(SearchNode::STATE_EXPANDED0, std::memory_order_seq_cst); - return true; + return PLAYOUT_SUCCESS; } } else if(nodeState == SearchNode::STATE_EVALUATING) { //Just give up on this playout and try again from the start. - return false; + return PLAYOUT_FAILED; } assert(nodeState >= SearchNode::STATE_EXPANDED0); @@ -1204,7 +1203,7 @@ bool Search::playoutDescend( //against someone reInitializing the output to add dirichlet noise or something, who was doing so based on an older cached //nnOutput that still had the illegal move. If so, then just fail this playout and try again. if(!thread.history.isLegal(thread.board,bestChildMoveLoc,thread.pla)) - return false; + return PLAYOUT_FAILED; } //Existing child else { @@ -1215,7 +1214,7 @@ bool Search::playoutDescend( assert(childrenCapacity > bestChildIdx); (void)childrenCapacity; children[bestChildIdx].addEdgeVisits(1); - return true; + return PLAYOUT_SUCCESS; } } } @@ -1224,7 +1223,7 @@ bool Search::playoutDescend( //This might happen if all moves have been forbidden. The node will just get stuck counting visits without expanding //and we won't do any search. addCurrentNNOutputAsLeafValue(node,false); - return true; + return PLAYOUT_SUCCESS; } //Do we think we are searching a new child for the first time? @@ -1281,7 +1280,7 @@ bool Search::playoutDescend( //Even if the node was newly allocated, no need to delete the node, it will get cleaned up next time we mark and sweep the node table later. //Clean up virtual losses in case the node is a transposition and is being used. child->virtualLosses.fetch_add(-1,std::memory_order_release); - return false; + return PLAYOUT_FAILED; } } @@ -1291,11 +1290,11 @@ bool Search::playoutDescend( if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { if(suppressEdgeVisit) { child->virtualLosses.fetch_add(-1,std::memory_order_release); - return false; + return PLAYOUT_FAILED; } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } } //Searching an existing child @@ -1311,11 +1310,11 @@ bool Search::playoutDescend( if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) { if(suppressEdgeVisit) { child->virtualLosses.fetch_add(-1,std::memory_order_release); - return false; + return PLAYOUT_FAILED; } updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } //Make the move! @@ -1340,19 +1339,21 @@ bool Search::playoutDescend( SearchNodeChildrenReference children = node.getChildren(nodeState); if(suppressEdgeVisit) { child->virtualLosses.fetch_add(-1,std::memory_order_release); - return false; + return PLAYOUT_FAILED; } children[bestChildIdx].addEdgeVisits(1); updateStatsAfterPlayout(node,thread,isRoot); child->virtualLosses.fetch_add(-1,std::memory_order_release); - return true; + return PLAYOUT_SUCCESS; } } //Recurse! - bool finishedPlayout = playoutDescend(thread,*child,false); + PlayoutResult childPlayoutResult = playoutDescend(thread,*child,false); + PlayoutResult ourPlayoutResult = PLAYOUT_FAILED; + //Update this node stats - if(finishedPlayout) { + if(childPlayoutResult == PLAYOUT_NOINCREMENT || childPlayoutResult == PLAYOUT_SUCCESS) { if(searchParams.suppressVirtualLossHindsight) { double childUtilityAvg = node.stats.utilityAvg.load(std::memory_order_acquire); if(node.nextPla == P_WHITE) @@ -1360,16 +1361,21 @@ bool Search::playoutDescend( else suppressEdgeVisit = childUtilityAvg > suppressEdgeVisitUtilityThreshold; } + if(!suppressEdgeVisit) { nodeState = node.state.load(std::memory_order_acquire); SearchNodeChildrenReference children = node.getChildren(nodeState); children[bestChildIdx].addEdgeVisits(1); + ourPlayoutResult = PLAYOUT_SUCCESS; + } + else { + ourPlayoutResult = PLAYOUT_NOINCREMENT; } updateStatsAfterPlayout(node,thread,isRoot); } child->virtualLosses.fetch_add(-1,std::memory_order_release); - return finishedPlayout; + return ourPlayoutResult; } @@ -1404,18 +1410,20 @@ bool Search::maybeCatchUpEdgeVisits( if(searchParams.graphSearchCatchUpLeakProb > 0.0 && edgeVisits < childVisits && thread.rand.nextBool(searchParams.graphSearchCatchUpLeakProb)) return false; + if(edgeVisits >= childVisits) + return false; if(suppressEdgeVisit) - return true; + return !searchParams.suppressVirtualLossLeakCatchUp; //If the edge visits exceeds the child then we need to search the child more, but as long as that's not the case, //we can add more edge visits. constexpr int64_t numToAdd = 1; // int64_t numToAdd; - do { + while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)) { if(edgeVisits >= childVisits) return false; // numToAdd = std::min((childVisits - edgeVisits + 3) / 4, maxNumToAdd); - } while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)); + } return true; } diff --git a/cpp/search/search.h b/cpp/search/search.h index 5e8282f91..929ea3be2 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -24,6 +24,7 @@ #include "../external/nlohmann_json/json.hpp" typedef int SearchNodeState; // See SearchNode::STATE_* +typedef int PlayoutResult; struct SearchNode; struct SearchThread; @@ -253,9 +254,13 @@ struct Search { //Without performing a whole search, recompute the root nn output for any root-level parameters. void maybeRecomputeRootNNOutput(); + static constexpr PlayoutResult PLAYOUT_FAILED = 0; + static constexpr PlayoutResult PLAYOUT_SUCCESS = 1; + static constexpr PlayoutResult PLAYOUT_NOINCREMENT = 2; + //Expert manual playout-by-playout interface void beginSearch(bool pondering); - bool runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft); + PlayoutResult runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft); //================================================================================================================ // SEARCH RESULTS AND TREE INSPECTION METHODS @@ -637,7 +642,7 @@ struct Search { void computeRootValues(); // Helper for begin search void recursivelyRecomputeStats(SearchNode& node); // Helper for search initialization - bool playoutDescend( + PlayoutResult playoutDescend( SearchThread& thread, SearchNode& node, bool isRoot ); diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index 1e7eb881d..8f30f1b6b 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -79,6 +79,7 @@ SearchParams::SearchParams() numVirtualLossesPerThread(3.0), suppressVirtualLossExploreFactor(1e10), suppressVirtualLossHindsight(false), + suppressVirtualLossLeakCatchUp(false), numThreads(1), minPlayoutsPerThread(0.0), maxVisits(((int64_t)1) << 50), @@ -310,6 +311,7 @@ void SearchParams::printParams(std::ostream& out) { PRINTPARAM(numVirtualLossesPerThread); PRINTPARAM(suppressVirtualLossExploreFactor); PRINTPARAM(suppressVirtualLossHindsight); + PRINTPARAM(suppressVirtualLossLeakCatchUp); PRINTPARAM(numThreads); diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index e5f2fc0fa..151d2dd2c 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -112,6 +112,7 @@ struct SearchParams { double numVirtualLossesPerThread; //Number of virtual losses for one thread to add double suppressVirtualLossExploreFactor; //Suppress edge visit if virtual loss or wide root noise explores child, but scaling cpuct by this factor wouldn't explore it. bool suppressVirtualLossHindsight; //Suppression of edge visit uses the hindsight value on the child + bool suppressVirtualLossLeakCatchUp; //When suppressing edge visits, if child visits > edge visits, visit the child anyways and make it even greater. //Asyncbot int numThreads; //Number of threads