Skip to content

Commit

Permalink
Merge branch 'save-restore-function' into spec_future
Browse files Browse the repository at this point in the history
  • Loading branch information
SleepyMug committed Mar 21, 2024
2 parents 79a7546 + 113e321 commit ce32bff
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
50 changes: 42 additions & 8 deletions compiler/shell_ast/ast_to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from env_var_names import *
from shell_ast.ast_util import *
from shasta.ast_node import ast_match, is_empty_cmd, string_of_arg
from shasta.ast_node import ast_match, is_empty_cmd, string_of_arg, BArgChar
from shasta.json_to_ast import to_ast_node
from parse import from_ast_objects_to_shell
from speculative import util_spec
Expand Down Expand Up @@ -41,6 +41,9 @@ def __init__(self, num: int):

def add_command(self, command):
self._is_emtpy = False

def add_assignment(self, assignment):
self._is_emtpy = False

def make_non_empty(self):
self._is_emtpy = False
Expand Down Expand Up @@ -148,6 +151,9 @@ def add_break(self):

def add_command(self, command):
self.bbs[self.current_bb].add_command(command)

def add_var_assignment(self, assignment):
self.bbs[self.current_bb].add_assignment(assignment)

## Use this object to pass options inside the preprocessing
## trasnformation.
Expand All @@ -157,6 +163,8 @@ def __init__(self, mode: TransformationType):
self.node_counter = 0
self.loop_counter = 0
self.loop_contexts = []
self.var_counter = 0
self.var_contexts = []
self.prog = ShellProg()

def get_mode(self):
Expand Down Expand Up @@ -190,6 +198,12 @@ def get_current_loop_id(self):
else:
return self.loop_contexts[0]

def get_number_of_var_assignments(self):
return self.var_counter

def get_var_nodes(self):
return self.var_contexts[:]

def current_bb(self):
return self.prog.current_bb

Expand All @@ -215,6 +229,11 @@ def exit_if(self):
def visit_command(self, command):
if len(command.arguments) > 0 and string_of_arg(command.arguments[0]) == 'break':
self.prog.add_break()
elif len(command.arguments) == 0 and len(command.assignments) > 0 and not contains_command_substitution(command):
self.prog.add_var_assignment(command)
## GL: HACK to get ids right
self.var_contexts.append(self.get_current_id() + 1)
self.var_counter += 1
else:
self.prog.add_command(command)

Expand Down Expand Up @@ -289,7 +308,17 @@ def get_all_node_bb(self):
lambda ast_node: preprocess_node_case(ast_node, trans_options, last_object=last_object))
}


# Checks is var assignment value is BArgChar
def contains_command_substitution(ast_node):
if len(ast_node.assignments) == 0:
return False
for assignment in ast_node.assignments:
if len(assignment.val) == 0:
return False
for val in assignment.val:
if type(val) == BArgChar:
return True
return False

## Replace candidate dataflow AST regions with calls to PaSh's runtime.
def replace_ast_regions(ast_objects, trans_options):
Expand Down Expand Up @@ -435,16 +464,21 @@ def preprocess_node_command(ast_node, trans_options, last_object=False):
## If there are no arguments, the command is just an
## assignment (Q: or just redirections?)
trans_options : TransformationState
if(len(ast_node.arguments) == 0):

if trans_options.get_mode() is TransformationType.PASH \
and (len(ast_node.arguments) == 0):
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
non_maximal=False,
something_replaced=False,
last_ast=last_object)
replace_whole=False,
non_maximal=False,
something_replaced=False,
last_ast=last_object)
return preprocessed_ast_object

## This means we have a command. Commands are always candidate dataflow
## regions.

## GL: In spec mode, we treat assignment nodes as commands
# breakpoint()
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=True,
non_maximal=False,
Expand Down Expand Up @@ -704,7 +738,7 @@ def preprocess_node_case(ast_node, trans_options, last_object=False):
##
## If we are need to disable parallel pipelines, e.g., if we are in the context of an if,
## or if we are in the end of a script, then we set a variable.
def replace_df_region(asts, trans_options, disable_parallel_pipelines=False, ast_text=None) -> AstNode:
def replace_df_region(asts, trans_options: TransformationState, disable_parallel_pipelines=False, ast_text=None, var_assignment=False) -> AstNode:

transformation_mode = trans_options.get_mode()
if transformation_mode is TransformationType.PASH:
Expand Down
23 changes: 23 additions & 0 deletions compiler/speculative/util_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ def serialize_edge(from_id: int, to_id: int) -> str:
def serialize_number_of_nodes(number_of_ids: int) -> str:
return f'{number_of_ids}\n'

def serialize_number_of_var_assignments(number_of_var_assignments: int) -> str:
return f'{number_of_var_assignments}\n'

def serialize_loop_context(node_id: int, bb_id) -> str:
## Galaxy brain serialization
# loop_contexts_str = ",".join([str(loop_ctx) for loop_ctx in loop_contexts])
bb_id_str = str(bb_id)
return f'{node_id}-loop_ctx-{bb_id_str}\n'

def serialize_var_assignments(node_id: int) -> str:
return f'{node_id}-var\n'

def save_current_env_to_file(trans_options):
initial_env_file = ptempfile()
subprocess.check_output([f"{os.getenv('PASH_TOP')}/compiler/orchestrator_runtime/pash_declare_vars.sh", initial_env_file])
Expand All @@ -89,6 +95,19 @@ def save_loop_contexts(trans_options):
bb_id = node_bb_dict[node_id]
po_file.write(serialize_loop_context(node_id, bb_id))

def save_var_assignment_contexts(trans_options):
var_nodes = trans_options.get_var_nodes()
partial_order_file_path = trans_options.get_partial_order_file()
with open(partial_order_file_path, "a") as po_file:
for node_id in var_nodes:
po_file.write(serialize_var_assignments(node_id))

def save_number_of_var_assignments(trans_options):
number_of_var_assignments = trans_options.get_number_of_var_assignments()
partial_order_file_path = trans_options.get_partial_order_file()
with open(partial_order_file_path, "a") as po_file:
po_file.write(serialize_number_of_var_assignments(number_of_var_assignments))

def serialize_partial_order(trans_options):
## Initialize the po file
dir_path = partial_order_directory()
Expand All @@ -110,6 +129,10 @@ def serialize_partial_order(trans_options):

## Save loop contexts
save_loop_contexts(trans_options)

save_number_of_var_assignments(trans_options)

save_var_assignment_contexts(trans_options)

# Save the edges in the partial order file
edges = trans_options.get_all_edges()
Expand Down

0 comments on commit ce32bff

Please sign in to comment.