Skip to content

Commit

Permalink
TL/SHM: progress api change
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin Petrov committed Mar 31, 2022
1 parent 1de4a87 commit d649161
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 155 deletions.
54 changes: 19 additions & 35 deletions src/components/tl/shm/barrier/barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,31 @@ enum
BARRIER_STAGE_TOP_TREE_FANOUT,
};

static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task)
static void ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_tl_shm_seg_t * seg = task->seg;
ucc_tl_shm_tree_t *tree = task->tree;
ucc_status_t status;
ucc_tl_shm_ctrl_t *my_ctrl;

next_stage:
switch (task->stage) {
case BARRIER_STAGE_START:
/* checks if previous collective has completed on the seg
TODO: can be optimized if we detect barrier->reduce pattern.*/
if (UCC_OK != ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num,
team, tree)) {
return UCC_INPROGRESS;
}
SHMCHECK_GOTO(ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num,
team, tree), task, out);
if (tree->base_tree) {
task->stage = BARRIER_STAGE_BASE_TREE_FANIN;
} else {
task->stage = BARRIER_STAGE_TOP_TREE_FANIN;
}
goto next_stage;
case BARRIER_STAGE_BASE_TREE_FANIN:
status = ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree);
if (UCC_OK != status) {
/* in progress */
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree),
task, out);
if (tree->top_tree) {
task->stage = BARRIER_STAGE_TOP_TREE_FANIN;
} else {
Expand All @@ -55,56 +49,46 @@ static ucc_status_t ucc_tl_shm_barrier_progress(ucc_coll_task_t *coll_task)
}
goto next_stage;
case BARRIER_STAGE_TOP_TREE_FANIN:
status = ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree);
if (UCC_OK != status) {
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree),
task, out);
task->stage = BARRIER_STAGE_TOP_TREE_FANOUT;
task->seq_num++; /* finished fanin, need seq_num to be updated for fanout */
goto next_stage;
case BARRIER_STAGE_TOP_TREE_FANOUT:
status = ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree);
if (UCC_OK != status) {
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree),
task, out);
if (tree->base_tree) {
task->stage = BARRIER_STAGE_BASE_TREE_FANOUT;
goto next_stage;
}
break;
case BARRIER_STAGE_BASE_TREE_FANOUT:
status = ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree);
if (UCC_OK != status) {
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree),
task, out);
break;
}

my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank);
/* task->seq_num was updated between fanin and fanout, now needs to be rewinded to fit general collectives order, as barrier is actually a single collective */
/* task->seq_num was updated between fanin and fanout, now needs to be
rewinded to fit general collectives order, as barrier is actually
a single collective */
my_ctrl->ci = task->seq_num - 1;
/* barrier done */
task->super.super.status = UCC_OK;
task->super.status = UCC_OK;
UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_barrier_progress_done", 0);
return UCC_OK;
out:
return;
}

static ucc_status_t ucc_tl_shm_barrier_start(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_status_t status;

UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_barrier_start", 0);
UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team);
task->super.super.status = UCC_INPROGRESS;
status = task->super.progress(&task->super);

if (UCC_INPROGRESS == status) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}
return ucc_task_complete(coll_task);
task->super.status = UCC_INPROGRESS;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_shm_barrier_init(ucc_base_coll_args_t *coll_args,
Expand Down
49 changes: 16 additions & 33 deletions src/components/tl/shm/bcast/bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ static ucc_status_t ucc_tl_shm_bcast_read(ucc_tl_shm_team_t *team,
return UCC_INPROGRESS;
}

static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
static void ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
Expand All @@ -128,7 +128,6 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
int is_inline = data_size <= team->max_inline;
int is_op_root = rank == root;
ucc_tl_shm_ctrl_t *my_ctrl, *parent_ctrl;
ucc_status_t status;
void * src;

next_stage:
Expand All @@ -138,10 +137,8 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
(tree->base_tree == NULL && tree->top_tree->n_children > 0)) {
/* checks if previous collective has completed on the seg
TODO: can be optimized if we detect bcast->reduce pattern.*/
if (UCC_OK != ucc_tl_shm_bcast_seg_ready(
seg, task->seg_ready_seq_num, team, tree)) {
return UCC_INPROGRESS;
}
SHMCHECK_GOTO(ucc_tl_shm_bcast_seg_ready(seg,
task->seg_ready_seq_num, team, tree), task, out);
}
if (tree->top_tree) {
task->stage = BCAST_STAGE_TOP_TREE;
Expand All @@ -151,15 +148,11 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
goto next_stage;
case BCAST_STAGE_TOP_TREE:
if (task->progress_alg == BCAST_WW || task->progress_alg == BCAST_WR) {
status = ucc_tl_shm_bcast_write(team, seg, task, tree->top_tree,
is_inline, &is_op_root, data_size);
SHMCHECK_GOTO(ucc_tl_shm_bcast_write(team, seg, task, tree->top_tree,
is_inline, &is_op_root, data_size), task, out);
} else {
status = ucc_tl_shm_bcast_read(team, seg, task, tree->top_tree,
is_inline, &is_op_root, data_size);
}
if (UCC_OK != status) {
/* in progress */
return status;
SHMCHECK_GOTO(ucc_tl_shm_bcast_read(team, seg, task, tree->top_tree,
is_inline, &is_op_root, data_size), task, out);
}
if (tree->base_tree) {
task->stage = BCAST_STAGE_BASE_TREE;
Expand All @@ -168,15 +161,11 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)
break;
case BCAST_STAGE_BASE_TREE:
if (task->progress_alg == BCAST_WW || task->progress_alg == BCAST_RW) {
status = ucc_tl_shm_bcast_write(team, seg, task, tree->base_tree,
is_inline, &is_op_root, data_size);
SHMCHECK_GOTO(ucc_tl_shm_bcast_write(team, seg, task, tree->base_tree,
is_inline, &is_op_root, data_size), task, out);
} else {
status = ucc_tl_shm_bcast_read(team, seg, task, tree->base_tree,
is_inline, &is_op_root, data_size);
}
if (UCC_OK != status) {
/* in progress */
return status;
SHMCHECK_GOTO(ucc_tl_shm_bcast_read(team, seg, task, tree->base_tree,
is_inline, &is_op_root, data_size), task, out);
}
break;
}
Expand Down Expand Up @@ -215,28 +204,22 @@ static ucc_status_t ucc_tl_shm_bcast_progress(ucc_coll_task_t *coll_task)

my_ctrl->ci = task->seq_num;
/* bcast done */
task->super.super.status = UCC_OK;
task->super.status = UCC_OK;
UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_bcast_rw_progress_done",
0);
return UCC_OK;
out:
return;
}

static ucc_status_t ucc_tl_shm_bcast_start(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_status_t status;

UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_bcast_start", 0);
UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team);
task->super.super.status = UCC_INPROGRESS;
status = task->super.progress(&task->super);

if (UCC_INPROGRESS == status) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}
return ucc_task_complete(coll_task);
task->super.status = UCC_INPROGRESS;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_shm_bcast_init(ucc_base_coll_args_t *coll_args,
Expand Down
42 changes: 13 additions & 29 deletions src/components/tl/shm/fanin/fanin.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,77 +14,61 @@ enum
FANIN_STAGE_TOP_TREE,
};

static ucc_status_t ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task)
static void ucc_tl_shm_fanin_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_tl_shm_seg_t * seg = task->seg;
ucc_tl_shm_tree_t *tree = task->tree;
ucc_status_t status;
ucc_tl_shm_ctrl_t *my_ctrl;

next_stage:
switch (task->stage) {
case FANIN_STAGE_START:
/* checks if previous collective has completed on the seg
TODO: can be optimized if we detect fanin->reduce pattern.*/
if (UCC_OK != ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num,
team, tree)) {
return UCC_INPROGRESS;
}
SHMCHECK_GOTO(ucc_tl_shm_reduce_seg_ready(seg, task->seg_ready_seq_num,
team, tree), task, out);
if (tree->base_tree) {
task->stage = FANIN_STAGE_BASE_TREE;
} else {
task->stage = FANIN_STAGE_TOP_TREE;
}
goto next_stage;
case FANIN_STAGE_BASE_TREE:
status = ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree);
if (UCC_OK != status) {
/* in progress */
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->base_tree),
task, out);
if (tree->top_tree) {
task->stage = FANIN_STAGE_TOP_TREE;
goto next_stage;
}
break;
case FANIN_STAGE_TOP_TREE:
status = ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree);

if (UCC_OK != status) {
/* in progress */
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanin_signal(team, seg, task, tree->top_tree),
task, out);
break;
}

my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank);
my_ctrl->ci = task->seq_num;
/* fanin done */
task->super.super.status = UCC_OK;
task->super.status = UCC_OK;
UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanin_progress_done", 0);
return UCC_OK;
out:
return;
}

static ucc_status_t ucc_tl_shm_fanin_start(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_status_t status;

UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanin_start", 0);
UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team);
task->super.super.status = UCC_INPROGRESS;
status = task->super.progress(&task->super);

if (UCC_INPROGRESS == status) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}
return ucc_task_complete(coll_task);
}
task->super.status = UCC_INPROGRESS;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_shm_fanin_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * tl_team,
Expand Down
40 changes: 12 additions & 28 deletions src/components/tl/shm/fanout/fanout.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ enum
FANOUT_STAGE_TOP_TREE,
};

static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task)
static void ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_tl_shm_seg_t * seg = task->seg;
ucc_tl_shm_tree_t *tree = task->tree;
ucc_status_t status;
ucc_tl_shm_ctrl_t *my_ctrl;

next_stage:
Expand All @@ -32,10 +31,8 @@ static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task)
tree->top_tree->n_children > 0)) { //similar to bcast
/* checks if previous collective has completed on the seg
TODO: can be optimized if we detect bcast->reduce pattern.*/
if (UCC_OK != ucc_tl_shm_bcast_seg_ready(
seg, task->seg_ready_seq_num, team, tree)) {
return UCC_INPROGRESS;
}
SHMCHECK_GOTO(ucc_tl_shm_bcast_seg_ready(seg,
task->seg_ready_seq_num, team, tree), task, out);
}
if (tree->top_tree) {
task->stage = FANOUT_STAGE_TOP_TREE;
Expand All @@ -44,50 +41,37 @@ static ucc_status_t ucc_tl_shm_fanout_progress(ucc_coll_task_t *coll_task)
}
goto next_stage;
case FANOUT_STAGE_TOP_TREE:
status = ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree);
if (UCC_OK != status) {
/* in progress */
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->top_tree),
task, out);
if (tree->base_tree) {
task->stage = FANOUT_STAGE_BASE_TREE;
goto next_stage;
}
break;
case FANOUT_STAGE_BASE_TREE:
status = ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree);

if (UCC_OK != status) {
/* in progress */
return status;
}
SHMCHECK_GOTO(ucc_tl_shm_fanout_signal(team, seg, task, tree->base_tree),
task, out);
break;
}

my_ctrl = ucc_tl_shm_get_ctrl(seg, team, rank);
my_ctrl->ci = task->seq_num;
/* fanout done */
task->super.super.status = UCC_OK;
task->super.status = UCC_OK;
UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanout_progress_done", 0);
return UCC_OK;
out:
return;
}

static ucc_status_t ucc_tl_shm_fanout_start(ucc_coll_task_t *coll_task)
{
ucc_tl_shm_task_t *task = ucc_derived_of(coll_task, ucc_tl_shm_task_t);
ucc_tl_shm_team_t *team = TASK_TEAM(task);
ucc_status_t status;

UCC_TL_SHM_PROFILE_REQUEST_EVENT(coll_task, "shm_fanout_start", 0);
UCC_TL_SHM_SET_SEG_READY_SEQ_NUM(task, team);
task->super.super.status = UCC_INPROGRESS;
status = task->super.progress(&task->super);

if (UCC_INPROGRESS == status) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}
return ucc_task_complete(coll_task);
task->super.status = UCC_INPROGRESS;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_shm_fanout_init(ucc_base_coll_args_t *coll_args,
Expand Down
Loading

0 comments on commit d649161

Please sign in to comment.