diff --git a/src/components/tl/shm/barrier/barrier.c b/src/components/tl/shm/barrier/barrier.c index 97b919ab40..aaa97622fd 100644 --- a/src/components/tl/shm/barrier/barrier.c +++ b/src/components/tl/shm/barrier/barrier.c @@ -16,14 +16,13 @@ enum BARRIER_STAGE_TOP_TREE_FANOUT, }; -static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task) +static void ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); ucc_rank_t rank = UCC_TL_TEAM_RANK(team); ucc_tl_shm_seg_t * seg = task->seg; ucc_tl_shm_tree_t *tree = task->tree; - ucc_status_t status; ucc_tl_shm_ctrl_t *my_ctrl; next_stage: @@ -31,10 +30,8 @@ static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task) case BARRIER_STAGE_START: /* checks if previous collective has completed on the seg TODO: can be optimized if we detect barrier->reduce pattern.*/ - if (UCC_OK != ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, - team, tree)) { - return UCC_INPROGRESS; - } + SHMCHECK_GOTO(ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, + team, tree), task, out); if (tree->base_tree) { task->stage = BARRIER_STAGE_BASE_TREE_FANIN; } else { @@ -42,11 +39,8 @@ static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task) } goto next_stage; case BARRIER_STAGE_BASE_TREE_FANIN: - status = ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree); - if (UCC_OK != status) { - /* in progress */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree), + task, out); if (tree->top_tree) { task->stage = BARRIER_STAGE_TOP_TREE_FANIN; } else { @@ -55,56 +49,46 @@ static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task) } goto next_stage; case BARRIER_STAGE_TOP_TREE_FANIN: - status = ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree); - if (UCC_OK != status) { - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree), + task, out); task->stage = BARRIER_STAGE_TOP_TREE_FANOUT; task->seq_num++; /* finished fanin, need seq_num to be updated for fanout */ goto next_stage; case BARRIER_STAGE_TOP_TREE_FANOUT: - status = ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree); - if (UCC_OK != status) { - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree), + task, out); if (tree->base_tree) { task->stage = BARRIER_STAGE_BASE_TREE_FANOUT; goto next_stage; } break; case BARRIER_STAGE_BASE_TREE_FANOUT: - status = ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree); - if (UCC_OK != status) { - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree), + task, out); break; } my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank); - /* task->seq_num was updated between fanin and fanout, now needs to be rewinded to fit general collectives order, as barrier is actually a single collective */ + /* task->seq_num was updated between fanin and fanout, now needs to be + rewinded to fit general collectives order, as barrier is actually + a single collective */ my_ctrl->ci = task->seq_num - 1; /* barrier done */ - task->super.super.status = UCC_OK; + task->super.status = UCC_OK; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_barrier_progress_done", 0); - return UCC_OK; +out: + return; } static ucc_status_t ucc_tl_shm_barrier_start(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); - ucc_status_t status; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_barrier_start", 0); UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team); - task->super.super.status = UCC_INPROGRESS; - status = task->super.progress(&task->super); - - if (UCC_INPROGRESS == status) { - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - return UCC_OK; - } - return ucc_task_complete(coll_task); + task->super.status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } ucc_status_t ucc_tl_shm_barrier_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/shm/bcast/bcast.c b/src/components/tl/shm/bcast/bcast.c index 6f5bf533d5..de77179e5e 100644 --- a/src/components/tl/shm/bcast/bcast.c +++ b/src/components/tl/shm/bcast/bcast.c @@ -113,7 +113,7 @@ static ucc_status_t ucc_tl_shm_bcast_read(ucc_tl_shm_team_t *team, return UCC_INPROGRESS; } -static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) +static void ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); @@ -128,7 +128,6 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) int is_inline = data_size <= team->max_inline; int is_op_root = rank == root; ucc_tl_shm_ctrl_t *my_ctrl, *parent_ctrl; - ucc_status_t status; void * src; next_stage: @@ -138,10 +137,8 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) (tree->base_tree == NULL && tree->top_tree->n_children > 0)) { /* checks if previous collective has completed on the seg TODO: can be optimized if we detect bcast->reduce pattern.*/ - if (UCC_OK != ucc_tl_shm_bcast_seg_ready( - seg, task->seg_ready_seq_num, team, tree)) { - return UCC_INPROGRESS; - } + SHMCHECK_GOTO(ucc_tl_shm_bcast_seg_ready(seg, + task->seg_ready_seq_num, team, tree), task, out); } if (tree->top_tree) { task->stage = BCAST_STAGE_TOP_TREE; @@ -151,15 +148,11 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) goto next_stage; case BCAST_STAGE_TOP_TREE: if (task->progress_alg == BCAST_WW || task->progress_alg == BCAST_WR) { - status = ucc_tl_shm_bcast_write(team, seg, task, tree->top_tree, - is_inline, &is_op_root, data_size); + SHMCHECK_GOTO(ucc_tl_shm_bcast_write(team, seg, task, tree->top_tree, + is_inline, &is_op_root, data_size), task, out); } else { - status = ucc_tl_shm_bcast_read(team, seg, task, tree->top_tree, - is_inline, &is_op_root, data_size); - } - if (UCC_OK != status) { - /* in progress */ - return status; + SHMCHECK_GOTO(ucc_tl_shm_bcast_read(team, seg, task, tree->top_tree, + is_inline, &is_op_root, data_size), task, out); } if (tree->base_tree) { task->stage = BCAST_STAGE_BASE_TREE; @@ -168,15 +161,11 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) break; case BCAST_STAGE_BASE_TREE: if (task->progress_alg == BCAST_WW || task->progress_alg == BCAST_RW) { - status = ucc_tl_shm_bcast_write(team, seg, task, tree->base_tree, - is_inline, &is_op_root, data_size); + SHMCHECK_GOTO(ucc_tl_shm_bcast_write(team, seg, task, tree->base_tree, + is_inline, &is_op_root, data_size), task, out); } else { - status = ucc_tl_shm_bcast_read(team, seg, task, tree->base_tree, - is_inline, &is_op_root, data_size); - } - if (UCC_OK != status) { - /* in progress */ - return status; + SHMCHECK_GOTO(ucc_tl_shm_bcast_read(team, seg, task, tree->base_tree, + is_inline, &is_op_root, data_size), task, out); } break; } @@ -215,28 +204,22 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task) my_ctrl->ci = task->seq_num; /* bcast done */ - task->super.super.status = UCC_OK; + task->super.status = UCC_OK; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_bcast_rw_progress_done", 0); - return UCC_OK; +out: + return; } static ucc_status_t ucc_tl_shm_bcast_start(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); - ucc_status_t status; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_bcast_start", 0); UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team); - task->super.super.status = UCC_INPROGRESS; - status = task->super.progress(&task->super); - - if (UCC_INPROGRESS == status) { - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - return UCC_OK; - } - return ucc_task_complete(coll_task); + task->super.status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } ucc_status_t ucc_tl_shm_bcast_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/shm/fanin/fanin.c b/src/components/tl/shm/fanin/fanin.c index f240f8696c..a89f944708 100644 --- a/src/components/tl/shm/fanin/fanin.c +++ b/src/components/tl/shm/fanin/fanin.c @@ -14,14 +14,13 @@ enum FANIN_STAGE_TOP_TREE, }; -static ucc_status_t ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task) +static void ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); ucc_rank_t rank = UCC_TL_TEAM_RANK(team); ucc_tl_shm_seg_t * seg = task->seg; ucc_tl_shm_tree_t *tree = task->tree; - ucc_status_t status; ucc_tl_shm_ctrl_t *my_ctrl; next_stage: @@ -29,10 +28,8 @@ static ucc_status_t ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task) case FANIN_STAGE_START: /* checks if previous collective has completed on the seg TODO: can be optimized if we detect fanin->reduce pattern.*/ - if (UCC_OK != ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, - team, tree)) { - return UCC_INPROGRESS; - } + SHMCHECK_GOTO(ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, + team, tree), task, out); if (tree->base_tree) { task->stage = FANIN_STAGE_BASE_TREE; } else { @@ -40,51 +37,38 @@ static ucc_status_t ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task) } goto next_stage; case FANIN_STAGE_BASE_TREE: - status = ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree); - if (UCC_OK != status) { - /* in progress */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree), + task, out); if (tree->top_tree) { task->stage = FANIN_STAGE_TOP_TREE; goto next_stage; } break; case FANIN_STAGE_TOP_TREE: - status = ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree); - - if (UCC_OK != status) { - /* in progress */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree), + task, out); break; } my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank); my_ctrl->ci = task->seq_num; /* fanin done */ - task->super.super.status = UCC_OK; + task->super.status = UCC_OK; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanin_progress_done", 0); - return UCC_OK; +out: + return; } static ucc_status_t ucc_tl_shm_fanin_start(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); - ucc_status_t status; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanin_start", 0); UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team); - task->super.super.status = UCC_INPROGRESS; - status = task->super.progress(&task->super); - - if (UCC_INPROGRESS == status) { - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - return UCC_OK; - } - return ucc_task_complete(coll_task); -} + task->super.status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); + } ucc_status_t ucc_tl_shm_fanin_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * tl_team, diff --git a/src/components/tl/shm/fanout/fanout.c b/src/components/tl/shm/fanout/fanout.c index 2f5d77402e..581a6678b8 100644 --- a/src/components/tl/shm/fanout/fanout.c +++ b/src/components/tl/shm/fanout/fanout.c @@ -14,14 +14,13 @@ enum FANOUT_STAGE_TOP_TREE, }; -static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task) +static void ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); ucc_rank_t rank = UCC_TL_TEAM_RANK(team); ucc_tl_shm_seg_t * seg = task->seg; ucc_tl_shm_tree_t *tree = task->tree; - ucc_status_t status; ucc_tl_shm_ctrl_t *my_ctrl; next_stage: @@ -32,10 +31,8 @@ static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task) tree->top_tree->n_children > 0)) { //similar to bcast /* checks if previous collective has completed on the seg TODO: can be optimized if we detect bcast->reduce pattern.*/ - if (UCC_OK != ucc_tl_shm_bcast_seg_ready( - seg, task->seg_ready_seq_num, team, tree)) { - return UCC_INPROGRESS; - } + SHMCHECK_GOTO(ucc_tl_shm_bcast_seg_ready(seg, + task->seg_ready_seq_num, team, tree), task, out); } if (tree->top_tree) { task->stage = FANOUT_STAGE_TOP_TREE; @@ -44,50 +41,37 @@ static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task) } goto next_stage; case FANOUT_STAGE_TOP_TREE: - status = ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree); - if (UCC_OK != status) { - /* in progress */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree), + task, out); if (tree->base_tree) { task->stage = FANOUT_STAGE_BASE_TREE; goto next_stage; } break; case FANOUT_STAGE_BASE_TREE: - status = ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree); - - if (UCC_OK != status) { - /* in progress */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree), + task, out); break; } my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank); my_ctrl->ci = task->seq_num; /* fanout done */ - task->super.super.status = UCC_OK; + task->super.status = UCC_OK; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanout_progress_done", 0); - return UCC_OK; +out: + return; } static ucc_status_t ucc_tl_shm_fanout_start(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); - ucc_status_t status; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanout_start", 0); UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team); - task->super.super.status = UCC_INPROGRESS; - status = task->super.progress(&task->super); - - if (UCC_INPROGRESS == status) { - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - return UCC_OK; - } - return ucc_task_complete(coll_task); + task->super.status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } ucc_status_t ucc_tl_shm_fanout_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/shm/reduce/reduce.c b/src/components/tl/shm/reduce/reduce.c index 77592275ba..937cda3bbe 100644 --- a/src/components/tl/shm/reduce/reduce.c +++ b/src/components/tl/shm/reduce/reduce.c @@ -86,7 +86,7 @@ ucc_tl_shm_reduce_read(ucc_tl_shm_team_t *team, ucc_tl_shm_seg_t *seg, return UCC_OK; } -static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) +static void ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); @@ -101,7 +101,6 @@ static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) ucc_tl_shm_tree_t *tree = task->tree; int is_inline; int is_op_root = rank == root; - ucc_status_t status; ucc_tl_shm_ctrl_t *my_ctrl; if (is_op_root) { @@ -121,10 +120,8 @@ static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) case REDUCE_STAGE_START: /* checks if previous collective has completed on the seg TODO: can be optimized if we detect bcast->reduce pattern.*/ - if (UCC_OK != ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, - team, tree)) { - return UCC_INPROGRESS; - } + SHMCHECK_GOTO(ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num, + team, tree), task, out); if (tree->base_tree) { task->stage = REDUCE_STAGE_BASE_TREE; } else { @@ -132,13 +129,8 @@ static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) } goto next_stage; case REDUCE_STAGE_BASE_TREE: - status = ucc_tl_shm_reduce_read(team, seg, task, tree->base_tree, - is_inline, count, dt, mtype, &args); - - if (UCC_OK != status) { - /* in progress or reduction failed */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_reduce_read(team, seg, task, tree->base_tree, + is_inline, count, dt, mtype, &args), task, out); task->cur_child = 0; if (tree->top_tree) { task->stage = REDUCE_STAGE_TOP_TREE; @@ -146,12 +138,8 @@ static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) } break; case REDUCE_STAGE_TOP_TREE: - status = ucc_tl_shm_reduce_read(team, seg, task, tree->top_tree, - is_inline, count, dt, mtype, &args); - if (UCC_OK != status) { - /* in progress or reduction failed */ - return status; - } + SHMCHECK_GOTO(ucc_tl_shm_reduce_read(team, seg, task, tree->top_tree, + is_inline, count, dt, mtype, &args), task, out); break; } @@ -159,27 +147,21 @@ static ucc_status_t ucc_tl_shm_reduce_progress(ucc_coll_task_t *coll_task) my_ctrl->ci = task->seq_num; /* reduce done */ - task->super.super.status = UCC_OK; + task->super.status = UCC_OK; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_reduce_rr_done", 0); - return UCC_OK; +out: + return; } static ucc_status_t ucc_tl_shm_reduce_start(ucc_coll_task_t *coll_task) { ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t); ucc_tl_shm_team_t *team = TASK_TEAM(task); - ucc_status_t status; UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_reduce_start", 0); UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team); - task->super.super.status = UCC_INPROGRESS; - status = task->super.progress(&task->super); - - if (UCC_INPROGRESS == status) { - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - return UCC_OK; - } - return ucc_task_complete(coll_task); + task->super.status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } ucc_status_t ucc_tl_shm_reduce_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/shm/tl_shm_coll.h b/src/components/tl/shm/tl_shm_coll.h index 98f58dd97f..cce97c8dbf 100644 --- a/src/components/tl/shm/tl_shm_coll.h +++ b/src/components/tl/shm/tl_shm_coll.h @@ -85,6 +85,15 @@ ucc_tl_shm_get_data(ucc_tl_shm_seg_t *seg, ucc_tl_shm_team_t *team, return PTR_OFFSET(seg->data, data_size * rank); } +#define SHMCHECK_GOTO(_cmd, _task, _label) \ + do { \ + ucc_status_t _status = (_cmd); \ + if (UCC_OK != _status) { \ + _task->super.status = _status; \ + goto _label; \ + } \ + } while (0) + static inline ucc_status_t ucc_tl_shm_bcast_seg_ready(ucc_tl_shm_seg_t *seg, uint32_t seq_num, ucc_tl_shm_team_t *team,