Skip to content

Commit

Permalink
[iree-import-onnx] improve handling of large models (iree-org#19217)
Browse files Browse the repository at this point in the history
This pr adds a few options:

1. `--large-model` allows disabling the onnx model checker if a user
knows ahead of time that the model is too large. It will also not load
the external weights in memory unless saving the parameters.
2. `--num-initializers-threshold` allows storing initializers to the
irpa file in batches with a specified number of entries. This can reduce
the memory overhead of first gathering all of the initializers, then
saving them to the irpa at once.
3. `--externalize-inputs-threshold` allows converting inputs to
externalized weights. This is useful for the following workflow:
exporting a HF pytorch model with safetensors, saving a `.irpa` from the
safetensor weights directly, and exporting to onnx with
`export_params=False` and `do_constant_folding=False` (which converts
weights to inputs and avoids folding weights with things like
transposes). When importing to mlir, you can set
`externalize-inputs-threshold=<num_original_inputs>` and it will convert
the inputs from and beyond that threshold to `util.global` ops.
4. `--save-params`/`--no-save-params` factors saving parameters out of
`import_initializer`, and one can avoid saving parameters with
`--no-save-params`. Useful for debugging compilation failures.

## TODO:

Figure out what to do about loading the onnx model and updating opset
version. It's possible to do opset version updating without weights in a
somewhat hacky way, since models > 2GB fail on opset version updating.

Add documentation

---------

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
  • Loading branch information
zjgarvey authored Dec 9, 2024
1 parent c62c3d0 commit ab3c9bb
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 82 deletions.
175 changes: 121 additions & 54 deletions compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pathlib import Path
import sys
import tempfile

from .importer_externalization_overrides import *


Expand All @@ -32,25 +31,45 @@ def main(args: argparse.Namespace):

imp: Any = None
if args.externalize_params:
imp = IREENodeImporter.define_function(
model_info.main_graph, m, args.num_elements_threshold, args.params_scope
if not args.save_params:
param_path = None
elif args.save_params_to:
param_path = args.save_params_to
elif (args.output_file is not None) and (args.output_file != "-"):
output_dir = Path(args.output_file).parent
output_stem = Path(args.output_file).stem
param_path = output_dir / (output_stem + "_params.irpa")
else:
raise ValueError(
"If `--externalize-params` is set and `--output-file` is stdout, either `--save-params-to` or `--no-save-params` must be set."
)
data_dir = (
args.data_dir
if args.data_dir is not None
else str(Path(args.input_file).parent)
)
param_bit_threshold = (
None
if args.param_gb_threshold is None
else int(args.param_gb_threshold * 8 * (10**9))
)
param_data = ParamData(
param_bit_threshold=param_bit_threshold,
num_elements_threshold=args.num_elements_threshold,
params_scope=args.params_scope,
data_dir=data_dir,
param_path=str(param_path),
input_index_threshold=args.externalize_inputs_threshold,
)
imp = IREENodeImporter.define_function(model_info.main_graph, m, param_data)
else:
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)

imp.import_all()

if not args.no_verify:
m.verify()

if args.externalize_params:
default_param_path = Path(args.output_file).parent / Path(args.output_file).stem
param_path = (
(str(default_param_path) + "_params.irpa")
if args.save_params_to is None
else str(args.save_params_to)
)
imp.param_archive.create_archive_file(param_path)

# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
# some copies.
Expand All @@ -60,41 +79,54 @@ def main(args: argparse.Namespace):
else:
print(m.get_asm(assume_verified=not args.no_verify))

if args.externalize_params and args.save_params:
imp.save_params()


def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
input_dir = os.path.dirname(os.path.abspath(args.input_file))

# Load the model, with possible external data coming from the default
# location, or the location specified on the command line.
if args.data_dir is None:
raw_model = onnx.load(args.input_file)
else:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

# Only change the opset version if it is greater than the current one.
if args.opset_version and args.opset_version > raw_model.opset_import[0].version:
raw_model = onnx.version_converter.convert_version(
raw_model, args.opset_version
# TODO: setup updating opset version without loading external weights.
if args.opset_version and args.large_model:
raise NotImplementedError(
"Updating the opset version for large models is currently unsupported."
)

# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
# files. See
# https://onnx.ai/onnx/repo-docs/PythonAPIOverview.html#shape-inference-a-large-onnx-model-2gb
# for details about the file-based technique.

# Run the checker to test whether the file is above the threshold for
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
onnx.checker.check_model(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(
raw_model, data_prop=args.data_prop
)
return inferred_model
except ValueError:
pass
if not args.large_model:
# Load the model, with possible external data coming from the default
# location, or the location specified on the command line.
if args.data_dir is None:
raw_model = onnx.load(args.input_file)
else:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

# Only change the opset version if it is greater than the current one.
if (
args.opset_version
and args.opset_version > raw_model.opset_import[0].version
):
raw_model = onnx.version_converter.convert_version(
raw_model, args.opset_version
)

# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
# files. See
# https://onnx.ai/onnx/repo-docs/PythonAPIOverview.html#shape-inference-a-large-onnx-model-2gb
# for details about the file-based technique.

# Run the checker to test whether the file is above the threshold for
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
onnx.checker.check_model(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(
raw_model, data_prop=args.data_prop
)
return inferred_model
except ValueError:
pass

# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
Expand All @@ -111,7 +143,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# Load the temp file and the external data.
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
data_dir = Path(input_dir if args.data_dir is None else args.data_dir)
onnx.load_external_data_for_model(inferred_model, str(data_dir))
# we don't need to load the model weights in-memory when externalizing params
if not args.externalize_params:
onnx.load_external_data_for_model(inferred_model, str(data_dir))

return inferred_model

Expand Down Expand Up @@ -146,30 +180,63 @@ def parse_arguments(argv=None) -> argparse.Namespace:
type=int,
)
parser.add_argument(
"--large-model",
help="Setting this to true is recommended for large models that do not require --opset-version."
" It will bypass loading external weights and running the onnx checker to determine the model size.",
action=argparse.BooleanOptionalAction,
default=False,
)
# args for saving a file with externalized params
externalization_args = parser.add_argument_group(
"externalization", "args used to customize the externalization of model weights"
)
externalization_args.add_argument(
"--externalize-params",
help="Import the mlir file with large weights replaced by external reference calls.",
action=argparse.BooleanOptionalAction,
default=False,
)
externalization_args.add_argument(
"--externalize-inputs-threshold",
help="Treats inputs at or after the provided index as external parameters of the model."
" Only has an effect if 'externalize-params' is true.",
type=int,
)
externalization_args.add_argument(
"--num-elements-threshold",
help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.",
help="Minimum number of elements for an initializer to be externalized."
" Only has an effect if 'externalize-params' is true.",
type=int,
default=100,
)
parser.add_argument(
"--externalize-params",
help="Externalize large parameters and store them on the disk, to load at runtime.",
externalization_args.add_argument(
"--params-scope",
help="The namespace or the scope in which the externalized parameters are placed."
" Default is 'model'.",
type=str,
default="model",
)
# args for creating an external weight file
externalization_args.add_argument(
"--save-params",
help="Whether to save the params to a file. Setting this to false will generate mlir with externalized weights"
" without creating an associated .irpa file.",
action=argparse.BooleanOptionalAction,
default=False,
default=True,
)
parser.add_argument(
externalization_args.add_argument(
"--param-gb-threshold",
help="Setting this will flush params to a temp file when total in-memory param size exceeds the Gigabyte threshold."
" This is less efficient (about x2 slower) and only recommended for machines with limited RAM.",
type=float,
)
externalization_args.add_argument(
"--save-params-to",
help="Location to save the externalized parameters. When not set, the parameters will be written to '<output_file_name>_params.irpa'"
" under the namespace 'model', which can be configured by passing the namespace string to 'params-scope'.",
default=None,
type=Path,
)
parser.add_argument(
"--params-scope",
help="The namespace or the scope in which the externalized parameters are placed. Default is 'model'.",
type=str,
default="model",
)
args = parser.parse_args(argv)
return args

Expand Down
Loading

0 comments on commit ab3c9bb

Please sign in to comment.