From 3323519bad0f2b91532a8272234034920101709b Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 20 Oct 2023 15:12:44 -0700 Subject: [PATCH] Port samples/dynamic_shapes/ to PyTorch using SHARK-Turbine. (#15255) Progress on https://github.com/openxla/iree/issues/15117 and towards documenting the "advanced AOT toolkit" on https://www.iree.dev/guides/ml-frameworks/pytorch/ (trying out the various features so I can write about them). This notebook produces a program with the same interface as the existing TensorFlow notebook used by this sample. * `export_name="module"` could be removed, but then I'd want to somehow change the TF program and update the C code from `iree_make_cstring_view("module.reduce_sum_2d")` to `iree_make_cstring_view("dynamic_shapes.reduce_sum_2d")`. Keeping that messiness at least as long as the TF notebook remains. * Yes, I want to delete the TF notebook and rebase on PyTorch + JAX... New notebook preview for review: https://colab.research.google.com/github/scotttodd/iree/blob/samples-pytorch/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb --- samples/dynamic_shapes/README.md | 60 +- .../pytorch_dynamic_shapes.ipynb | 560 ++++++++++++++++++ ....ipynb => tensorflow_dynamic_shapes.ipynb} | 2 +- samples/dynamic_shapes/test.sh | 7 +- 4 files changed, 605 insertions(+), 24 deletions(-) create mode 100644 samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb rename samples/dynamic_shapes/{dynamic_shapes.ipynb => tensorflow_dynamic_shapes.ipynb} (99%) diff --git a/samples/dynamic_shapes/README.md b/samples/dynamic_shapes/README.md index 7473115d7cf9..bfcdafc14890 100644 --- a/samples/dynamic_shapes/README.md +++ b/samples/dynamic_shapes/README.md @@ -2,18 +2,21 @@ This sample shows how to -1. Create a TensorFlow program that includes dynamic shapes in program inputs - and outputs +1. Create a program that includes dynamic shapes in program inputs and outputs 2. Import that program into IREE's compiler 3. Compile that program to an IREE VM bytecode module 4. Load the compiled program using IREE's high level runtime C API 5. Call exported functions on the loaded program Steps 1-2 are performed in Python via the -[`dynamic_shapes.ipynb`](./dynamic_shapes.ipynb) -[Colab](https://research.google.com/colaboratory/) notebook: +[`pytorch_dynamic_shapes.ipynb`](./pytorch_dynamic_shapes.ipynb) or +[`tensorflow_dynamic_shapes.ipynb`](./tensorflow_dynamic_shapes.ipynb) +[Colab](https://colab.google/) notebooks: -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/iree/blob/main/samples/dynamic_shapes/dynamic_shapes.ipynb) +| Framework | Notebook | +| --------- | -------- | +PyTorch | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/iree/blob/main/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb) +TensorFlow | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/iree/blob/main/samples/dynamic_shapes/tensorflow_dynamic_shapes.ipynb) Step 3 should be performed on your development host machine @@ -23,33 +26,51 @@ The program used to demonstrate includes functions with varying uses of dynamic shapes: ```python -class DynamicShapesModule(tf.Module): +import torch +import shark_turbine.aot as aot + +class DynamicShapesModule(aot.CompiledModule, export_name="module"): # reduce_sum_1d (dynamic input size, static output size) + # tensor -> tensor # e.g. [1, 2, 3] -> 6 - @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)]) - def reduce_sum_1d(self, values): - return tf.math.reduce_sum(values) + def reduce_sum_1d(self, values=aot.AbstractTensor(None, dtype=torch.int32)): + return self.compute_reduce_sum_1d(values) + + @aot.jittable + def compute_reduce_sum_1d(values): + return torch.sum(values, dtype=torch.int32) # reduce_sum_2d (partially dynamic input size, static output size) + # tensor -> tensor<3xi32> # e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33] - @tf.function(input_signature=[tf.TensorSpec([None, 3], tf.int32)]) - def reduce_sum_2d(self, values): - return tf.math.reduce_sum(values, 0) + def reduce_sum_2d(self, values=aot.AbstractTensor(None, 3, dtype=torch.int32)): + return self.compute_reduce_sum_2d(values) + + @aot.jittable + def compute_reduce_sum_2d(values): + return torch.sum(values, 0, dtype=torch.int32) # add_one (dynamic input size, dynamic output size) + # tensor) -> tensor # e.g. [1, 2, 3] -> [2, 3, 4] - @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)]) - def add_one(self, values): - return tf.math.add(values, tf.constant(1, dtype=tf.int32)) + def add_one(self, values=aot.AbstractTensor(None, dtype=torch.int32)): + return self.compute_add_one(values) + + @aot.jittable + def compute_add_one(values): + return values + 1 ``` ## Background Tensors are multi-dimensional arrays with a uniform type (e.g. int32, float32) and a shape. Shapes consist of a rank and a list of dimensions and may be -static (i.e. fully known and fixed) or varying degrees of dynamic. See -TensorFlow's [Introduction to Tensors](https://www.tensorflow.org/guide/tensor) -for more information on how tensors are used in TensorFlow programs. +static (i.e. fully known and fixed) or varying degrees of dynamic. For more +information, see these references: +* PyTorch: +[Compiler dynamic shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html), +[`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html) +* TensorFlow: [Introduction to Tensors](https://www.tensorflow.org/guide/tensor) Dynamic shapes are useful for passing variable sized batches as input, receiving variable length sentences of text as output, etc. @@ -64,7 +85,7 @@ them. ## Instructions -1. Run the Colab notebook and download the `dynamic_shapes.mlir` file it +1. Run either Colab notebook and download the `dynamic_shapes.mlir` file it generates 2. Build the `iree-compile` tool (see @@ -83,7 +104,6 @@ them. ``` ../iree-build/tools/iree-compile \ --iree-hal-target-backends=llvm-cpu \ - --iree-input-type=stablehlo \ dynamic_shapes.mlir -o dynamic_shapes_cpu.vmfb ``` diff --git a/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb b/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb new file mode 100644 index 000000000000..a03049063936 --- /dev/null +++ b/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb @@ -0,0 +1,560 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "FH3IRpYTta2v" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "FH3IRpYTta2v" + }, + "source": [ + "##### Copyright 2023 The IREE Authors" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mWGa71_Ct2ug", + "cellView": "form" + }, + "source": [ + "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", + "# See https://llvm.org/LICENSE.txt for license information.\n", + "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h5s6ncerSpc5" + }, + "source": [ + "# Dynamic Shapes\n", + "\n", + "This notebook\n", + "\n", + "1. Creates a PyTorch program with dynamic shapes using [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine)'s advanced AOT toolkit\n", + "2. Compiles the program to an IREE VM bytecode module\n", + "3. Tests running the compiled VM module using IREE's runtime\n", + "4. Downloads compilation artifacts for use with the native (C API) sample application" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "s2bScbYkP6VZ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7b268798-a20d-4df4-f00d-ed7811f77767" + }, + "source": [ + "#@title General setup\n", + "\n", + "import os\n", + "import tempfile\n", + "\n", + "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", + "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n", + "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using artifacts directory '/tmp/iree/colab_artifacts'\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "%%capture\n", + "#@title Uninstall existing packages\n", + "# This avoids some warnings when installing specific PyTorch packages below.\n", + "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" + ], + "metadata": { + "id": "y9KOsqosg6Ms" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Install SHARK-Turbine\n", + "\n", + "# Limit cell height.\n", + "from IPython.display import Javascript\n", + "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n", + "\n", + "!python -m pip install shark-turbine" + ], + "metadata": { + "id": "SdCAvI3sqBO7", + "outputId": "2be248d9-bf6b-475e-c44a-aa529f20de23", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 300 + } + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting shark-turbine\n", + " Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n", + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/60.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n", + "Collecting iree-compiler>=20231004.665 (from shark-turbine)\n", + " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n", + " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m91.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (2.1.0+cu118)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2.1.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n", + "Building wheels for collected packages: shark-turbine\n", + " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=507dec827b9a2eea18f47c6ebdc84347c9956b8f2e0b186d3107a006e0742d81\n", + " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n", + "Successfully built shark-turbine\n", + "Installing collected packages: iree-runtime, iree-compiler, shark-turbine\n", + "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 shark-turbine-0.9.1.dev3\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Report version information\n", + "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n", + "\n", + "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", + "!iree-compile --version\n", + "\n", + "import torch\n", + "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Oj5I6R9LI7t_", + "outputId": "35d79e6a-7bd0-46e1-8113-5af1a7bcbb5b" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Installed SHARK-Turbine, Version: 0.9.1.dev3\n", + "\n", + "Installed IREE, compiler version information:\n", + "IREE (https://openxla.github.io/iree):\n", + " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n", + " LLVM version 18.0.0git\n", + " Optimized build\n", + "\n", + "Installed PyTorch, version: 2.1.0+cu118\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a program using PyTorch + SHARK-Turbine\n", + "\n", + "NOTE: as in other domains, providing more information to a compiler allows it\n", + "to generate more efficient code. As a general rule, the slowest varying\n", + "dimensions of program data like batch index or timestep are safer to treat as\n", + "dynamic than faster varying dimensions like image x/y/channel. See\n", + "[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the\n", + "challenges imposed by dynamic shapes and one project's approach to addressing\n", + "them." + ], + "metadata": { + "id": "C3mhaullI940" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Define a sample `shark_turbine.aot.CompiledModule` using dynamic shapes\n", + "\n", + "import shark_turbine.aot as aot\n", + "\n", + "class DynamicShapesModule(aot.CompiledModule, export_name=\"module\"):\n", + " # reduce_sum_1d (dynamic input size, static output size)\n", + " # tensor -> tensor\n", + " # e.g. [1, 2, 3] -> 6\n", + " def reduce_sum_1d(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", + " return self.compute_reduce_sum_1d(values)\n", + "\n", + " @aot.jittable\n", + " def compute_reduce_sum_1d(values):\n", + " return torch.sum(values, dtype=torch.int32)\n", + "\n", + " # reduce_sum_2d (partially dynamic input size, static output size)\n", + " # tensor -> tensor<3xi32>\n", + " # e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n", + " def reduce_sum_2d(self, values=aot.AbstractTensor(None, 3, dtype=torch.int32)):\n", + " return self.compute_reduce_sum_2d(values)\n", + "\n", + " @aot.jittable\n", + " def compute_reduce_sum_2d(values):\n", + " return torch.sum(values, 0, dtype=torch.int32)\n", + "\n", + " # add_one (dynamic input size, dynamic output size)\n", + " # tensor) -> tensor\n", + " # e.g. [1, 2, 3] -> [2, 3, 4]\n", + " def add_one(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", + " return self.compute_add_one(values)\n", + "\n", + " @aot.jittable\n", + " def compute_add_one(values):\n", + " return values + 1" + ], + "metadata": { + "id": "vsf9F4WxI_DX" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from iree.compiler.ir import Context\n", + "\n", + "# Import into MLIR and save to disk.\n", + "dynamic_shapes_instance = DynamicShapesModule(context=Context())\n", + "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n", + "aot.CompiledModule.save_mlir(dynamic_shapes_instance, imported_mlir_path)\n", + "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")\n", + "\n", + "# Inspect the IR.\n", + "# Note the question marks for dynamic shapes in types, like `tensor`.\n", + "print(\"\\nDynamic Shapes MLIR:\")\n", + "!cat {imported_mlir_path}" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_OQIpOtNr4Gh", + "outputId": "888c0bf3-bec6-403c-9993-ad45d21364fb" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n", + "\n", + "Dynamic Shapes MLIR:\n", + "#map = affine_map<(d0) -> (d0)>\n", + "#map1 = affine_map<(d0) -> ()>\n", + "#map2 = affine_map<(d0, d1) -> (d0, d1)>\n", + "#map3 = affine_map<(d0, d1) -> (d1)>\n", + "module @module {\n", + " func.func @reduce_sum_1d(%arg0: tensor) -> tensor attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", + " %0 = call @compute_reduce_sum_1d(%arg0) : (tensor) -> tensor\n", + " return %0 : tensor\n", + " }\n", + " func.func private @compute_reduce_sum_1d(%arg0: tensor) -> tensor {\n", + " %c0_i32 = arith.constant 0 : i32\n", + " %0 = tensor.empty() : tensor\n", + " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor) -> tensor\n", + " %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = [\"reduction\"]} ins(%arg0 : tensor) outs(%1 : tensor) {\n", + " ^bb0(%in: i32, %out: i32):\n", + " %3 = arith.addi %in, %out : i32\n", + " linalg.yield %3 : i32\n", + " } -> tensor\n", + " return %2 : tensor\n", + " }\n", + " func.func @reduce_sum_2d(%arg0: tensor) -> tensor<3xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", + " %0 = call @compute_reduce_sum_2d(%arg0) : (tensor) -> tensor<3xi32>\n", + " return %0 : tensor<3xi32>\n", + " }\n", + " func.func private @compute_reduce_sum_2d(%arg0: tensor) -> tensor<3xi32> {\n", + " %c0_i32 = arith.constant 0 : i32\n", + " %0 = tensor.empty() : tensor<3xi32>\n", + " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<3xi32>) -> tensor<3xi32>\n", + " %2 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"reduction\", \"parallel\"]} ins(%arg0 : tensor) outs(%1 : tensor<3xi32>) {\n", + " ^bb0(%in: i32, %out: i32):\n", + " %3 = arith.addi %in, %out : i32\n", + " linalg.yield %3 : i32\n", + " } -> tensor<3xi32>\n", + " return %2 : tensor<3xi32>\n", + " }\n", + " func.func @add_one(%arg0: tensor) -> tensor attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", + " %0 = call @compute_add_one(%arg0) : (tensor) -> tensor\n", + " return %0 : tensor\n", + " }\n", + " func.func private @compute_add_one(%arg0: tensor) -> tensor {\n", + " %c0 = arith.constant 0 : index\n", + " %c1_i32 = arith.constant 1 : i32\n", + " %dim = tensor.dim %arg0, %c0 : tensor\n", + " %0 = tensor.empty(%dim) : tensor\n", + " %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = [\"parallel\"]} ins(%arg0 : tensor) outs(%0 : tensor) {\n", + " ^bb0(%in: i32, %out: i32):\n", + " %2 = arith.addi %in, %c1_i32 : i32\n", + " linalg.yield %2 : i32\n", + " } -> tensor\n", + " return %1 : tensor\n", + " }\n", + "}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Test the imported program\n", + "\n", + "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n", + "\n", + "_See the [README](https://github.com/openxla/iree/tree/main/samples/dynamic_shapes#instructions) for more details and example command line instructions._\n", + "\n", + "* _The \"imported MLIR\" (above) can be used by IREE's generic compiler tools_\n", + "* _The \"binary\" can be saved and used by runtime applications_\n", + "\n", + "_The specific point at which you switch from Python to native tools will depend on your project._" + ], + "metadata": { + "id": "z6w_Pbl6tUtJ" + } + }, + { + "cell_type": "code", + "source": [ + "# Export and compile.\n", + "exported_output = aot.export(DynamicShapesModule)\n", + "\n", + "# Compile to a file on disk for usage outside of Python.\n", + "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n", + "exported_output.compile(save_to=flatbuffer_path)\n", + "print(f\"Wrote compiled program to path '{flatbuffer_path}'\")\n", + "\n", + "# Compile into memory for testing.\n", + "binary = exported_output.compile(save_to=None)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0PGyH1tvI_Ic", + "outputId": "23b53928-4d77-461f-e4b8-b2c8ffb25ef0" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Wrote compiled program to path '/tmp/iree/colab_artifacts/dynamic_shapes_cpu.vmfb'\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import iree.runtime as ireert\n", + "import numpy as np\n", + "\n", + "# Use the IREE runtime API to test the compiled program.\n", + "config = ireert.Config(\"local-task\")\n", + "vm_module = ireert.load_vm_module(\n", + " ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n", + " config,\n", + ")\n", + "\n", + "print(vm_module.reduce_sum_1d(np.array([1, 10, 100], dtype=np.int32)).to_host())\n", + "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30]], dtype=np.int32)).to_host())\n", + "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30], [100, 200, 300]], dtype=np.int32)).to_host())\n", + "print(vm_module.add_one(np.array([1, 10, 100], dtype=np.int32)).to_host())" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9ilJY15BI_LD", + "outputId": "57db6e52-83f1-4283-fc08-31e743cc9b42" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "111\n", + "[11 22 33]\n", + "[111 222 333]\n", + "[ 2 11 101]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Download compilation artifacts" + ], + "metadata": { + "id": "3mizlpY9uJEW" + } + }, + { + "cell_type": "code", + "source": [ + "ARTIFACTS_ZIP = \"/tmp/dynamic_shapes_colab_artifacts.zip\"\n", + "\n", + "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n", + "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n", + "\n", + "# Note: you can also download files using Colab's file explorer\n", + "try:\n", + " from google.colab import files\n", + " print(\"Downloading the artifacts zip file...\")\n", + " files.download(ARTIFACTS_ZIP)\n", + "except ImportError:\n", + " print(\"Missing google_colab Python package, can't download files\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 86 + }, + "id": "dgaXpdiWuGtx", + "outputId": "dc0fbca1-c5b0-44f9-e1ff-9bf1307c049f" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n", + " adding: dynamic_shapes_cpu.vmfb (deflated 66%)\n", + " adding: dynamic_shapes.mlir (deflated 82%)\n", + "Downloading the artifacts zip file...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_e2630f9b-e811-4164-b2d8-80cf52f17145\", \"dynamic_shapes_colab_artifacts.zip\", 5699)" + ] + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file diff --git a/samples/dynamic_shapes/dynamic_shapes.ipynb b/samples/dynamic_shapes/tensorflow_dynamic_shapes.ipynb similarity index 99% rename from samples/dynamic_shapes/dynamic_shapes.ipynb rename to samples/dynamic_shapes/tensorflow_dynamic_shapes.ipynb index 294b88cf9b14..126735190cc5 100644 --- a/samples/dynamic_shapes/dynamic_shapes.ipynb +++ b/samples/dynamic_shapes/tensorflow_dynamic_shapes.ipynb @@ -3,7 +3,7 @@ "nbformat_minor": 0, "metadata": { "colab": { - "name": "dynamic_shapes.ipynb", + "name": "tensorflow_dynamic_shapes.ipynb", "provenance": [], "collapsed_sections": [ "FH3IRpYTta2v" diff --git a/samples/dynamic_shapes/test.sh b/samples/dynamic_shapes/test.sh index dc302de82087..2336d0425a43 100755 --- a/samples/dynamic_shapes/test.sh +++ b/samples/dynamic_shapes/test.sh @@ -16,9 +16,11 @@ ROOT_DIR=$(git rev-parse --show-toplevel) BUILD_DIR=${ROOT_DIR}/build-samples ARTIFACTS_DIR=/tmp/iree/colab_artifacts -# 1. Run the notebook to generate `counter.mlir` and `counter_vmvx.vmfb` +# 1. Run the notebook to generate `dynamic_shapes.mlir` and +# `dynamic_shapes_cpu.vmfb` +# TODO(scotttodd): Test pytorch_dynamic_shapes.ipynb instead/also ${ROOT_DIR}/build_tools/testing/run_python_notebook.sh \ - ${ROOT_DIR}/samples/dynamic_shapes/dynamic_shapes.ipynb + ${ROOT_DIR}/samples/dynamic_shapes/tensorflow_dynamic_shapes.ipynb test -f ${ARTIFACTS_DIR}/dynamic_shapes.mlir && echo "dynamic_shapes.mlir exists" # 2. Build the `iree-compile` tool. @@ -28,7 +30,6 @@ cmake --build ${BUILD_DIR} --target iree-compile -- -k 0 # 3. Compile `dynamic_shapes.mlir` using `iree-compile`. ${BUILD_DIR}/tools/iree-compile \ --iree-hal-target-backends=llvm-cpu \ - --iree-input-type=stablehlo \ ${ARTIFACTS_DIR}/dynamic_shapes.mlir -o ${ARTIFACTS_DIR}/dynamic_shapes_cpu.vmfb # 4. Build the `iree_samples_dynamic_shapes` CMake target.