@@ -60,14 +60,6 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
6060 ucc_status_t status ;
6161 size_t extra_count ;
6262
63- uint32_t USE_CUDA = UCC_TL_UCP_TEAM_LIB (team )-> cfg .allgather_use_cuda ;
64- if (!USE_CUDA ){
65- if (UCC_INPROGRESS == ucc_tl_ucp_test (task )){
66- // should I use ucc_tl_ucp_test_with_etasks ?
67- return ;
68- }
69- }
70-
7163 EXEC_TASK_TEST (UCC_KN_PHASE_INIT , "failed during ee task test" ,
7264 task -> allgather_kn .etask );
7365 task -> allgather_kn .etask = NULL ;
@@ -208,6 +200,8 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
208200 ucc_status_t status ;
209201 ptrdiff_t offset ;
210202 ucc_ee_executor_t * exec ;
203+ int use_loopback = UCC_TL_UCP_TEAM_LIB (team )-> cfg .allgather_use_loopback ;
204+
211205
212206 UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_kn_start" , 0 );
213207 ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
@@ -218,29 +212,28 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
218212 & task -> allgather_kn .p );
219213 offset = ucc_buffer_block_offset (args -> dst .info .count , size , rank ) *
220214 ucc_dt_size (args -> dst .info .datatype );
221- if (USE_CUDA ){
222- status = ucc_coll_task_get_executor (& task -> super , & exec );
223- if (ucc_unlikely (status != UCC_OK )) {
224- task -> super .status = status ;
225- return status ;
226- }
227- eargs .task_type = UCC_EE_EXECUTOR_TASK_COPY ;
228- eargs .copy .dst = PTR_OFFSET (args -> dst .info .buffer , offset );
229- eargs .copy .src = args -> src .info .buffer ;
230- eargs .copy .len = args -> src .info .count *
231- ucc_dt_size (args -> src .info .datatype );
232- status = ucc_ee_executor_task_post (exec , & eargs ,
233- & task -> allgather_kn .etask );
234- if (ucc_unlikely (status != UCC_OK )) {
235- task -> super .status = status ;
236- return status ;
237- }
238- } else {
239- /*Loopback*/
215+ if (use_loopback ){
240216 UCPCHECK_GOTO (ucc_tl_ucp_send_nb (args -> src .info .buffer , args -> src .info .count * ucc_dt_size (args -> src .info .datatype ),
241- args -> src .info .mem_type , rank , team , task ),task , out );
217+ args -> src .info .mem_type , rank , team , task ),task , out );
242218 UCPCHECK_GOTO (ucc_tl_ucp_recv_nb (PTR_OFFSET (args -> dst .info .buffer , offset ), args -> src .info .count * ucc_dt_size (args -> src .info .datatype ),
243219 args -> dst .info .mem_type , rank , team , task ),task , out );
220+ } else {
221+ status = ucc_coll_task_get_executor (& task -> super , & exec );
222+ if (ucc_unlikely (status != UCC_OK )) {
223+ task -> super .status = status ;
224+ return status ;
225+ }
226+ eargs .task_type = UCC_EE_EXECUTOR_TASK_COPY ;
227+ eargs .copy .dst = PTR_OFFSET (args -> dst .info .buffer , offset );
228+ eargs .copy .src = args -> src .info .buffer ;
229+ eargs .copy .len = args -> src .info .count *
230+ ucc_dt_size (args -> src .info .datatype );
231+ status = ucc_ee_executor_task_post (exec , & eargs ,
232+ & task -> allgather_kn .etask );
233+ if (ucc_unlikely (status != UCC_OK )) {
234+ task -> super .status = status ;
235+ return status ;
236+ }
244237 }
245238 } else {
246239 ucc_kn_agx_pattern_init (size , rank , radix , args -> dst .info .count ,
@@ -418,12 +411,12 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
418411 ucc_base_coll_args_t * coll_args , ucc_base_team_t * team ,
419412 ucc_coll_task_t * * task_h , ucc_kn_radix_t radix )
420413{
421- ucc_tl_ucp_team_t * tl_team = ucc_derived_of (team , ucc_tl_ucp_team_t );
422- ucc_tl_ucp_task_t * task ;
414+ ucc_tl_ucp_team_t * tl_team = ucc_derived_of (team , ucc_tl_ucp_team_t );
415+ ucc_tl_ucp_task_t * task = ucc_tl_ucp_init_task ( coll_args , team ) ;
423416 ucc_sbgp_t * sbgp ;
424417 ucc_status_t status ;
425- int use_loopback = UCC_TL_UCP_TEAM_LIB ( team ) -> cfg . allgather_use_loopback ;
426- task = ucc_tl_ucp_init_task ( coll_args , team ) ;
418+ ucc_tl_ucp_team_t * team_loopback = TASK_TEAM ( task ) ;
419+ int use_loopback = UCC_TL_UCP_TEAM_LIB ( team_loopback ) -> cfg . allgather_use_loopback ;
427420 status = ucc_mpool_init (& task -> allgather_kn .etask_node_mpool , 0 , sizeof (node_ucc_ee_executor_task_t ),
428421 0 , UCC_CACHE_LINE_SIZE , 16 , UINT_MAX , NULL ,
429422 tl_team -> super .super .context -> ucc_context -> thread_mode , "etasks_linked_list_nodes" );
0 commit comments