Skip to content

Commit

Permalink
[TVMScript] Avoid segfault from invalid TVMScript (#17373)
Browse files Browse the repository at this point in the history
* [TVMScript] Avoid segfault from invalid TVMScript

Prior to this commit, after the `DiagnosticContext` prints its error,
it overwrites the `DiagnosticRenderer` with a NULL renderer.  If a
second call to `DiagnosticContext::Render` occurs, it will segfault.
This appears to be intended to prevent double-printing of error
messages, but double-printing error messages is much worse than a
segfault.

In addition, `DiagnosticContext::Render` should only be called once.
There's a common pattern in the parser where it will wrap exceptions
in `DiagnosticError`, but re-raise exceptions that are already a
`DiagnosticError`.  This requires every such location to include
`except DiagnosticError: raise`, and can easily be missed.

This PR makes two changes: First, the `DiagnosticRenderer` is updated
to have a no-op callback rather than a NULL callback.  Second, the
re-raising of `DiagnosticError` is moved to `Parser.report_error`, so
that it does not need to be handled separately at several independent
locations in the TVMScript parser.
  • Loading branch information
Lunderberg authored Sep 17, 2024
1 parent 9f28175 commit ff8e416
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 27 deletions.
12 changes: 6 additions & 6 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def _visit(self, node: doc.AST) -> Any:
value = self._eval_slice(fields)
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, e)
except Exception as err: # pylint: disable=broad-except
self.parser.report_error(node, err)
return self._add_intermediate_result(value)

def _eval_lambda(self, node: doc.Lambda) -> Any:
Expand All @@ -286,8 +286,8 @@ def _eval_lambda(self, node: doc.Lambda) -> Any:
"""
try:
value = self._eval_expr(node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, str(e))
except Exception as err: # pylint: disable=broad-except
self.parser.report_error(node, err)
return self._add_intermediate_result(value)

def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -463,8 +463,8 @@ def eval_assign(
"""
try:
return _eval_assign(target, source)
except Exception as e: # pylint: disable=broad-except,invalid-name
parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
except Exception as err: # pylint: disable=broad-except
parser.report_error(target, err)
raise


Expand Down
19 changes: 10 additions & 9 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
def _wrapper(self: "Parser", node: doc.AST) -> None:
try:
return func(self, node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, e)
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise

return _wrapper
Expand Down Expand Up @@ -547,6 +545,12 @@ def report_error(
err: Union[Exception, str]
The error to report.
"""

# If the error is already being raised as a DiagnosticError,
# re-raise it without wrapping it in a DiagnosticContext.
if isinstance(err, DiagnosticError):
raise err

# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
Expand Down Expand Up @@ -595,11 +599,8 @@ def visit(self, node: doc.AST) -> None:
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
try:
func(node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
raise
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)

def visit_body(self, node: List[doc.stmt]) -> Any:
"""The general body visiting method.
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,19 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
try:
annotation = self.eval_expr(node)
return _normalize_struct_info_proxy(annotation)
except Exception as err:
self.report_error(node, str(err))
raise err
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise


def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo:
var_table = self.var_table.get() if eval_str else None
try:
struct_info = self.eval_expr(node)
return _normalize_struct_info(struct_info, var_table)
except Exception as err:
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise err
raise


def is_called(node: Any, func_name: str) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ void DiagnosticContext::Render() {
}

if (errs) {
(*this)->renderer = DiagnosticRenderer();
(*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {});
// (*this)->diagnostics.clear();
LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
<< "emitted, please check diagnostic render for output.";
}
Expand Down
14 changes: 11 additions & 3 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def f(x: R.Tensor):
return x


def test_incorrect_tensor_shape():
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def f(x: R.Tensor([16])):
y: R.Tensor(16) = R.add(x, x)
return y


def test_simple_module():
@I.ir_module
class TestModule:
Expand Down Expand Up @@ -1045,7 +1054,6 @@ def main(


def test_call_tir_inplace_with_tuple_var_raises_error():

with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
Expand Down Expand Up @@ -1838,7 +1846,7 @@ def mul_add(x: R.Tensor) -> R.Tensor:
_check(InputModule, OutputModule)


def test_context_aware_parsing():
def test_context_aware_parsing(monkeypatch):
@tvm.script.ir_module
class Module:
@T.prim_func
Expand All @@ -1863,7 +1871,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32
def _break_env(self, *args):
raise RuntimeError("Fail to pass context-aware parsing")

tvm.ir.GlobalVar.__call__ = _break_env
monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env)

_check(Module)

Expand Down
8 changes: 5 additions & 3 deletions tests/python/tvmscript/test_tvmscript_printer_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm.testing
from tvm import relay
from tvm.script import tir as T
from tvm.script.highlight import cprint
from tvm.script.highlight import cprint, _format


def test_highlight_script():
Expand Down Expand Up @@ -58,12 +58,14 @@ def test_cprint():
# Print nodes with `script` method, e.g. PrimExpr
cprint(tvm.tir.Var("v", "int32") + 1)

# Cannot print non-Python-style codes if black installed
# Cannot print non-Python-style codes when using the black
# formatter. This error comes from `_format`, used internally by
# `cprint`, and doesn't occur when using the `ruff` formatter.
try:
import black

with pytest.raises(ValueError):
cprint("if (a == 1) { a +=1; }")
_format("if (a == 1) { a +=1; }", formatter="black")
except ImportError:
pass

Expand Down

0 comments on commit ff8e416

Please sign in to comment.