From 404c5ee091626ddf4fa50a569abbcd2f5d5b888d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Benesty?= Date: Tue, 28 Dec 2021 23:28:40 +0100 Subject: [PATCH] Add QAT support to more models (#29) * first version of QDQ monkey patching * add Albert, Electra and Distilbert QAT support * add QDQDeberta V1 * fix distilbert * add ast patch add quant onnx export * simplify quantization process * fix qdq deberta * quantization refactoring * add documentation add quantization tests add deberta v2 * add quant of layernorm refactor ast modif add tests * add operator name in quantizer name update notebook * update notebook * update notebook --- README.md | 2 +- VERSION | 2 +- .../quantization_end_to_end.ipynb | 1989 ++++++----------- src/transformer_deploy/QDQModels/QDQAlbert.py | 20 + src/transformer_deploy/QDQModels/QDQBert.py | 20 + .../QDQModels/QDQDeberta.py | 71 + .../QDQModels/QDQDistilbert.py | 20 + .../QDQModels/QDQElectra.py | 21 + .../QDQModels/QDQRoberta.py | 1612 +------------ .../QDQModels/ast_module_patch.py | 196 ++ .../QDQModels/ast_operator_patch.py | 112 + .../QDQModels/calibration_utils.py | 111 + src/transformer_deploy/QDQModels/patch.py | 74 + src/transformer_deploy/backends/ort_utils.py | 48 +- src/transformer_deploy/benchmarks/utils.py | 29 +- src/transformer_deploy/convert.py | 16 +- src/transformer_deploy/utils/args.py | 5 + .../utils/python_tokenizer.py | 15 +- tests/test_ast_modifications.py | 92 + 19 files changed, 1572 insertions(+), 2883 deletions(-) rename demo/{ => quantization}/quantization_end_to_end.ipynb (72%) create mode 100644 src/transformer_deploy/QDQModels/QDQAlbert.py create mode 100644 src/transformer_deploy/QDQModels/QDQBert.py create mode 100644 src/transformer_deploy/QDQModels/QDQDeberta.py create mode 100644 src/transformer_deploy/QDQModels/QDQDistilbert.py create mode 100644 src/transformer_deploy/QDQModels/QDQElectra.py create mode 100644 src/transformer_deploy/QDQModels/ast_module_patch.py create mode 100644 src/transformer_deploy/QDQModels/ast_operator_patch.py create mode 100644 src/transformer_deploy/QDQModels/calibration_utils.py create mode 100644 src/transformer_deploy/QDQModels/patch.py create mode 100644 tests/test_ast_modifications.py diff --git a/README.md b/README.md index 8ae209a0..08f84069 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ * [🐍 TensorRT usage in Python script](#tensorrt-usage-in-python-script) * [⏱ benchmarks](#benchmarks) * [πŸ€— end to end reproduction of Infinity Hugging Face demo](./demo/README.md) (to replay [Medium article](https://towardsdatascience.com/hugging-face-transformer-inference-under-1-millisecond-latency-e1be0057a51c?source=friends_link&sk=cd880e05c501c7880f2b9454830b8915)) -* [🏎️ end to end GPU quantization tutorial and many benchmarks (ONNX Runtime, TensorRT, vanilla Pytorch, etc.)](./demo/quantization_end_to_end.ipynb) +* [🏎️ end to end GPU quantization tutorial and many benchmarks (ONNX Runtime, TensorRT, vanilla Pytorch, etc.)](demo/quantization/quantization_end_to_end.ipynb) #### Why this tool? diff --git a/VERSION b/VERSION index 0c62199f..0d91a54c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.1 +0.3.0 diff --git a/demo/quantization_end_to_end.ipynb b/demo/quantization/quantization_end_to_end.ipynb similarity index 72% rename from demo/quantization_end_to_end.ipynb rename to demo/quantization/quantization_end_to_end.ipynb index 6e339f81..33fb99ba 100644 --- a/demo/quantization_end_to_end.ipynb +++ b/demo/quantization/quantization_end_to_end.ipynb @@ -4,101 +4,91 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Method to perform Nvidia GPU INT-8 quantization on most transformers model (encoder based)" + "# Nvidia GPU INT-8 quantization on any transformers model (encoder based)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "\n", - "zation is one of the most effective and generic approach to make model inference faster.\n", + "Quantization is one of the most effective and generic approaches to make model inference faster.\n", "Basically, it replaces high precision float numbers in model tensors encoded in 32 or 16 bits by lower precision ones encoded in 8 bits or less:\n", "\n", "* it takes less memory\n", "* computation is easier / faster\n", "\n", - "It can be applied to any model in theory, and, if done well, it should not decrease its accuracy.\n", + "It can be applied to any model in theory, and, if done well, it should maintain accuracy.\n", "\n", - "The purpose of this notebook is to show 2 processes to perform quantization on most `transformer` architectures.\n", + "The purpose of this notebook is to show a process to perform quantization on any `transformer` architectures.\n", + "\n", + "Moreover, the library is designed to offer a simple API and still let advanced users tweak the algorithm.\n", "\n", "**TL;DR, we benchmarked Pytorch and Nvidia TensorRT, on both CPU and GPU, with/without quantization, our methods provide the fastest inference by large margin**.\n", "\n", - "| Framework | Precision | Latency (ms) | Accuracy | Speedup | Hardware |\n", - "|:----------------------------|-----------|--------------|----------|:-----------|:--------:|\n", - "| Pytorch | FP32 | 4000 | 86.8 % | X 0.02 | CPU |\n", - "| Pytorch | FP16 | 4005 | 86.8 % | X 0.02 | CPU |\n", - "| Pytorch | **INT-8** | 3670 | 86.8 % | X 0.02 | **CPU** |\n", - "| Pytorch | FP32 | 80 | 86.8 % | X 1 | GPU |\n", - "| Pytorch | FP16 | 58 | 86.8 % | X 1.38 | GPU |\n", - "| ONNX Runtime | FP32 | 74 | 86.8 % | X 1.08 | GPU |\n", - "| ONNX Runtime | FP16 | 34 | 86.8 % | X 2.35 | GPU |\n", - "| ONNX Runtime | FP32 | 3767 | 86.8 % | X 0.02 | CPU |\n", - "| ONNX Runtime | FP16 | 4607 | 86.8 % | X 0.02 | CPU |\n", - "| ONNX Runtime | **INT-8** | 3712 | 86.8 % | X 0.02 | **CPU** |\n", - "| TensorRT | FP16 | 30 | 86.8 % | X 2.67 | GPU |\n", - "| TensorRT (**our method 1**) | **INT-8** | 15 | 84.4 % | **X 5.33** | **GPU** |\n", - "| TensorRT (**our method 2**) | **INT-8** | 16 | 85.8 % | **X 5.00** | **GPU** |\n", + "| Framework | Precision | Latency (ms) | Accuracy | Speedup | Hardware |\n", + "|:--------------------------|-----------|--------------|----------|:-----------|:--------:|\n", + "| Pytorch | FP32 | 4267 | 86.6 % | X 0.02 | CPU |\n", + "| Pytorch | FP16 | 4428 | 86.6 % | X 0.02 | CPU |\n", + "| Pytorch | INT-8 | 3300 | 85.9 % | X 0.02 | CPU |\n", + "| Pytorch | FP32 | 77 | 86.6 % | X 1 | GPU |\n", + "| Pytorch | FP16 | 56 | 86.6 % | X 1.38 | GPU |\n", + "| ONNX Runtime | FP32 | 76 | 86.6 % | X 1.01 | GPU |\n", + "| ONNX Runtime | FP16 | 34 | 86.6 % | X 2.26 | GPU |\n", + "| ONNX Runtime | FP32 | 4023 | 86.6 % | X 0.02 | CPU |\n", + "| ONNX Runtime | FP16 | 3957 | 86.6 % | X 0.02 | CPU |\n", + "| ONNX Runtime | INT-8 | 3336 | 86.5 % | X 0.02 | CPU |\n", + "| TensorRT | FP16 | 30 | 86.6 % | X 2.57 | GPU |\n", + "| TensorRT (**our method**) | **INT-8** | **17** | 86.2 % | **X 4.53** | **GPU** |\n", "\n", - "> measures done on a Nvidia RTX 3090 GPU + 12 cores i7 Intel CPU (support AVX-2 instructions)\n", + "> measures done on a Nvidia RTX 3090 GPU + 12 cores i7 Intel CPU (support AVX-2 instruction)\n", ">\n", - "> architecture `Roberta-base` with batch of size 32 / seq len 256, similar results obtained for other sizes/seq len not included in the table.\n", + "> `base` architecture flavor with batch of size 32 / seq len 256, similar results obtained for other sizes/seq len not included in the table.\n", ">\n", "> accuracy obtained after a single epoch, no LR search or any hyper parameter optimization\n", - ">\n", - "> CPU measures are a bit unfair, it's still possible to push performance a bit by adding lots of (Python related) complexities and using last generation CPU, still those measurements are indicative of orders of magnitude to expect from Pytorch+CPU deployment.\n", - ">\n", - "> same kind of acceleration is observed on all seq len / batch sizes\n", "\n", "\n", "## A (very) short intro to INT-8 quantization\n", "\n", "Basic idea behind model quantization is to replace tensors made of float numbers (usually encoded on 32 bits) by lower precision representation (integers encoded on 8 bits for Nvidia GPUs).\n", "Therefore computation is faster and model memory footprint is lower. Making tensor storage smaller makes memory transfer faster... and is also a source of computation acceleration.\n", - "This technic is very interesting for its trade-off: you reduce inference time significantly, and when dataset is large enough, it costs close to nothing in accuracy.\n", + "This approach is very interesting for its trade-off: you reduce inference time significantly, and it costs close to nothing in accuracy.\n", "\n", "Replacing float numbers by integers is done through a mapping.\n", - "This step is called `calibration`, and its purpose is to compute for each tensor or each channel of a tensor (one of its dimensions) a range of all possible values and then define a scale and a distribution center to map float numbers to 8 bits integers.\n", - "The process is well described in this [Nvidia presentation](https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf).\n", + "This step is called `calibration`, and its purpose is to compute for each tensor or each channel of a tensor (one of its dimensions) a range covering most weights and then define a scale and a distribution center to map float numbers to 8 bits integers.\n", "\n", "There are several ways to perform quantization, depending of how and when the `calibration` is performed:\n", "\n", - "* dynamically: the mapping is done during the inference, there are some overhead but it's easy to put in place and usually the accuracy is preserved,\n", - "* statically, after training (`post training quantization` or `PTQ`): this way is efficient, but it may have a significant accuracy cost,\n", - "* statically, before training (`quantization aware training` or `QAT`): this way is efficient and has a low accuracy cost as the weights will take care of the result\n", + "* dynamically: the mapping is done online, during the inference, there are some overhead but it's usually the easiest to leverage, end user has very few configuration to set,\n", + "* statically, after training (`post training quantization` or `PTQ`): this way is efficient because quantization is done offline, before inference, but it may have an accuracy cost,\n", + "* statically, after training (`quantization aware training` or `QAT`): like a PTQ followed by a second fine tuning. Same efficiency but usually slightly better accuracy.\n", + "\n", + "Nvidia GPUs don't support dynamic quantization, CPU supports all types of quantization. \n", + "Compared to `PTQ`, `QAT` better preserves accuracy and should be preferred in most cases.\n", "\n", - "In this guide we will focus on the third option: `QAT`.\n", "\n", "During the quantization aware *training*:\n", "\n", "* in the inside, Pytorch will train with high precision float numbers,\n", "* on the outside, Pytorch will simulate that a quantization has already been applied and output results accordingly (for loss computation for instance)\n", "\n", - "The simulation process is done through the add of quantization / dequantization nodes, most often called `QDQ`, it's an abbreviation you will see often in quantization world.\n", + "The simulation process is done through the add of quantization / dequantization nodes, most often called `QDQ`, it's an abbreviation you will see often in the quantization world.\n", + "\n", + "\n", "\n", - "You can check this [high quality blog post](https://leimao.github.io/article/Neural-Networks-Quantization/) for more information.\n", + "> Want to learn more about quantization?\n", + "> \n", + "> * You can check this [high quality blog post](https://leimao.github.io/article/Neural-Networks-Quantization/) for more information.\n", + "> * The process is well described in this [Nvidia presentation](https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf)\n", "\n", "## Why this notebook?\n", "\n", "CPU quantization is supported out of the box by `Pytorch` and `ONNX Runtime`.\n", "**GPU quantization on the other side requires specific tools and process to be applied**.\n", "\n", - "In the specific case of `transformer` models, until recently (december 2021), the only way shown by Nvidia is to build manually the graph of our models in `TensorRT`. This is a low level approach, based on GPU capacity knowledge (which operators are supported, etc.). It's certainly out of reach of most NLP practitioners and is very time consuming to update/adapt to new architectures.\n", + "In the specific case of `transformer` models, few demos from Nvidia and Microsoft exist; they are all for the old vanilla Bert architecture.\n", "\n", - "Hopefully, Nvidia added to Hugging Face `transformer` library a new model called `QDQBert` few weeks ago.\n", - "Basically, it's a vanilla `Bert` architecture which supports INT-8 quantization.\n", - "It doesn't support any other architecture out of the box, like `Albert`, `Roberta`, or `Electra`.\n", - "Nvidia also provide a demo dedicated to the SQuaD task.\n", - "\n", - "This open the door to extension of the approach to other architectures.\n", - "\n", - "To be both simple and cover most use cases, in this notebook we will see:\n", - "\n", - "* how to perform GPU quantization on **any** transformer model (not just Bert) using a simple trick, a `transplatation`\n", - "* how to perform GPU quantization on `QDQRoberta`, a custom model similar to `QDQBert` and supported by `transformer-deploy` library\n", - "* how to apply quantization to a common task like classification (which is easier to understand than question answering)\n", - "* measure performance gain (latency)\n" + "It doesn't support modern architectures out of the box, like `Albert`, `Roberta`, `Deberta` or `Electra`.\n", + "\n" ] }, { @@ -114,7 +104,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We install `master` branch of `transfomers` library to use a new model: **QDQBert** and `transformer-deploy` to leverage `TensorRT` models (TensorRT API is not something simple to master, it's highly advised to use a wrapper). Your machine should have Nvidia CUDA 11.X, TensorRT 8.2.1 and cuBLAS installed. It's said to be tricky to install, in my experience, just follow Nvidia instructions **and nothing else**, it should work out of the box. Docker image with TensorRT 8.2.1 has not yet been released, this notebook will be updated when it's ready." + "Your machine should have Nvidia CUDA 11.X, TensorRT 8.2.1 and cuBLAS installed. It's said to be tricky to install, in my experience, just follow Nvidia download page instructions **and nothing else**, it should work out of the box. Nvidia Docker image could be a good choice too." ] }, { @@ -125,12 +115,9 @@ }, "outputs": [], "source": [ - "#! pip install git+https://github.com/huggingface/transformers\n", - "#! pip install git+https://github.com/ELS-RD/transformer-deploy\n", - "#! pip install sklearn datasets\n", - "#! pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com\n", - "# or install pytorch-quantization from https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization\n", - "# pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\\&subdirectory=tools/pytorch-quantization/" + "#! pip3 install git+ssh://git@github.com/ELS-RD/transformer-deploy\n", + "#! pip3 install datasets sklearn\n", + "#! pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\\&subdirectory=tools/pytorch-quantization/" ] }, { @@ -155,16 +142,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Thu Dec 9 19:47:47 2021 \r\n", + "Tue Dec 28 21:46:26 2021 \r\n", "+-----------------------------------------------------------------------------+\r\n", - "| NVIDIA-SMI 495.29.05 Driver Version: 495.29.05 CUDA Version: 11.5 |\r\n", + "| NVIDIA-SMI 495.44 Driver Version: 495.44 CUDA Version: 11.5 |\r\n", "|-------------------------------+----------------------+----------------------+\r\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", "| | | MIG M. |\r\n", "|===============================+======================+======================|\r\n", - "| 0 NVIDIA GeForce ... On | 00000000:03:00.0 On | N/A |\r\n", - "| 67% 55C P8 45W / 350W | 286MiB / 24267MiB | 6% Default |\r\n", + "| 0 NVIDIA GeForce ... Off | 00000000:03:00.0 On | N/A |\r\n", + "| 35% 42C P8 40W / 350W | 263MiB / 24267MiB | 4% Default |\r\n", "| | | N/A |\r\n", "+-------------------------------+----------------------+----------------------+\r\n", " \r\n", @@ -173,11 +160,10 @@ "| GPU GI CI PID Type Process name GPU Memory |\r\n", "| ID ID Usage |\r\n", "|=============================================================================|\r\n", - "| 0 N/A N/A 1944 G /usr/lib/xorg/Xorg 148MiB |\r\n", - "| 0 N/A N/A 7816 G /usr/bin/gnome-shell 40MiB |\r\n", - "| 0 N/A N/A 529613 G ...518105.log --shared-files 13MiB |\r\n", - "| 0 N/A N/A 540908 G ...AAAAAAAAA= --shared-files 49MiB |\r\n", - "| 0 N/A N/A 1378576 G ...AAAAAAAAA= --shared-files 31MiB |\r\n", + "| 0 N/A N/A 1604 G /usr/lib/xorg/Xorg 159MiB |\r\n", + "| 0 N/A N/A 8473 G /usr/bin/gnome-shell 44MiB |\r\n", + "| 0 N/A N/A 106329 G ..._18576.log --shared-files 17MiB |\r\n", + "| 0 N/A N/A 110356 G ...AAAAAAAAA= --shared-files 39MiB |\r\n", "+-----------------------------------------------------------------------------+\r\n" ] } @@ -223,46 +209,51 @@ }, "outputs": [], "source": [ - "import numpy as np\n", - "from tqdm.notebook import tqdm\n", - "import transformers\n", + "import logging\n", + "import os\n", + "from collections import OrderedDict\n", + "from typing import Dict, List\n", + "from typing import OrderedDict as OD\n", + "from typing import Union\n", + "\n", "import datasets\n", - "from typing import OrderedDict as OD, List, Dict, Union\n", + "import numpy as np\n", + "import pycuda.autoinit\n", + "import tensorrt as trt\n", "import torch\n", - "from torch import Tensor\n", + "import transformers\n", + "from datasets import load_dataset, load_metric\n", + "from pycuda._driver import Stream\n", + "from tensorrt.tensorrt import IExecutionContext, Logger, Runtime\n", + "from pytorch_quantization import nn as quant_nn\n", + "\n", "from transformers import (\n", " AutoModelForSequenceClassification,\n", - " PreTrainedModel,\n", - " QDQBertForSequenceClassification,\n", - " BertForSequenceClassification,\n", - " TrainingArguments,\n", - " Trainer,\n", - " IntervalStrategy,\n", " AutoTokenizer,\n", + " IntervalStrategy,\n", + " PreTrainedModel,\n", " PreTrainedTokenizer,\n", + " Trainer,\n", + " TrainingArguments,\n", + ")\n", + "\n", + "from transformer_deploy.backends.ort_utils import (\n", + " convert_to_onnx,\n", + " convert_to_quant_onnx,\n", + " cpu_quantization,\n", + " create_model_for_provider,\n", + " optimize_onnx,\n", ")\n", - "from datasets import load_dataset, load_metric\n", - "from transformer_deploy.QDQModels.QDQRoberta import QDQRobertaForSequenceClassification\n", - "import pytorch_quantization.nn as quant_nn\n", - "from pytorch_quantization.tensor_quant import QuantDescriptor\n", - "from pytorch_quantization import calib\n", - "import logging\n", - "from datasets import DatasetDict\n", "from transformer_deploy.backends.trt_utils import build_engine, get_binding_idxs, infer_tensorrt\n", - "from transformer_deploy.backends.ort_utils import convert_to_onnx\n", - "from collections import OrderedDict\n", - "from transformer_deploy.benchmarks.utils import track_infer_time, print_timings\n", - "from pycuda._driver import Stream\n", - "import tensorrt as trt\n", - "from tensorrt.tensorrt import IExecutionContext, Logger, Runtime\n", - "import pycuda.autoinit" + "from transformer_deploy.benchmarks.utils import print_timings, track_infer_time\n", + "from transformer_deploy.QDQModels.calibration_utils import QATCalibrate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Set logging to `error` to make the `notebook` more readable on Github." + "Set logging to `error` level to ease readability of this `notebook` on Github." ] }, { @@ -287,14 +278,16 @@ "id": "rEJBSTyZIrIb" }, "source": [ - "### Download data" + "### Preprocess data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This part is inspired from an [official Notebooks from Hugging Face](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb)." + "This part is inspired from an [official Notebooks from Hugging Face](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb).\n", + "\n", + "There is nothing special to do. Define the task:" ] }, { @@ -305,13 +298,15 @@ }, "outputs": [], "source": [ + "model_name = \"roberta-base\"\n", "task = \"mnli\"\n", "num_labels = 3\n", - "model_checkpoint = \"roberta-base\"\n", "batch_size = 32\n", "max_seq_len = 256\n", "validation_key = \"validation_matched\"\n", - "timings: Dict[str, List[float]] = dict()" + "timings: Dict[str, List[float]] = dict()\n", + "runtime: Runtime = trt.Runtime(trt_logger)\n", + "profile_index = 0" ] }, { @@ -320,7 +315,7 @@ "id": "W7QYTpxXIrIl" }, "source": [ - "We will use the [πŸ€— Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark)." + "Preprocess data (task specific):" ] }, { @@ -329,233 +324,14 @@ "metadata": { "id": "IreSlFmlIrIm" }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ff977d13b16d44ddbf9536d565248621", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/5 [00:00 PreTrainedModel:\n", - " # Find the TensorQuantizer and enable calibration\n", - " for name, module in model.named_modules():\n", - " if isinstance(module, quant_nn.TensorQuantizer):\n", - " if module._calibrator is not None:\n", - " module.disable_quant()\n", - " module.enable_calib()\n", - " else:\n", - " module.disable()\n", - "\n", - " with torch.no_grad():\n", - " for start_index in tqdm(range(0, nb_sample, batch_size)):\n", - " end_index = start_index + batch_size\n", - " data = encoded_dataset[\"train\"][start_index:end_index]\n", - " input_torch = {\n", - " k: torch.tensor(v, dtype=torch.long, device=\"cpu\")\n", - " for k, v in data.items()\n", - " if k in [\"input_ids\", \"attention_mask\", \"token_type_ids\"]\n", - " }\n", - " model(**input_torch)\n", - "\n", - " # Finalize calibration\n", - " for name, module in model.named_modules():\n", - " if isinstance(module, quant_nn.TensorQuantizer):\n", - " if module._calibrator is not None:\n", - " if isinstance(module._calibrator, calib.MaxCalibrator):\n", - " module.load_calib_amax()\n", - " else:\n", - " module.load_calib_amax(\"percentile\", percentile=99.99)\n", - " module.enable_quant()\n", - " module.disable_calib()\n", - " else:\n", - " module.enable()\n", - "\n", - " model.cuda()\n", - " return model\n", - "\n", - "\n", "def convert_tensor(data: OD[str, List[List[int]]], output: str) -> OD[str, Union[np.ndarray, torch.Tensor]]:\n", " input: OD[str, Union[np.ndarray, torch.Tensor]] = OrderedDict()\n", " for k in [\"input_ids\", \"attention_mask\", \"token_type_ids\"]:\n", @@ -615,45 +353,12 @@ " else:\n", " raise Exception(f\"unknown output type: {output}\")\n", " input[k] = value\n", - " return input" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some `TensorRT` reused variables:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "runtime: Runtime = trt.Runtime(trt_logger)\n", - "profile_index = 0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Measure accuracy for ONNX Runtime and TensorRT:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "validation_labels = [item[\"label\"] for item in encoded_dataset[validation_key]]\n", + " return input\n", "\n", "\n", "def measure_accuracy(infer, int64: bool) -> float:\n", " outputs = list()\n", - " for start_index in tqdm(range(0, len(encoded_dataset[validation_key]), batch_size)):\n", + " for start_index in range(0, len(encoded_dataset[validation_key]), batch_size):\n", " end_index = start_index + batch_size\n", " data = encoded_dataset[validation_key][start_index:end_index]\n", " inputs: OD[str, np.ndarray] = convert_tensor(data=data, output=\"np\")\n", @@ -663,30 +368,59 @@ " output = infer(inputs)\n", " output = np.argmax(output[0], axis=1).astype(int).tolist()\n", " outputs.extend(output)\n", - " return np.mean(np.array(outputs) == np.array(validation_labels))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fine-tuning model\n", + " return np.mean(np.array(outputs) == np.array(validation_labels))\n", "\n", - "Now that our data are ready, we can download the pretrained model and fine-tune it.\n", "\n", - "Default parameters to be used for the training:" + "def get_trainer(model: PreTrainedModel) -> Trainer:\n", + " trainer = Trainer(\n", + " model,\n", + " args,\n", + " train_dataset=encoded_dataset[\"train\"],\n", + " eval_dataset=encoded_dataset[validation_key],\n", + " tokenizer=tokenizer,\n", + " compute_metrics=compute_metrics,\n", + " )\n", + " transformers.logging.set_verbosity_error()\n", + " return trainer" ] }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "67a7277cd56f4f4ab5746977446e740a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/5 [00:00 Bert, there is 1 to 1 correspondance, for other models, you may need to create your own mapping.\n", - "for bert_key in bert_keys:\n", - " # pop remove the first weights from the Ordered dict ...\n", - " _, weight = model_weights.popitem(last=False)\n", - " # ... and we re-insert them, in order, with a new key\n", - " model_weights[bert_key] = weight\n", - "\n", - "# we re-export the weights\n", - "torch.save(model_weights, \"roberta-in-bert/pytorch_model.bin\")\n", - "del model_weights" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We override the architecture name to make `transformers` believe it is `Bert`..." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# =====> change architecture to bert base <======\n", - "import json\n", - "\n", - "with open(\"roberta-in-bert/config.json\") as f:\n", - " content = json.load(f)\n", - " content[\"architectures\"] = [\"bert\"]\n", - "\n", - "with open(\"roberta-in-bert/config.json\", mode=\"w\") as f:\n", - " json.dump(content, f)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model training\n", - "\n", - "The goal of this first training is to update weights to the new architecture to help the next step, the calibration.\n", - "Indeed, `Roberta` architecture is a bit different from vanilla `Bert`, for instance position embeddings are not managed the same way, as they are at the very bottom of the model, they impact all model layers.\n", - "If we skip this step, the value ranges computed during the calibration step may be very wrong and the `QAT` would provide low accuracy score." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[INFO|trainer.py:437] 2021-12-09 19:48:11,412 >> Using amp half precision backend\n" + "[INFO|trainer.py:439] 2021-12-27 09:19:51,063 >> Using amp half precision backend\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "{'loss': 0.7474, 'learning_rate': 9.1875814863103e-06, 'epoch': 0.08}\n", - "{'eval_loss': 0.5083153247833252, 'eval_accuracy': 0.8007131940906775, 'eval_runtime': 18.9412, 'eval_samples_per_second': 518.182, 'eval_steps_per_second': 8.13, 'epoch': 0.08}\n", - "{'loss': 0.5457, 'learning_rate': 8.372718383311604e-06, 'epoch': 0.16}\n", - "{'eval_loss': 0.47982481122016907, 'eval_accuracy': 0.8163015792154865, 'eval_runtime': 19.417, 'eval_samples_per_second': 505.486, 'eval_steps_per_second': 7.931, 'epoch': 0.16}\n", - "{'loss': 0.5075, 'learning_rate': 7.557855280312908e-06, 'epoch': 0.24}\n", - "{'eval_loss': 0.4634084105491638, 'eval_accuracy': 0.8218033622007132, 'eval_runtime': 19.6824, 'eval_samples_per_second': 498.67, 'eval_steps_per_second': 7.824, 'epoch': 0.24}\n", - "{'loss': 0.483, 'learning_rate': 6.743807040417211e-06, 'epoch': 0.33}\n", - "{'eval_loss': 0.42370399832725525, 'eval_accuracy': 0.8344370860927153, 'eval_runtime': 19.5405, 'eval_samples_per_second': 502.29, 'eval_steps_per_second': 7.881, 'epoch': 0.33}\n", - "{'loss': 0.4652, 'learning_rate': 5.9289439374185145e-06, 'epoch': 0.41}\n", - "{'eval_loss': 0.4242672026157379, 'eval_accuracy': 0.8365766683647479, 'eval_runtime': 18.8626, 'eval_samples_per_second': 520.343, 'eval_steps_per_second': 8.164, 'epoch': 0.41}\n", - "{'loss': 0.451, 'learning_rate': 5.114080834419818e-06, 'epoch': 0.49}\n", - "{'eval_loss': 0.42570847272872925, 'eval_accuracy': 0.8365766683647479, 'eval_runtime': 18.9979, 'eval_samples_per_second': 516.636, 'eval_steps_per_second': 8.106, 'epoch': 0.49}\n", - "{'loss': 0.4505, 'learning_rate': 4.299217731421121e-06, 'epoch': 0.57}\n", - "{'eval_loss': 0.4005618989467621, 'eval_accuracy': 0.8451349974528782, 'eval_runtime': 18.8952, 'eval_samples_per_second': 519.443, 'eval_steps_per_second': 8.15, 'epoch': 0.57}\n", - "{'loss': 0.4422, 'learning_rate': 3.4859843546284226e-06, 'epoch': 0.65}\n", - "{'eval_loss': 0.3936935365200043, 'eval_accuracy': 0.8445236882322975, 'eval_runtime': 18.9752, 'eval_samples_per_second': 517.253, 'eval_steps_per_second': 8.116, 'epoch': 0.65}\n", - "{'loss': 0.4332, 'learning_rate': 2.6711212516297265e-06, 'epoch': 0.73}\n", - "{'eval_loss': 0.3954601585865021, 'eval_accuracy': 0.8470708099847173, 'eval_runtime': 18.7912, 'eval_samples_per_second': 522.32, 'eval_steps_per_second': 8.195, 'epoch': 0.73}\n", - "{'loss': 0.4254, 'learning_rate': 1.8570730117340288e-06, 'epoch': 0.81}\n", - "{'eval_loss': 0.39208605885505676, 'eval_accuracy': 0.8492103922567499, 'eval_runtime': 19.1256, 'eval_samples_per_second': 513.185, 'eval_steps_per_second': 8.052, 'epoch': 0.81}\n", - "{'loss': 0.4263, 'learning_rate': 1.0422099087353325e-06, 'epoch': 0.9}\n", - "{'eval_loss': 0.384700208902359, 'eval_accuracy': 0.851044319918492, 'eval_runtime': 19.1997, 'eval_samples_per_second': 511.207, 'eval_steps_per_second': 8.021, 'epoch': 0.9}\n", - "{'loss': 0.4263, 'learning_rate': 2.2816166883963498e-07, 'epoch': 0.98}\n", - "{'eval_loss': 0.38375720381736755, 'eval_accuracy': 0.8508405501782985, 'eval_runtime': 18.5961, 'eval_samples_per_second': 527.8, 'eval_steps_per_second': 8.281, 'epoch': 0.98}\n", - "{'train_runtime': 2701.3089, 'train_samples_per_second': 145.375, 'train_steps_per_second': 4.543, 'train_loss': 0.48233796906751014, 'epoch': 1.0}\n", - "{'eval_loss': 0.384700208902359, 'eval_accuracy': 0.851044319918492, 'eval_runtime': 18.5921, 'eval_samples_per_second': 527.911, 'eval_steps_per_second': 8.283, 'epoch': 1.0}\n", - "{'eval_loss': 0.384700208902359, 'eval_accuracy': 0.851044319918492, 'eval_runtime': 18.5921, 'eval_samples_per_second': 527.911, 'eval_steps_per_second': 8.283, 'epoch': 1.0}\n" + "{'loss': 0.6605, 'learning_rate': 9.1875814863103e-06, 'epoch': 0.08}\n", + "{'eval_loss': 0.4653007388114929, 'eval_accuracy': 0.8183392766174223, 'eval_runtime': 18.2981, 'eval_samples_per_second': 536.393, 'eval_steps_per_second': 8.416, 'epoch': 0.08}\n", + "{'loss': 0.4956, 'learning_rate': 8.372718383311604e-06, 'epoch': 0.16}\n", + "{'eval_loss': 0.4208127558231354, 'eval_accuracy': 0.8346408558329088, 'eval_runtime': 18.3709, 'eval_samples_per_second': 534.268, 'eval_steps_per_second': 8.383, 'epoch': 0.16}\n", + "{'loss': 0.4662, 'learning_rate': 7.557855280312908e-06, 'epoch': 0.24}\n", + "{'eval_loss': 0.42171549797058105, 'eval_accuracy': 0.8358634742740703, 'eval_runtime': 18.3642, 'eval_samples_per_second': 534.464, 'eval_steps_per_second': 8.386, 'epoch': 0.24}\n", + "{'loss': 0.4458, 'learning_rate': 6.7429921773142115e-06, 'epoch': 0.33}\n", + "{'eval_loss': 0.3808833658695221, 'eval_accuracy': 0.8527763627101376, 'eval_runtime': 18.3578, 'eval_samples_per_second': 534.649, 'eval_steps_per_second': 8.389, 'epoch': 0.33}\n", + "{'loss': 0.4295, 'learning_rate': 5.9289439374185145e-06, 'epoch': 0.41}\n", + "{'eval_loss': 0.383415549993515, 'eval_accuracy': 0.851044319918492, 'eval_runtime': 18.3946, 'eval_samples_per_second': 533.58, 'eval_steps_per_second': 8.372, 'epoch': 0.41}\n", + "{'loss': 0.4193, 'learning_rate': 5.1148956975228174e-06, 'epoch': 0.49}\n", + "{'eval_loss': 0.3880891799926758, 'eval_accuracy': 0.8494141619969434, 'eval_runtime': 18.4347, 'eval_samples_per_second': 532.418, 'eval_steps_per_second': 8.354, 'epoch': 0.49}\n", + "{'loss': 0.4166, 'learning_rate': 4.30003259452412e-06, 'epoch': 0.57}\n", + "{'eval_loss': 0.3630894124507904, 'eval_accuracy': 0.8582781456953642, 'eval_runtime': 18.5126, 'eval_samples_per_second': 530.181, 'eval_steps_per_second': 8.319, 'epoch': 0.57}\n", + "{'loss': 0.4111, 'learning_rate': 3.4851694915254244e-06, 'epoch': 0.65}\n", + "{'eval_loss': 0.3584975004196167, 'eval_accuracy': 0.8596026490066225, 'eval_runtime': 18.4771, 'eval_samples_per_second': 531.198, 'eval_steps_per_second': 8.335, 'epoch': 0.65}\n", + "{'loss': 0.4002, 'learning_rate': 2.6711212516297265e-06, 'epoch': 0.73}\n", + "{'eval_loss': 0.36166584491729736, 'eval_accuracy': 0.8625573102394295, 'eval_runtime': 18.489, 'eval_samples_per_second': 530.857, 'eval_steps_per_second': 8.329, 'epoch': 0.73}\n", + "{'loss': 0.3938, 'learning_rate': 1.8562581486310302e-06, 'epoch': 0.81}\n", + "{'eval_loss': 0.354215145111084, 'eval_accuracy': 0.8649006622516556, 'eval_runtime': 18.4614, 'eval_samples_per_second': 531.651, 'eval_steps_per_second': 8.342, 'epoch': 0.81}\n", + "{'loss': 0.3951, 'learning_rate': 1.0413950456323338e-06, 'epoch': 0.9}\n", + "{'eval_loss': 0.3511120676994324, 'eval_accuracy': 0.8663270504330107, 'eval_runtime': 18.512, 'eval_samples_per_second': 530.197, 'eval_steps_per_second': 8.319, 'epoch': 0.9}\n", + "{'loss': 0.3972, 'learning_rate': 2.265319426336376e-07, 'epoch': 0.98}\n", + "{'eval_loss': 0.34958672523498535, 'eval_accuracy': 0.8661232806928171, 'eval_runtime': 18.4826, 'eval_samples_per_second': 531.04, 'eval_steps_per_second': 8.332, 'epoch': 0.98}\n", + "{'train_runtime': 2606.7651, 'train_samples_per_second': 150.647, 'train_steps_per_second': 4.708, 'train_loss': 0.44322973124517767, 'epoch': 1.0}\n", + "{'eval_loss': 0.3511120676994324, 'eval_accuracy': 0.8663270504330107, 'eval_runtime': 18.5143, 'eval_samples_per_second': 530.13, 'eval_steps_per_second': 8.318, 'epoch': 1.0}\n", + "{'eval_loss': 0.3511120676994324, 'eval_accuracy': 0.8663270504330107, 'eval_runtime': 18.5143, 'eval_samples_per_second': 530.13, 'eval_steps_per_second': 8.318, 'epoch': 1.0}\n" ] } ], "source": [ - "transformers.logging.set_verbosity_error()\n", - "model_bert = BertForSequenceClassification.from_pretrained(\"roberta-in-bert\", num_labels=num_labels)\n", - "model_bert = model_bert.cuda()\n", - "\n", - "trainer = Trainer(\n", - " model_bert,\n", - " args,\n", - " train_dataset=encoded_dataset[\"train\"],\n", - " eval_dataset=encoded_dataset[validation_key],\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics,\n", - ")\n", + "model_fp16: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)\n", + "trainer = get_trainer(model_fp16)\n", "transformers.logging.set_verbosity_error()\n", "trainer.train()\n", "print(trainer.evaluate())\n", - "model_bert.save_pretrained(\"roberta-in-bert-trained\")\n", - "del trainer\n", - "del model_bert" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Quantization" + "\n", + "model_fp16.save_pretrained(\"model_trained_fp16\")" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ - "Below we will start the quantization process.\n", - "It follow those steps:\n", + "## Add quantization support to any model\n", "\n", - "* perform the calibration\n", - "* perform a quantization aware training\n", + "The idea is to take the source code of a specific model and add automatically `QDQ` nodes. QDQ nodes will be placed before and after an operation that we want to quantize, that’s inside these nodes that the information to perform the mapping between high precision and low precision number is stored.\n", "\n", - "By passing validation values to the model, we will calibrate it, meaning it will get the right range / scale to convert FP32 weights to int-8 ones." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Calibration" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Activate histogram calibration\n", + "That way, quantization will work out of the box for the final user.\n", "\n", - "There are several kinds of calbrators, below we use the percentile one (99.99p) (`histogram`), basically, its purpose is to just remove the most extreme values before computing range / scale.\n", - "The other option in NLP is `max`, it's much faster but expect lower accuracy.\n", + "The process is based on Python AST modification, basically we parse the model source code in RAM, we convert it to a tree, then we patch the tree to add the QDQ nodes and we replace, still in RAM, the original module source code. Our library also offer the option to restore original behavior.\n", "\n", - "Second calibration option, choose between calibration done at the tensor level or per channel (finer grained value ranges, a bit slower).\n", - "Calibration is based on few samples (in our case 128 sequences)." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# you can also use \"max\" instead of \"historgram\"\n", - "input_desc = QuantDescriptor(num_bits=8, calib_method=\"histogram\")\n", - "# below we do per-channel quantization for weights, set axis to None to get a per tensor calibration\n", - "weight_desc = QuantDescriptor(num_bits=8, axis=(0,))\n", - "quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)\n", - "quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Perform calibration\n", + "In theory it works for any model. However, not related to quantization, some models are not fully compliant with `TensorRT` (unsupported operators, etc.).\n", + "For those models, we rewrite some part of the source code, these patches are manually written but are applied to the model at run time (like the AST manipulation).\n", "\n", - "During this step we will enable the calibration nodes, and pass some representative data to the model.\n", - "It will then be used to compute the scale/range.\n", + "> concrete examples on `Roberta` architecture: in HF library, there is a `cumsum` operator used during the position embedding generation. Something very simple. It takes as input an integer tensor and output an integer tensor. It happens that the `cumsum` operator from TensorRT supports float but not integer (https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md). It leads to a crash during the model conversion with a strange error message. Converting the input to float tensor fixes the issue. \n", "\n", - "Official recommendations from Nvidia is to calibrate over thousands of examples from the validation set.\n", - "Here we use 128 examples because it's a slow process. It's enough to be close from the original accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "776c4daba3a04b34a72210697ad37e6a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/4 [00:00 there are many ways to get a QDQ model, you can modify Pytorch source code (including doing it at runtime like here), patch ONNX graph (this approach is used at Microsoft for instance but only support PTQ, not QAT as ONNX file can't be trained on Pytorch for now) or leverage the new FX Pytorch interface (it's a bit experimental and it seems to miss some feature to support Nvidia QAT library). Modifying the source code is the most straightforward, and doing it through AST is the least intrusive (no need to duplicate the work of HF).\n", + "\n", + "### Post Training Quantization (PTQ)\n", + "A PTQ is basically a fine tuned model where we add quantization nodes and that we calibrate.\n", + "\n", + "Calibration is a key step in the static quantization process. Its quality depends on the final accuracy (the inference speed will stay the same). \n", + "Moreover, a good PTQ is a good basis for a good Quantization Aware Training (QAT).\n", + "\n", + "By calling `with QATCalibrate(...) as qat:`, the lib will patch transformer model AST (source code) in RAM, basically adding quantization support to each model.\n", + "\n", + "#### Calibration percentile grid search\n", + "\n", + "One of the things we try to guess during the calibration is what range of tensor values capture most of the information stored in the tensor. Indeed, a FP32 tensor can store at the same time very large and very small values, we obviously can't do the same with a 8-bits integer tensors and a scale. An 8-bits integer can only encode 255 values so we need to fix some limits and say, if a value is outside our limits, it just takes a maximum value instead of its real one. For instance, if we say our range is -1000 to +1000 and a tensor contains the value +4000, it will be replaced by the maximum value, +1000.\n", + "\n", + "As said before, we will use the histogram method to find the perfect range. We also need to choose a percentile. Usually, you will choose something very close to 100.\n", + "\n", + "If the percentile is too small, we put too many values outside the covered range. Values outside the range will be replaced by a single maximum value and you lose some granularity in model weights.\n", "\n", - "The query aware training is not a mandatory step, but **highly** recommended to get the best accuracy. Basically we will redo the training with the quantization enabled and a low learning rate to avoid overfitting." + "If the percentile is too big, your range will be very large and because 8-bits signed integers can only encode values between -127 to +127, even when you use a scale you lose in granularity.\n", + "\n", + "Therefore, we launch a grid search on percentile hyper parameter.\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { - "id": "imY1oC3SIrJf" + "pycharm": { + "name": "#%%\n" + } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[INFO|trainer.py:437] 2021-12-09 20:39:35,717 >> Using amp half precision backend\n" + "[INFO|trainer.py:439] 2021-12-27 17:25:51,070 >> Using amp half precision backend\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "{'eval_loss': 0.4307818114757538, 'eval_accuracy': 0.8327050433010698, 'eval_runtime': 45.8394, 'eval_samples_per_second': 214.117, 'eval_steps_per_second': 3.36}\n", - "{'eval_loss': 0.4307818114757538, 'eval_accuracy': 0.8327050433010698, 'eval_runtime': 45.8394, 'eval_samples_per_second': 214.117, 'eval_steps_per_second': 3.36}\n", - "{'loss': 0.4542, 'learning_rate': 9.187581486310299e-07, 'epoch': 0.08}\n", - "{'eval_loss': 0.4258573651313782, 'eval_accuracy': 0.8376974019358125, 'eval_runtime': 46.5114, 'eval_samples_per_second': 211.023, 'eval_steps_per_second': 3.311, 'epoch': 0.08}\n", - "{'loss': 0.4422, 'learning_rate': 8.372718383311604e-07, 'epoch': 0.16}\n", - "{'eval_loss': 0.4214017987251282, 'eval_accuracy': 0.8395313295975547, 'eval_runtime': 46.5334, 'eval_samples_per_second': 210.924, 'eval_steps_per_second': 3.309, 'epoch': 0.16}\n", - "{'loss': 0.4268, 'learning_rate': 7.557855280312907e-07, 'epoch': 0.24}\n", - "{'eval_loss': 0.41808152198791504, 'eval_accuracy': 0.8423841059602649, 'eval_runtime': 46.5578, 'eval_samples_per_second': 210.813, 'eval_steps_per_second': 3.308, 'epoch': 0.24}\n", - "{'loss': 0.4223, 'learning_rate': 6.742992177314211e-07, 'epoch': 0.33}\n", - "{'eval_loss': 0.42257484793663025, 'eval_accuracy': 0.838920020376974, 'eval_runtime': 46.599, 'eval_samples_per_second': 210.627, 'eval_steps_per_second': 3.305, 'epoch': 0.33}\n", - "{'loss': 0.4282, 'learning_rate': 5.928943937418513e-07, 'epoch': 0.41}\n", - "{'eval_loss': 0.41137370467185974, 'eval_accuracy': 0.8409577177789098, 'eval_runtime': 46.5435, 'eval_samples_per_second': 210.878, 'eval_steps_per_second': 3.309, 'epoch': 0.41}\n", - "{'loss': 0.4234, 'learning_rate': 5.114895697522817e-07, 'epoch': 0.49}\n", - "{'eval_loss': 0.41403627395629883, 'eval_accuracy': 0.841874681609781, 'eval_runtime': 46.5743, 'eval_samples_per_second': 210.739, 'eval_steps_per_second': 3.307, 'epoch': 0.49}\n", - "{'loss': 0.4202, 'learning_rate': 4.3000325945241197e-07, 'epoch': 0.57}\n", - "{'eval_loss': 0.4175918698310852, 'eval_accuracy': 0.8417727967396842, 'eval_runtime': 46.5654, 'eval_samples_per_second': 210.779, 'eval_steps_per_second': 3.307, 'epoch': 0.57}\n", - "{'loss': 0.4289, 'learning_rate': 3.4859843546284223e-07, 'epoch': 0.65}\n", - "{'eval_loss': 0.41122177243232727, 'eval_accuracy': 0.8441161487519103, 'eval_runtime': 46.5376, 'eval_samples_per_second': 210.905, 'eval_steps_per_second': 3.309, 'epoch': 0.65}\n", - "{'loss': 0.4283, 'learning_rate': 2.6711212516297263e-07, 'epoch': 0.73}\n", - "{'eval_loss': 0.4128565490245819, 'eval_accuracy': 0.8426897605705552, 'eval_runtime': 46.7511, 'eval_samples_per_second': 209.942, 'eval_steps_per_second': 3.294, 'epoch': 0.73}\n", - "{'loss': 0.4174, 'learning_rate': 1.8570730117340285e-07, 'epoch': 0.81}\n", - "{'eval_loss': 0.40628135204315186, 'eval_accuracy': 0.8441161487519103, 'eval_runtime': 46.6972, 'eval_samples_per_second': 210.184, 'eval_steps_per_second': 3.298, 'epoch': 0.81}\n", - "{'loss': 0.4191, 'learning_rate': 1.0422099087353324e-07, 'epoch': 0.9}\n", - "{'eval_loss': 0.41109439730644226, 'eval_accuracy': 0.8451349974528782, 'eval_runtime': 46.5565, 'eval_samples_per_second': 210.819, 'eval_steps_per_second': 3.308, 'epoch': 0.9}\n", - "{'loss': 0.4178, 'learning_rate': 2.2734680573663624e-08, 'epoch': 0.98}\n", - "{'eval_loss': 0.409859836101532, 'eval_accuracy': 0.8464595007641366, 'eval_runtime': 46.5572, 'eval_samples_per_second': 210.816, 'eval_steps_per_second': 3.308, 'epoch': 0.98}\n", - "{'train_runtime': 4943.333, 'train_samples_per_second': 79.441, 'train_steps_per_second': 2.483, 'train_loss': 0.4271986904169155, 'epoch': 1.0}\n", - "{'eval_loss': 0.409859836101532, 'eval_accuracy': 0.8464595007641366, 'eval_runtime': 46.5126, 'eval_samples_per_second': 211.018, 'eval_steps_per_second': 3.311, 'epoch': 1.0}\n", - "{'eval_loss': 0.409859836101532, 'eval_accuracy': 0.8464595007641366, 'eval_runtime': 46.5126, 'eval_samples_per_second': 211.018, 'eval_steps_per_second': 3.311, 'epoch': 1.0}\n" + "percentile: 99.9\n", + "{'eval_loss': 0.47421666979789734, 'eval_accuracy': 0.8121242995415181, 'eval_runtime': 47.9158, 'eval_samples_per_second': 204.839, 'eval_steps_per_second': 3.214}\n", + "{'eval_loss': 0.47421666979789734, 'eval_accuracy': 0.8121242995415181, 'eval_runtime': 47.9158, 'eval_samples_per_second': 204.839, 'eval_steps_per_second': 3.214}\n" ] - } - ], - "source": [ - "model_q = QDQBertForSequenceClassification.from_pretrained(\"roberta-in-bert-trained-quantized\", num_labels=num_labels)\n", - "model_q = model_q.cuda()\n", - "\n", - "args.learning_rate = 1e-6\n", - "\n", - "trainer = Trainer(\n", - " model_q,\n", - " args,\n", - " train_dataset=encoded_dataset[\"train\"],\n", - " eval_dataset=encoded_dataset[validation_key],\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics,\n", - ")\n", - "transformers.logging.set_verbosity_error()\n", - "print(trainer.evaluate())\n", - "trainer.train()\n", - "print(trainer.evaluate())\n", - "model_q.save_pretrained(\"roberta-in-bert-trained-quantized-bis\")\n", - "del model_q\n", - "del trainer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Benchmark" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Export a `QDQ Pytorch` model on `ONNX`, we need to enable fake quantization mode from Pytorch." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ + }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:285: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", - " inputs, amax.item() / bound, 0,\n", - "/home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:291: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", - " quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])\n" + "[INFO|trainer.py:439] 2021-12-27 17:30:13,795 >> Using amp half precision backend\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "percentile: 99.99\n", + "{'eval_loss': 0.3841923773288727, 'eval_accuracy': 0.8487009679062659, 'eval_runtime': 46.6715, 'eval_samples_per_second': 210.3, 'eval_steps_per_second': 3.3}\n", + "{'eval_loss': 0.3841923773288727, 'eval_accuracy': 0.8487009679062659, 'eval_runtime': 46.6715, 'eval_samples_per_second': 210.3, 'eval_steps_per_second': 3.3}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO|trainer.py:439] 2021-12-27 17:34:34,280 >> Using amp half precision backend\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "percentile: 99.999\n", + "{'eval_loss': 0.3939284086227417, 'eval_accuracy': 0.850636780438105, 'eval_runtime': 49.1138, 'eval_samples_per_second': 199.842, 'eval_steps_per_second': 3.136}\n", + "{'eval_loss': 0.3939284086227417, 'eval_accuracy': 0.850636780438105, 'eval_runtime': 49.1138, 'eval_samples_per_second': 199.842, 'eval_steps_per_second': 3.136}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO|trainer.py:439] 2021-12-27 17:38:54,289 >> Using amp half precision backend\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "percentile: 99.9999\n", + "{'eval_loss': 1.0285985469818115, 'eval_accuracy': 0.4956698930208864, 'eval_runtime': 48.0849, 'eval_samples_per_second': 204.118, 'eval_steps_per_second': 3.203}\n", + "{'eval_loss': 1.0285985469818115, 'eval_accuracy': 0.4956698930208864, 'eval_runtime': 48.0849, 'eval_samples_per_second': 204.118, 'eval_steps_per_second': 3.203}\n" ] } ], "source": [ - "data = encoded_dataset[\"train\"][0:3]\n", - "input_torch = convert_tensor(data, output=\"torch\")\n", - "\n", - "model_q = QDQBertForSequenceClassification.from_pretrained(\n", - " \"roberta-in-bert-trained-quantized-bis\", num_labels=num_labels\n", - ")\n", - "model_q = model_q.cuda()\n", - "from pytorch_quantization.nn import TensorQuantizer\n", + "for percentile in [99.9, 99.99, 99.999, 99.9999]:\n", + " with QATCalibrate(method=\"histogram\", percentile=percentile) as qat:\n", + " model_q: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(\n", + " \"model_trained_fp16\", num_labels=num_labels\n", + " )\n", + " model_q = model_q.cuda()\n", + " qat.setup_model_qat(model_q) # prepare quantizer to any model\n", "\n", - "TensorQuantizer.use_fb_fake_quant = True\n", - "convert_to_onnx(model_q, output_path=\"model_q.onnx\", inputs_pytorch=input_torch, opset=13)\n", - "TensorQuantizer.use_fb_fake_quant = False\n", - "# del model_q" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "#### Convert `ONNX` graph to `TensorRT` engine" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "engine = build_engine(\n", - " runtime=runtime,\n", - " onnx_file_path=\"model_q.onnx\",\n", - " logger=trt_logger,\n", - " min_shape=(1, max_seq_len), # 1 in batch size to support batch from size 1 to 32\n", - " optimal_shape=(batch_size, max_seq_len),\n", - " max_shape=(batch_size, max_seq_len),\n", - " workspace_size=10000 * 1024 * 1024,\n", - " fp16=False,\n", - " int8=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "scrolled": true - }, - "outputs": [], - "source": [ - "# same thing from command line\n", - "# !/usr/src/tensorrt/bin/trtexec --onnx=model_q.onnx --shapes=input_ids:32x256,attention_mask:32x256 --int8 --workspace=10000 --saveEngine=\"test.plan\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "#### Prepare input and output buffer" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "stream: Stream = pycuda.driver.Stream()\n", - "context: IExecutionContext = engine.create_execution_context()\n", - "context.set_optimization_profile_async(profile_index=profile_index, stream_handle=stream.handle)\n", - "input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index) # type: List[int], List[int]" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "data = encoded_dataset[\"train\"][0:batch_size]\n", - "input_np: Dict[str, np.ndarray] = convert_tensor(data, output=\"np\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "#### Inference on `TensorRT`" + " with torch.no_grad():\n", + " for start_index in range(0, 128, batch_size):\n", + " end_index = start_index + batch_size\n", + " data = encoded_dataset[\"train\"][start_index:end_index]\n", + " input_torch = {\n", + " k: torch.tensor(v, dtype=torch.long, device=\"cuda\")\n", + " for k, v in data.items()\n", + " if k in [\"input_ids\", \"attention_mask\", \"token_type_ids\"]\n", + " }\n", + " model_q(**input_torch)\n", + " trainer = get_trainer(model_q)\n", + " print(f\"percentile: {percentile}\")\n", + " print(trainer.evaluate())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We first check that inference is working correctly:" + "As you can see, the chosen percentile value has a high impact on the final accuracy.\n", + "\n", + "For the rest of the notebook, we apply the `99.999` percentile." ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 8, "metadata": { - "pycharm": { - "name": "#%%\n" - } + "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO|trainer.py:439] 2021-12-28 13:52:09,215 >> Using amp half precision backend\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "[array([[ 0.74428475, 0.92069125, -1.8263749 ],\n", - " [ 1.851499 , -1.0653014 , -1.4172065 ],\n", - " [ 1.7828548 , -1.1306885 , -1.1852558 ],\n", - " [ 1.6740881 , -0.7420906 , -1.5647126 ],\n", - " [ 2.2817519 , 0.01615529, -2.7685544 ],\n", - " [ 3.1013348 , -0.48788828, -3.3121827 ],\n", - " [-3.0679533 , 2.708288 , 0.70968 ],\n", - " [ 3.1545656 , -1.0913979 , -2.6073706 ],\n", - " [-0.3026344 , -1.7703965 , 1.6011946 ],\n", - " [-3.2131557 , -0.5275665 , 3.786335 ],\n", - " [ 2.2266033 , -1.2310914 , -1.523544 ],\n", - " [-1.5110059 , -0.46988845, 1.7940781 ],\n", - " [-2.4409676 , 3.7142613 , -0.73455316],\n", - " [-1.8158143 , 1.9259161 , -0.05558195],\n", - " [-0.33427513, -0.48280472, 0.6140785 ],\n", - " [ 2.3686104 , -1.4665173 , -1.5184819 ],\n", - " [ 3.58267 , -1.1251179 , -3.060151 ],\n", - " [-2.4983776 , -2.0526152 , 4.5359097 ],\n", - " [-3.441052 , -0.6358736 , 4.1798487 ],\n", - " [-2.2326443 , 4.032728 , -1.1005057 ],\n", - " [ 3.4742196 , -0.98982847, -3.2408576 ],\n", - " [ 1.7075734 , 0.56745094, -2.7780871 ],\n", - " [-2.6132822 , 0.45791242, 2.1319566 ],\n", - " [ 3.498353 , -0.68513054, -3.4510155 ],\n", - " [ 3.394199 , -1.578492 , -2.5097256 ],\n", - " [-1.5231444 , 0.22112232, 1.1882032 ],\n", - " [-2.7878394 , 1.368547 , 1.5938892 ],\n", - " [-2.263415 , 2.5507202 , -0.16721557],\n", - " [-2.716222 , 0.03395515, 2.6644425 ],\n", - " [ 2.663493 , -0.7295195 , -2.7137947 ],\n", - " [ 2.6217816 , -0.7861772 , -2.417176 ],\n", - " [ 2.506748 , -0.09974011, -3.0284772 ]], dtype=float32)]\n" + "{'eval_loss': 0.3939284086227417, 'eval_accuracy': 0.850636780438105, 'eval_runtime': 46.5572, 'eval_samples_per_second': 210.816, 'eval_steps_per_second': 3.308}\n", + "{'eval_loss': 0.3939284086227417, 'eval_accuracy': 0.850636780438105, 'eval_runtime': 46.5572, 'eval_samples_per_second': 210.816, 'eval_steps_per_second': 3.308}\n" ] } ], "source": [ - "tensorrt_output = infer_tensorrt(\n", - " context=context,\n", - " host_inputs=input_np,\n", - " input_binding_idxs=input_binding_idxs,\n", - " output_binding_idxs=output_binding_idxs,\n", - " stream=stream,\n", - ")\n", - "print(tensorrt_output)" + "with QATCalibrate(method=\"histogram\", percentile=99.999) as qat:\n", + " model_q: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(\n", + " \"model_trained_fp16\", num_labels=num_labels\n", + " )\n", + " model_q = model_q.cuda()\n", + " qat.setup_model_qat(model_q) # prepare quantizer to any model\n", + "\n", + " with torch.no_grad():\n", + " for start_index in range(0, 128, batch_size):\n", + " end_index = start_index + batch_size\n", + " data = encoded_dataset[\"train\"][start_index:end_index]\n", + " input_torch = {\n", + " k: torch.tensor(v, dtype=torch.long, device=\"cuda\")\n", + " for k, v in data.items()\n", + " if k in [\"input_ids\", \"attention_mask\", \"token_type_ids\"]\n", + " }\n", + " model_q(**input_torch)\n", + "trainer = get_trainer(model_q)\n", + "print(trainer.evaluate())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Measure of the accuracy when `TensorRT` is the engine:" + "#### Per layer quantization analysis\n", + "\n", + "Below we will run a sensitivity analysis, by enabling quantization of one layer at a time and measuring the accuracy. That way we will be able to detect if the quantization of a specific layer has a larger cost on accuracy than other layers." ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e3123eb9ee20476aa55e70b181c124d7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/307 [00:00 concrete examples on `Roberta` architecture: in HF library, there is a `cumsum` in the position embedding generation. Something very simple. It takes as input an integer tensor and output an integer tensor. It happens that the `cumsum` operator from TensorRT supports float but not integer (https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md). It leads to a crash during the model conversion with a strange error message. Converting the input to float tensor fix the issue. Not complex, but requires some knowledge.\n", - "\n", - "The process below is a bit simpler than the method 1:\n", - "\n", - "* Calibrate\n", - "* Quantization Aware training (QAT)\n", - "\n", - "> there are many ways to get a QDQ model, you can modify Pytorch source code like here, patch ONNX graph (this approach is used at Microsoft for instance) or leverage the new FX Pytorch interface. Modifying the source code is the most straight forward so we choosed to do it that way.\n" + "for op in [\"matmul\", \"layernorm\"]:\n", + " for name, module in model_q.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if op in name:\n", + " module.enable_quant()\n", + " else:\n", + " module.disable_quant()\n", + " print(op)\n", + " trainer.evaluate()\n", + " print(\"----\")" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ - "### Calibration" + "It appears that `LayerNorm` quantization has a significant accuracy cost.\n", + "\n", + "Our goal is to disable quantization for as few operations as possible while preserving accuracy as much as possible. Therefore we will try to only disable quantization for `LayerNorm` on Layers 2 to 6." ] }, { "cell_type": "code", - "execution_count": 28, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 11, + "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "disable roberta.encoder.layer.2.output.layernorm_quantizer_0\n", + "disable roberta.encoder.layer.2.output.layernorm_quantizer_1\n", + "disable roberta.encoder.layer.3.output.layernorm_quantizer_0\n", + "disable roberta.encoder.layer.3.output.layernorm_quantizer_1\n", + "disable roberta.encoder.layer.4.output.layernorm_quantizer_0\n", + "disable roberta.encoder.layer.4.output.layernorm_quantizer_1\n", + "disable roberta.encoder.layer.6.output.layernorm_quantizer_0\n", + "disable roberta.encoder.layer.6.output.layernorm_quantizer_1\n", + "{'eval_loss': 0.3660135269165039, 'eval_accuracy': 0.8618441161487519, 'eval_runtime': 45.9324, 'eval_samples_per_second': 213.684, 'eval_steps_per_second': 3.353}\n" + ] + }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b3d2a85cc8674f6db9ccc11bdb52794a", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - " 0%| | 0/4 [00:00> Using amp half precision backend\n" + "[INFO|trainer.py:439] 2021-12-28 13:54:41,146 >> Using amp half precision backend\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "{'loss': 0.7446, 'learning_rate': 9.1875814863103e-06, 'epoch': 0.08}\n", - "{'eval_loss': 0.5072791576385498, 'eval_accuracy': 0.8059093224656139, 'eval_runtime': 47.7062, 'eval_samples_per_second': 205.738, 'eval_steps_per_second': 3.228, 'epoch': 0.08}\n", - "{'loss': 0.5424, 'learning_rate': 8.372718383311604e-06, 'epoch': 0.16}\n", - "{'eval_loss': 0.4495479464530945, 'eval_accuracy': 0.8264900662251655, 'eval_runtime': 46.5832, 'eval_samples_per_second': 210.698, 'eval_steps_per_second': 3.306, 'epoch': 0.16}\n", - "{'loss': 0.5071, 'learning_rate': 7.558670143415907e-06, 'epoch': 0.24}\n", - "{'eval_loss': 0.4541535973548889, 'eval_accuracy': 0.8242485990830362, 'eval_runtime': 50.0162, 'eval_samples_per_second': 196.236, 'eval_steps_per_second': 3.079, 'epoch': 0.24}\n", - "{'loss': 0.4854, 'learning_rate': 6.743807040417211e-06, 'epoch': 0.33}\n", - "{'eval_loss': 0.4110897183418274, 'eval_accuracy': 0.8425878757004585, 'eval_runtime': 47.5433, 'eval_samples_per_second': 206.443, 'eval_steps_per_second': 3.239, 'epoch': 0.33}\n", - "{'loss': 0.4656, 'learning_rate': 5.929758800521513e-06, 'epoch': 0.41}\n", - "{'eval_loss': 0.4083874821662903, 'eval_accuracy': 0.8410596026490066, 'eval_runtime': 47.486, 'eval_samples_per_second': 206.693, 'eval_steps_per_second': 3.243, 'epoch': 0.41}\n", - "{'loss': 0.4547, 'learning_rate': 5.1148956975228174e-06, 'epoch': 0.49}\n", - "{'eval_loss': 0.40900033712387085, 'eval_accuracy': 0.8411614875191035, 'eval_runtime': 48.7052, 'eval_samples_per_second': 201.518, 'eval_steps_per_second': 3.162, 'epoch': 0.49}\n", - "{'loss': 0.4503, 'learning_rate': 4.30003259452412e-06, 'epoch': 0.57}\n", - "{'eval_loss': 0.391275018453598, 'eval_accuracy': 0.8503311258278146, 'eval_runtime': 47.4931, 'eval_samples_per_second': 206.662, 'eval_steps_per_second': 3.243, 'epoch': 0.57}\n", - "{'loss': 0.4433, 'learning_rate': 3.4851694915254244e-06, 'epoch': 0.65}\n", - "{'eval_loss': 0.3878655731678009, 'eval_accuracy': 0.851044319918492, 'eval_runtime': 49.1368, 'eval_samples_per_second': 199.748, 'eval_steps_per_second': 3.134, 'epoch': 0.65}\n", - "{'loss': 0.4323, 'learning_rate': 2.6711212516297265e-06, 'epoch': 0.73}\n", - "{'eval_loss': 0.38398584723472595, 'eval_accuracy': 0.8541008660213958, 'eval_runtime': 48.2456, 'eval_samples_per_second': 203.438, 'eval_steps_per_second': 3.192, 'epoch': 0.73}\n", - "{'loss': 0.4238, 'learning_rate': 1.8570730117340288e-06, 'epoch': 0.81}\n", - "{'eval_loss': 0.38184261322021484, 'eval_accuracy': 0.8535914416709118, 'eval_runtime': 47.2522, 'eval_samples_per_second': 207.715, 'eval_steps_per_second': 3.259, 'epoch': 0.81}\n", - "{'loss': 0.4253, 'learning_rate': 1.0422099087353325e-06, 'epoch': 0.9}\n", - "{'eval_loss': 0.37562793493270874, 'eval_accuracy': 0.856953642384106, 'eval_runtime': 47.5902, 'eval_samples_per_second': 206.24, 'eval_steps_per_second': 3.236, 'epoch': 0.9}\n", - "{'loss': 0.4248, 'learning_rate': 2.2734680573663624e-07, 'epoch': 0.98}\n", - "{'eval_loss': 0.37268802523612976, 'eval_accuracy': 0.8588894549159449, 'eval_runtime': 48.7185, 'eval_samples_per_second': 201.463, 'eval_steps_per_second': 3.161, 'epoch': 0.98}\n", - "{'train_runtime': 5103.4383, 'train_samples_per_second': 76.949, 'train_steps_per_second': 2.405, 'train_loss': 0.4821581125570245, 'epoch': 1.0}\n", - "{'eval_loss': 0.37268802523612976, 'eval_accuracy': 0.8588894549159449, 'eval_runtime': 47.6115, 'eval_samples_per_second': 206.148, 'eval_steps_per_second': 3.235, 'epoch': 1.0}\n", - "{'eval_loss': 0.37268802523612976, 'eval_accuracy': 0.8588894549159449, 'eval_runtime': 47.6115, 'eval_samples_per_second': 206.148, 'eval_steps_per_second': 3.235, 'epoch': 1.0}\n" + "{'loss': 0.3591, 'learning_rate': 9.188396349413298e-08, 'epoch': 0.08}\n", + "{'eval_loss': 0.3738575875759125, 'eval_accuracy': 0.8596026490066225, 'eval_runtime': 46.992, 'eval_samples_per_second': 208.865, 'eval_steps_per_second': 3.277, 'epoch': 0.08}\n", + "{'loss': 0.3182, 'learning_rate': 8.373533246414603e-08, 'epoch': 0.16}\n", + "{'eval_loss': 0.38133203983306885, 'eval_accuracy': 0.8586856851757514, 'eval_runtime': 45.7335, 'eval_samples_per_second': 214.613, 'eval_steps_per_second': 3.367, 'epoch': 0.16}\n", + "{'loss': 0.3062, 'learning_rate': 7.558670143415906e-08, 'epoch': 0.24}\n", + "{'eval_loss': 0.3903615176677704, 'eval_accuracy': 0.8592969943963321, 'eval_runtime': 45.6544, 'eval_samples_per_second': 214.985, 'eval_steps_per_second': 3.373, 'epoch': 0.24}\n", + "{'loss': 0.2986, 'learning_rate': 6.74380704041721e-08, 'epoch': 0.33}\n", + "{'eval_loss': 0.39669597148895264, 'eval_accuracy': 0.8577687213448802, 'eval_runtime': 45.6583, 'eval_samples_per_second': 214.966, 'eval_steps_per_second': 3.373, 'epoch': 0.33}\n", + "{'loss': 0.2994, 'learning_rate': 5.9289439374185136e-08, 'epoch': 0.41}\n", + "{'eval_loss': 0.394754558801651, 'eval_accuracy': 0.8612328069281712, 'eval_runtime': 45.6439, 'eval_samples_per_second': 215.034, 'eval_steps_per_second': 3.374, 'epoch': 0.41}\n", + "{'loss': 0.3027, 'learning_rate': 5.1148956975228164e-08, 'epoch': 0.49}\n", + "{'eval_loss': 0.39516741037368774, 'eval_accuracy': 0.8626591951095263, 'eval_runtime': 45.7496, 'eval_samples_per_second': 214.538, 'eval_steps_per_second': 3.366, 'epoch': 0.49}\n", + "{'loss': 0.3164, 'learning_rate': 4.300032594524119e-08, 'epoch': 0.57}\n", + "{'eval_loss': 0.39596375823020935, 'eval_accuracy': 0.8609271523178808, 'eval_runtime': 45.6669, 'eval_samples_per_second': 214.926, 'eval_steps_per_second': 3.372, 'epoch': 0.57}\n", + "{'loss': 0.3298, 'learning_rate': 3.485984354628422e-08, 'epoch': 0.65}\n", + "{'eval_loss': 0.3958113491535187, 'eval_accuracy': 0.8599083036169128, 'eval_runtime': 45.6637, 'eval_samples_per_second': 214.941, 'eval_steps_per_second': 3.372, 'epoch': 0.65}\n", + "{'loss': 0.3379, 'learning_rate': 2.6719361147327245e-08, 'epoch': 0.73}\n", + "{'eval_loss': 0.39275360107421875, 'eval_accuracy': 0.8577687213448802, 'eval_runtime': 45.7659, 'eval_samples_per_second': 214.461, 'eval_steps_per_second': 3.365, 'epoch': 0.73}\n", + "{'loss': 0.354, 'learning_rate': 1.8570730117340286e-08, 'epoch': 0.81}\n", + "{'eval_loss': 0.39236611127853394, 'eval_accuracy': 0.8562404482934284, 'eval_runtime': 45.6972, 'eval_samples_per_second': 214.784, 'eval_steps_per_second': 3.37, 'epoch': 0.81}\n", + "{'loss': 0.3826, 'learning_rate': 1.0422099087353324e-08, 'epoch': 0.9}\n", + "{'eval_loss': 0.389812171459198, 'eval_accuracy': 0.8620478858889455, 'eval_runtime': 45.6861, 'eval_samples_per_second': 214.836, 'eval_steps_per_second': 3.371, 'epoch': 0.9}\n", + "{'loss': 0.4363, 'learning_rate': 2.2734680573663624e-09, 'epoch': 0.98}\n", + "{'eval_loss': 0.3902811110019684, 'eval_accuracy': 0.8583800305654611, 'eval_runtime': 45.6787, 'eval_samples_per_second': 214.87, 'eval_steps_per_second': 3.371, 'epoch': 0.98}\n", + "{'train_runtime': 4893.2972, 'train_samples_per_second': 80.253, 'train_steps_per_second': 2.508, 'train_loss': 0.33914821741633805, 'epoch': 1.0}\n", + "{'eval_loss': 0.39516741037368774, 'eval_accuracy': 0.8626591951095263, 'eval_runtime': 45.6231, 'eval_samples_per_second': 215.132, 'eval_steps_per_second': 3.375, 'epoch': 1.0}\n", + "{'eval_loss': 0.39516741037368774, 'eval_accuracy': 0.8626591951095263, 'eval_runtime': 45.6231, 'eval_samples_per_second': 215.132, 'eval_steps_per_second': 3.375, 'epoch': 1.0}\n" ] } ], "source": [ - "model_roberta_q: PreTrainedModel = QDQRobertaForSequenceClassification.from_pretrained(\n", - " \"roberta-untrained-quantized\", num_labels=num_labels\n", - ")\n", - "model_roberta_q = model_roberta_q.cuda()\n", - "\n", - "args.learning_rate = 1e-5\n", - "\n", - "trainer = Trainer(\n", - " model_roberta_q,\n", - " args,\n", - " train_dataset=encoded_dataset[\"train\"],\n", - " eval_dataset=encoded_dataset[validation_key],\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics,\n", - ")\n", - "transformers.logging.set_verbosity_error()\n", + "args.learning_rate = 1e-7\n", + "trainer = get_trainer(model_q)\n", "trainer.train()\n", "print(trainer.evaluate())\n", - "model_roberta_q.save_pretrained(\"roberta-trained-quantized\")\n", - "del model_roberta_q" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### Benchmark" + "model_q.save_pretrained(\"model-qat\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Export a `QDQ Pytorch` model on `ONNX`, we need to enable fake quantization mode from Pytorch." + "#### Export a `QDQ Pytorch` model to `ONNX`\n", + "\n", + "We need to enable fake quantization mode from Pytorch." ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" @@ -1552,19 +999,34 @@ } ], "source": [ - "model_roberta_q: PreTrainedModel = QDQRobertaForSequenceClassification.from_pretrained(\n", - " \"roberta-trained-quantized\", num_labels=num_labels\n", - ")\n", - "model_roberta_q = model_roberta_q.cuda()\n", - "\n", "data = encoded_dataset[\"train\"][1:3]\n", "input_torch = convert_tensor(data, output=\"torch\")\n", - "\n", - "from pytorch_quantization.nn import TensorQuantizer\n", - "\n", - "TensorQuantizer.use_fb_fake_quant = True\n", - "convert_to_onnx(model_pytorch=model_roberta_q, output_path=\"roberta_q.onnx\", inputs_pytorch=input_torch, opset=13)\n", - "TensorQuantizer.use_fb_fake_quant = False" + "convert_to_quant_onnx(model_pytorch=model_q, output_path=\"model_qat.onnx\", inputs_pytorch=input_torch)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "del model_q\n", + "QATCalibrate.restore()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Benchmark" ] }, { @@ -1576,7 +1038,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 15, "metadata": { "pycharm": { "name": "#%%\n" @@ -1586,20 +1048,20 @@ "source": [ "engine = build_engine(\n", " runtime=runtime,\n", - " onnx_file_path=\"roberta_q.onnx\",\n", + " onnx_file_path=\"model_qat.onnx\",\n", " logger=trt_logger,\n", " min_shape=(1, max_seq_len),\n", " optimal_shape=(batch_size, max_seq_len),\n", " max_shape=(batch_size, max_seq_len),\n", " workspace_size=10000 * 1024 * 1024,\n", - " fp16=False,\n", + " fp16=True,\n", " int8=True,\n", ")" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 16, "metadata": { "pycharm": { "name": "#%%\n" @@ -1608,8 +1070,8 @@ }, "outputs": [], "source": [ - "# same conversion from the terminal\n", - "#!/usr/src/tensorrt/bin/trtexec --onnx=roberta_q.onnx --shapes=input_ids:32x256,attention_mask:32x256 --int8 --workspace=10000 --saveEngine=\"test.plan\"" + "# same as above, but from the terminal\n", + "# !/usr/src/tensorrt/bin/trtexec --onnx=model_qat.onnx --shapes=input_ids:32x256,attention_mask:32x256 --best --workspace=10000 --saveEngine=\"test.plan\"" ] }, { @@ -1621,7 +1083,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 17, "metadata": { "pycharm": { "name": "#%%\n" @@ -1637,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 18, "metadata": { "pycharm": { "name": "#%%\n" @@ -1661,7 +1123,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 19, "metadata": { "pycharm": { "name": "#%%\n" @@ -1672,38 +1134,38 @@ "name": "stdout", "output_type": "stream", "text": [ - "[array([[-0.51485693, 1.7172506 , -1.5262733 ],\n", - " [ 2.353083 , -1.1611496 , -1.5365003 ],\n", - " [ 1.7413325 , -0.5790743 , -1.659783 ],\n", - " [ 1.564936 , -0.49288243, -1.3157034 ],\n", - " [ 1.625774 , -0.06960616, -2.3415694 ],\n", - " [ 3.4986691 , -1.4058641 , -2.909094 ],\n", - " [-2.8577392 , 2.3696911 , 0.9519228 ],\n", - " [ 3.3248267 , -1.3577703 , -2.7853382 ],\n", - " [ 0.24115235, -1.2206222 , 1.9764783 ],\n", - " [-2.1684752 , -0.5929435 , 4.1980004 ],\n", - " [ 2.7209766 , -0.85320175, -2.54238 ],\n", - " [-1.4474616 , -0.5539231 , 3.543574 ],\n", - " [-2.4900246 , 2.5807233 , 0.29105982],\n", - " [-2.5218582 , 2.4110076 , 0.4147416 ],\n", - " [-0.17686448, 0.19154471, 0.5593225 ],\n", - " [ 2.7820387 , -0.92807496, -2.466081 ],\n", - " [ 3.279974 , -1.1566027 , -3.082859 ],\n", - " [-2.0141928 , -1.5038209 , 4.759576 ],\n", - " [-2.8134325 , -0.09846646, 4.3754187 ],\n", - " [-2.4098513 , 3.2899659 , -1.0759622 ],\n", - " [ 3.3163776 , -1.0768431 , -3.2254322 ],\n", - " [ 1.8269717 , 0.69882363, -3.383636 ],\n", - " [-2.5522573 , -0.6264023 , 4.348268 ],\n", - " [ 3.4143322 , -1.0857687 , -3.3268075 ],\n", - " [ 3.418143 , -1.5472901 , -2.7069504 ],\n", - " [-0.9896777 , 0.2267024 , 1.1920347 ],\n", - " [-1.9947617 , 0.58624893, 2.7530055 ],\n", - " [-2.328186 , 3.3224452 , -1.4264905 ],\n", - " [-2.8432767 , 0.7639467 , 3.400511 ],\n", - " [ 2.8967564 , -0.8297742 , -3.1207962 ],\n", - " [ 3.0185122 , -1.3341582 , -2.4004488 ],\n", - " [ 3.050612 , -1.0362974 , -2.8093874 ]], dtype=float32)]\n" + "[array([[ 0.11111109, 2.9936233 , -2.5243347 ],\n", + " [ 3.2135723 , -0.4374885 , -2.4485767 ],\n", + " [ 2.1678474 , -1.1477091 , -0.7798154 ],\n", + " [ 1.8148003 , -0.2093072 , -1.416711 ],\n", + " [ 2.3070638 , 0.27601779, -2.2818418 ],\n", + " [ 4.1799006 , -0.83163625, -2.8492923 ],\n", + " [-3.695277 , 2.3409832 , 1.4314314 ],\n", + " [ 4.1796045 , -1.0709951 , -2.6119678 ],\n", + " [-0.44781622, -1.4288648 , 1.888488 ],\n", + " [-2.9845483 , -1.5895646 , 4.117529 ],\n", + " [ 3.9293122 , -0.68528754, -2.9477124 ],\n", + " [-2.516609 , 0.34680495, 2.2793124 ],\n", + " [-3.0710464 , 3.3439813 , 0.08079423],\n", + " [-2.2859852 , 1.9546673 , 0.37908432],\n", + " [ 0.3999826 , -1.0603418 , 0.5099453 ],\n", + " [ 2.9247677 , -0.6867883 , -1.7499886 ],\n", + " [ 4.1125493 , -0.7771612 , -2.986419 ],\n", + " [-2.58058 , -2.3291597 , 4.553415 ],\n", + " [-3.215447 , -1.3902456 , 4.2499046 ],\n", + " [-2.014185 , 4.117433 , -1.634403 ],\n", + " [ 4.051285 , -0.64716065, -2.9019048 ],\n", + " [ 3.742484 , -0.07188296, -3.272956 ],\n", + " [-3.302061 , -1.0159078 , 3.9711204 ],\n", + " [ 3.9316242 , -0.33764294, -3.209711 ],\n", + " [ 3.9900765 , -1.5201662 , -2.1166122 ],\n", + " [-1.2437494 , 1.410141 , -0.10993958],\n", + " [-3.1267605 , -0.8212991 , 3.6917076 ],\n", + " [-2.0607114 , 4.1098857 , -1.4996963 ],\n", + " [-3.5770578 , -0.736545 , 3.9671996 ],\n", + " [ 3.776105 , -0.60771704, -2.8707912 ],\n", + " [ 3.5450761 , -0.14414684, -2.9718893 ],\n", + " [ 3.4713674 , 0.12106885, -3.189211 ]], dtype=float32)]\n" ] } ], @@ -1727,30 +1189,16 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 20, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "39a7ce7f07b34049a82b0721dcb0322e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/307 [00:00> Using amp half precision backend\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.6612, 'learning_rate': 9.1875814863103e-06, 'epoch': 0.08}\n", - "{'eval_loss': 0.4690713882446289, 'eval_accuracy': 0.8217014773306164, 'eval_runtime': 19.0461, 'eval_samples_per_second': 515.328, 'eval_steps_per_second': 8.086, 'epoch': 0.08}\n", - "{'loss': 0.4984, 'learning_rate': 8.372718383311604e-06, 'epoch': 0.16}\n", - "{'eval_loss': 0.4219346344470978, 'eval_accuracy': 0.835048395313296, 'eval_runtime': 18.9872, 'eval_samples_per_second': 516.928, 'eval_steps_per_second': 8.111, 'epoch': 0.16}\n", - "{'loss': 0.4664, 'learning_rate': 7.558670143415907e-06, 'epoch': 0.24}\n", - "{'eval_loss': 0.4248501658439636, 'eval_accuracy': 0.835048395313296, 'eval_runtime': 18.5315, 'eval_samples_per_second': 529.639, 'eval_steps_per_second': 8.31, 'epoch': 0.24}\n", - "{'loss': 0.4471, 'learning_rate': 6.743807040417211e-06, 'epoch': 0.33}\n", - "{'eval_loss': 0.3851495087146759, 'eval_accuracy': 0.853998981151299, 'eval_runtime': 18.5175, 'eval_samples_per_second': 530.039, 'eval_steps_per_second': 8.316, 'epoch': 0.33}\n", - "{'loss': 0.4291, 'learning_rate': 5.9289439374185145e-06, 'epoch': 0.41}\n", - "{'eval_loss': 0.3886096775531769, 'eval_accuracy': 0.8529801324503311, 'eval_runtime': 19.2048, 'eval_samples_per_second': 511.07, 'eval_steps_per_second': 8.019, 'epoch': 0.41}\n", - "{'loss': 0.4198, 'learning_rate': 5.114080834419818e-06, 'epoch': 0.49}\n", - "{'eval_loss': 0.3939703404903412, 'eval_accuracy': 0.8499235863474274, 'eval_runtime': 19.079, 'eval_samples_per_second': 514.44, 'eval_steps_per_second': 8.072, 'epoch': 0.49}\n", - "{'loss': 0.417, 'learning_rate': 4.30003259452412e-06, 'epoch': 0.57}\n", - "{'eval_loss': 0.36645272374153137, 'eval_accuracy': 0.8580743759551707, 'eval_runtime': 18.5783, 'eval_samples_per_second': 528.305, 'eval_steps_per_second': 8.289, 'epoch': 0.57}\n", - "{'loss': 0.4117, 'learning_rate': 3.4851694915254244e-06, 'epoch': 0.65}\n", - "{'eval_loss': 0.3587413430213928, 'eval_accuracy': 0.860825267447784, 'eval_runtime': 18.6272, 'eval_samples_per_second': 526.919, 'eval_steps_per_second': 8.267, 'epoch': 0.65}\n", - "{'loss': 0.4014, 'learning_rate': 2.670306388526728e-06, 'epoch': 0.73}\n", - "{'eval_loss': 0.3596762418746948, 'eval_accuracy': 0.8618441161487519, 'eval_runtime': 18.5056, 'eval_samples_per_second': 530.379, 'eval_steps_per_second': 8.322, 'epoch': 0.73}\n", - "{'loss': 0.394, 'learning_rate': 1.8554432855280313e-06, 'epoch': 0.81}\n", - "{'eval_loss': 0.35547441244125366, 'eval_accuracy': 0.8645950076413652, 'eval_runtime': 18.5169, 'eval_samples_per_second': 530.056, 'eval_steps_per_second': 8.317, 'epoch': 0.81}\n", - "{'loss': 0.3967, 'learning_rate': 1.0422099087353325e-06, 'epoch': 0.9}\n", - "{'eval_loss': 0.350873202085495, 'eval_accuracy': 0.8677534386143657, 'eval_runtime': 18.5202, 'eval_samples_per_second': 529.963, 'eval_steps_per_second': 8.315, 'epoch': 0.9}\n", - "{'loss': 0.3971, 'learning_rate': 2.2734680573663624e-07, 'epoch': 0.98}\n", - "{'eval_loss': 0.350557416677475, 'eval_accuracy': 0.866225165562914, 'eval_runtime': 18.4974, 'eval_samples_per_second': 530.616, 'eval_steps_per_second': 8.326, 'epoch': 0.98}\n", - "{'train_runtime': 2679.9953, 'train_samples_per_second': 146.531, 'train_steps_per_second': 4.579, 'train_loss': 0.44397361524101964, 'epoch': 1.0}\n", - "{'eval_loss': 0.350873202085495, 'eval_accuracy': 0.8677534386143657, 'eval_runtime': 18.5374, 'eval_samples_per_second': 529.471, 'eval_steps_per_second': 8.308, 'epoch': 1.0}\n", - "{'eval_loss': 0.350873202085495, 'eval_accuracy': 0.8677534386143657, 'eval_runtime': 18.5374, 'eval_samples_per_second': 529.471, 'eval_steps_per_second': 8.308, 'epoch': 1.0}\n" - ] - } - ], - "source": [ - "model_roberta: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(\n", - " model_checkpoint, num_labels=num_labels\n", - ")\n", - "model_roberta = model_roberta.cuda()\n", - "\n", - "args.learning_rate = 1e-5\n", - "\n", - "trainer = Trainer(\n", - " model_roberta,\n", - " args,\n", - " train_dataset=encoded_dataset[\"train\"],\n", - " eval_dataset=encoded_dataset[validation_key],\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics,\n", - ")\n", - "transformers.logging.set_verbosity_error()\n", - "trainer.train()\n", - "print(trainer.evaluate())\n", - "# {'eval_loss': 0.3559744358062744, 'eval_accuracy': 0.8655119714722364, 'eval_runtime': 19.6678, 'eval_samples_per_second': 499.04, 'eval_steps_per_second': 7.83, 'epoch': 0.98}\n", - "trainer.save_model(\"roberta-baseline\")\n", - "del model_roberta\n", - "del trainer" - ] - }, - { - "cell_type": "markdown", + "execution_count": 23, "metadata": {}, + "outputs": [], "source": [ - "### GPU execution" + "del engine, context" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To finish, we will measure vanilla Pytorch inference on both FP32 and FP16 precision, it will be our baseline:" + "## Pytorch baseline\n", + "\n", + "Time to get some numbers to compare with.\n", + "\n", + "### GPU execution\n", + "\n", + "We will measure vanilla Pytorch inference on both FP32 and FP16 precision on GPU, it will be our baseline:" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Pytorch (FP32)] mean=79.48ms, sd=0.84ms, min=78.64ms, max=82.72ms, median=79.26ms, 95p=81.66ms, 99p=82.47ms\n" + "[Pytorch (FP32)] mean=76.87ms, sd=1.24ms, min=75.68ms, max=82.38ms, median=76.52ms, 95p=79.29ms, 99p=81.90ms\n" ] } ], "source": [ - "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"roberta-baseline\", num_labels=num_labels)\n", + "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"model_trained_fp16\", num_labels=num_labels)\n", "baseline_model = baseline_model.cuda()\n", "baseline_model = baseline_model.eval()\n", "\n", @@ -1936,14 +1307,14 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Pytorch (FP16)] mean=58.02ms, sd=0.59ms, min=57.46ms, max=60.90ms, median=57.80ms, 95p=59.52ms, 99p=60.49ms\n" + "[Pytorch (FP16)] mean=56.24ms, sd=0.67ms, min=55.53ms, max=59.61ms, median=56.05ms, 95p=57.80ms, 99p=58.18ms\n" ] } ], @@ -1971,26 +1342,24 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Pytorch (FP32) - CPU] mean=3999.80ms, sd=27.86ms, min=3946.75ms, max=4042.56ms, median=4000.13ms, 95p=4037.90ms, 99p=4041.63ms\n" + "[Pytorch (FP32) - CPU] mean=4267.96ms, sd=249.08ms, min=3959.59ms, max=4697.79ms, median=4299.22ms, 95p=4632.12ms, 99p=4684.66ms\n" ] } ], "source": [ - "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"roberta-baseline\", num_labels=num_labels)\n", + "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"model_trained_fp16\", num_labels=num_labels)\n", "baseline_model = baseline_model.eval()\n", "data = encoded_dataset[\"train\"][0:batch_size]\n", "input_torch: OD[str, torch.Tensor] = convert_tensor(data=data, output=\"torch\")\n", "input_torch_cpu = {k: v.to(\"cpu\") for k, v in input_torch.items()}\n", "\n", - "import os\n", - "\n", "torch.set_num_threads(os.cpu_count())\n", "\n", "with torch.inference_mode():\n", @@ -2007,14 +1376,14 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Pytorch (FP16) - CPU] mean=4005.57ms, sd=46.34ms, min=3922.37ms, max=4095.30ms, median=4010.20ms, 95p=4071.75ms, 99p=4090.59ms\n" + "[Pytorch (FP16) - CPU] mean=4428.94ms, sd=225.39ms, min=4148.26ms, max=4871.84ms, median=4404.70ms, 95p=4781.81ms, 99p=4853.83ms\n" ] } ], @@ -2042,19 +1411,21 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Pytorch (INT-8) - CPU] mean=3669.84ms, sd=25.49ms, min=3633.48ms, max=3712.22ms, median=3670.07ms, 95p=3706.98ms, 99p=3711.17ms\n" + "[Pytorch (INT-8) - CPU] mean=3299.66ms, sd=37.76ms, min=3274.33ms, max=3405.91ms, median=3285.20ms, 95p=3366.88ms, 99p=3398.10ms\n" ] } ], "source": [ - "quantized_baseline_model = AutoModelForSequenceClassification.from_pretrained(\"roberta-baseline\", num_labels=num_labels)\n", + "quantized_baseline_model = AutoModelForSequenceClassification.from_pretrained(\n", + " \"model_trained_fp16\", num_labels=num_labels\n", + ")\n", "quantized_baseline_model = quantized_baseline_model.eval()\n", "quantized_baseline_model = torch.quantization.quantize_dynamic(\n", " quantized_baseline_model, {torch.nn.Linear}, dtype=torch.qint8\n", @@ -2069,8 +1440,7 @@ " with track_infer_time(time_buffer):\n", " _ = quantized_baseline_model(**input_torch_cpu)\n", " torch.cuda.synchronize()\n", - "print_timings(name=\"Pytorch (INT-8) - CPU\", timings=time_buffer)\n", - "del quantized_baseline_model" + "print_timings(name=\"Pytorch (INT-8) - CPU\", timings=time_buffer)" ] }, { @@ -2084,18 +1454,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Below we export a randomly initialized `Roberta` model, the purpose is to only check the performance on mixed precision (FP16, no quantization)." + "Below we export our finetuned model, the purpose is to only check the performance on mixed precision (FP16, no quantization)." ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 13, "metadata": { "scrolled": true }, "outputs": [], "source": [ - "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"roberta-baseline\", num_labels=num_labels)\n", + "baseline_model = AutoModelForSequenceClassification.from_pretrained(\"model_trained_fp16\", num_labels=num_labels)\n", "baseline_model = baseline_model.cuda()\n", "convert_to_onnx(baseline_model, output_path=\"baseline.onnx\", inputs_pytorch=input_torch, opset=12)\n", "del baseline_model" @@ -2103,7 +1473,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 15, "metadata": { "scrolled": true }, @@ -2112,7 +1482,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[TensorRT (FP16)] mean=30.23ms, sd=0.49ms, min=29.73ms, max=32.67ms, median=30.06ms, 95p=30.82ms, 99p=32.38ms\n" + "[TensorRT (FP16)] mean=29.90ms, sd=0.82ms, min=29.30ms, max=33.41ms, median=29.69ms, 95p=31.85ms, 99p=32.79ms\n" ] } ], @@ -2128,6 +1498,7 @@ " fp16=True,\n", " int8=False,\n", ")\n", + "input_np: OD[str, np.ndarray] = convert_tensor(data=data, output=\"np\")\n", "stream: Stream = pycuda.driver.Stream()\n", "context: IExecutionContext = engine.create_execution_context()\n", "context.set_optimization_profile_async(profile_index=profile_index, stream_handle=stream.handle)\n", @@ -2166,17 +1537,32 @@ "The recent 1.10 version of ONNX Runtime (with TensorRT support) is still a bit buggy on transformer models, that is why we use the 1.9.0 version in the measures below.\n", "\n", "As before, CPU quantization is dynamic.\n", - "Function `create_model_for_provider` will set ONNX Runtime to use all cores available and enable any possible optimizations." + "Function `\n", + "` will set ONNX Runtime to use all cores available and enable any possible optimizations." ] }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 8, "metadata": { "pycharm": { "is_executing": true } }, + "outputs": [], + "source": [ + "optimize_onnx(\n", + " onnx_path=\"baseline.onnx\",\n", + " onnx_optim_model_path=\"baseline-optimized.onnx\",\n", + " fp16=True,\n", + " use_cuda=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -2184,60 +1570,113 @@ "text": [ "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", - "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", + "Warning: Unsupported operator SkipLayerNormalization. No schema registered for th" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ignore MatMul due to non constant B: /[MatMul_113]\n", + "Ignore MatMul due to non constant B: /[MatMul_127]\n", + "Ignore MatMul due to non constant B: /[MatMul_137]\n", + "Ignore MatMul due to non constant B: /[MatMul_207]\n", + "Ignore MatMul due to non constant B: /[MatMul_221]\n", + "Ignore MatMul due to non constant B: /[MatMul_231]\n", + "Ignore MatMul due to non constant B: /[MatMul_301]\n", + "Ignore MatMul due to non constant B: /[MatMul_315]\n", + "Ignore MatMul due to non constant B: /[MatMul_325]\n", + "Ignore MatMul due to non constant B: /[MatMul_395]\n", + "Ignore MatMul due to non constant B: /[MatMul_409]\n", + "Ignore MatMul due to non constant B: /[MatMul_419]\n", + "Ignore MatMul due to non constant B: /[MatMul_489]\n", + "Ignore MatMul due to non constant B: /[MatMul_503]\n", + "Ignore MatMul due to non constant B: /[MatMul_513]\n", + "Ignore MatMul due to non constant B: /[MatMul_583]\n", + "Ignore MatMul due to non constant B: /[MatMul_597]\n", + "Ignore MatMul due to non constant B: /[MatMul_607]\n", + "Ignore MatMul due to non constant B: /[MatMul_677]\n", + "Ignore MatMul due to non constant B: /[MatMul_691]\n", + "Ignore MatMul due to non constant B: /[MatMul_701]\n", + "Ignore MatMul due to non constant B: /[MatMul_771]\n", + "Ignore MatMul due to non constant B: /[MatMul_785]\n", + "Ignore MatMul due to non constant B: /[MatMul_795]\n", + "Ignore MatMul due to non constant B: /[MatMul_865]\n", + "Ignore MatMul due to non constant B: /[MatMul_879]\n", + "Ignore MatMul due to non constant B: /[MatMul_889]\n", + "Ignore MatMul due to non constant B: /[MatMul_959]\n", + "Ignore MatMul due to non constant B: /[MatMul_973]\n", + "Ignore MatMul due to non constant B: /[MatMul_983]\n", + "Ignore MatMul due to non constant B: /[MatMul_1053]\n", + "Ignore MatMul due to non constant B: /[MatMul_1067]\n", + "Ignore MatMul due to non constant B: /[MatMul_1077]\n", + "Ignore MatMul due to non constant B: /[MatMul_1147]\n", + "Ignore MatMul due to non constant B: /[MatMul_1161]\n", + "Ignore MatMul due to non constant B: /[MatMul_1171]\n", + "Ignore MatMul due to non constant B: /[Gemm_1187_MatMul]\n", + "Ignore MatMul due to non constant B: /[Gemm_1189_MatMul]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "is operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", "Warning: Unsupported operator Attention. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n", + "Warning: Unsupported operator BiasGelu. No schema registered for this operator.\n", "Warning: Unsupported operator SkipLayerNormalization. No schema registered for this operator.\n" ] } ], "source": [ - "from transformer_deploy.backends.ort_utils import optimize_onnx, create_model_for_provider, cpu_quantization\n", - "from onnxruntime.quantization import quantize_dynamic, QuantType\n", - "\n", - "optimize_onnx(\n", - " onnx_path=\"baseline.onnx\",\n", - " onnx_optim_model_path=\"baseline-optimized.onnx\",\n", - " fp16=True,\n", - " use_cuda=True,\n", - ")\n", - "\n", "cpu_quantization(input_model_path=\"baseline-optimized.onnx\", output_model_path=\"baseline-quantized.onnx\")" ] }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -2253,18 +1692,18 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[ONNX Runtime GPU (FP32)] mean=74.48ms, sd=0.55ms, min=73.87ms, max=76.61ms, median=74.36ms, 95p=75.89ms, 99p=76.34ms\n", - "[ONNX Runtime GPU (FP16)] mean=33.50ms, sd=0.62ms, min=32.90ms, max=37.58ms, median=33.39ms, 95p=34.42ms, 99p=35.47ms\n", - "[ONNX Runtime CPU (FP32)] mean=3767.02ms, sd=32.02ms, min=3720.72ms, max=3831.88ms, median=3766.35ms, 95p=3816.04ms, 99p=3828.71ms\n", - "[ONNX Runtime CPU (FP16)] mean=4607.67ms, sd=121.41ms, min=4513.24ms, max=4950.20ms, median=4573.18ms, 95p=4822.23ms, 99p=4924.61ms\n", - "[ONNX Runtime CPU (INT-8)] mean=3712.67ms, sd=45.19ms, min=3656.30ms, max=3827.99ms, median=3709.21ms, 95p=3788.00ms, 99p=3819.99ms\n" + "[ONNX Runtime GPU (FP32)] mean=76.38ms, sd=4.99ms, min=73.10ms, max=91.05ms, median=73.91ms, 95p=88.30ms, 99p=89.42ms\n", + "[ONNX Runtime GPU (FP16)] mean=34.21ms, sd=1.68ms, min=33.23ms, max=41.80ms, median=33.70ms, 95p=38.87ms, 99p=40.63ms\n", + "[ONNX Runtime CPU (FP32)] mean=4023.32ms, sd=92.76ms, min=3895.51ms, max=4267.63ms, median=4013.27ms, 95p=4170.44ms, 99p=4248.19ms\n", + "[ONNX Runtime CPU (FP16)] mean=3956.61ms, sd=167.65ms, min=3709.88ms, max=4188.62ms, median=3914.53ms, 95p=4180.81ms, 99p=4187.06ms\n", + "[ONNX Runtime CPU (INT-8)] mean=3336.29ms, sd=168.96ms, min=3170.64ms, max=3765.07ms, median=3299.52ms, 95p=3641.01ms, 99p=3740.26ms\n" ] } ], @@ -2301,30 +1740,16 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 12, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ee60729e561d492f82a7db2c93fc44ac", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/307 [00:00 not supported by TensorRT + # attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# in class XSoftmax(torch.autograd.Function): +# @staticmethod +def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + # r_mask = g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value) + # replace Byte by Char to get signed numbers + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Char"], + ) + output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int8))) + + +qdq_deberta_mapping: PatchModule = PatchModule( + module="transformers.models.deberta.modeling_deberta", + monkey_patch={ + "XSoftmax.symbolic": (symbolic, "symbolic"), + "DebertaEncoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), + }, +) + + +def toto(): + print("1") + + +qdq_deberta_v2_mapping: PatchModule = PatchModule( + module="transformers.models.deberta_v2.modeling_deberta_v2", + monkey_patch={ + "XSoftmax.symbolic": (toto, "toto"), + "DebertaV2Encoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), + }, +) diff --git a/src/transformer_deploy/QDQModels/QDQDistilbert.py b/src/transformer_deploy/QDQModels/QDQDistilbert.py new file mode 100644 index 00000000..e39c22e4 --- /dev/null +++ b/src/transformer_deploy/QDQModels/QDQDistilbert.py @@ -0,0 +1,20 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformer_deploy.QDQModels.ast_module_patch import PatchModule + + +qdq_distilbert_mapping: PatchModule = PatchModule( + module="transformers.models.distilbert.modeling_distilbert", +) diff --git a/src/transformer_deploy/QDQModels/QDQElectra.py b/src/transformer_deploy/QDQModels/QDQElectra.py new file mode 100644 index 00000000..9db1e619 --- /dev/null +++ b/src/transformer_deploy/QDQModels/QDQElectra.py @@ -0,0 +1,21 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformer_deploy.QDQModels.ast_module_patch import PatchModule + + +qdq_electra_mapping: PatchModule = PatchModule( + module="transformers.models.electra.modeling_electra", +) diff --git a/src/transformer_deploy/QDQModels/QDQRoberta.py b/src/transformer_deploy/QDQModels/QDQRoberta.py index 3f9ae139..db079166 100644 --- a/src/transformer_deploy/QDQModels/QDQRoberta.py +++ b/src/transformer_deploy/QDQModels/QDQRoberta.py @@ -14,1617 +14,25 @@ # coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# copied from Hugging Face transformers library -# modified parts (outside imports) are preceded by -> # QDQ change below - -"""PyTorch RoBERTa model. """ -import math import torch import torch.utils.checkpoint -from packaging import version -from pytorch_quantization import nn as quant_nn -from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import RobertaConfig -from transformers.activations import ACT2FN, gelu -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - MaskedLMOutput, - MultipleChoiceModelOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) -from transformers.modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "roberta-base" -_CONFIG_FOR_DOC = "RobertaConfig" -_TOKENIZER_FOR_DOC = "RobertaTokenizer" - -ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "roberta-base", - "roberta-large", - "roberta-large-mnli", - "distilroberta-base", - "roberta-base-openai-detector", - "roberta-large-openai-detector", - # See all RoBERTa models at https://huggingface.co/models?filter=roberta -] - - -class RobertaEmbeddings(nn.Module): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - if version.parse(torch.__version__) > version.parse("1.6.0"): - self.register_buffer( - "token_type_ids", - torch.zeros(self.position_ids.size(), dtype=torch.long), - persistent=False, - ) - - # End copy - self.padding_idx = config.pad_token_id - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx - ) - - def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 - ): - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) - else: - position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) - - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs - # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # noqa: E501 - # issue #5664 - if token_type_ids is None: - if hasattr(self, "token_type_ids"): - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - def create_position_ids_from_inputs_embeds(self, inputs_embeds): - """ - We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. - - Args: - inputs_embeds: torch.Tensor - - Returns: torch.Tensor - """ - input_shape = inputs_embeds.size()[:-1] - sequence_length = input_shape[1] - - position_ids = torch.arange( - self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device - ) - return position_ids.unsqueeze(0).expand(input_shape) - - -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta -class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - # QDQ change below - self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) - self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) - self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - # QDQ change below - self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # QDQ change below - attention_scores = torch.matmul( - self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2)) - ) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - # QDQ change below - context_layer = torch.matmul( - self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer) - ) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - -# Copied from transformers.models.bert.modeling_bert.BertSelfOutput -class RobertaSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - # QDQ change below - # Quantize Linear layer - self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size) - - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - # QDQ change below - # Quantize the inputs to the residual add - self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - # QDQ change below - # Quantize the inputs to the residual add - add_local = self.add_local_input_quantizer(hidden_states) - add_residual = self.add_residual_input_quantizer(input_tensor) - hidden_states = self.LayerNorm(add_local + add_residual) - return hidden_states - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta -class RobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): - super().__init__() - self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type) - self.output = RobertaSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.bert.modeling_bert.BertIntermediate -class RobertaIntermediate(nn.Module): - def __init__(self, config): - super().__init__() - # QDQ change below - self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_bert.BertOutput -class RobertaOutput(nn.Module): - def __init__(self, config): - super().__init__() - # QDQ change below - # Quantize Linear layer - self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - # QDQ change below - # Quantize the inputs to the residual add - self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - # QDQ change below - # Quantize the inputs to the residual add - add_local = self.add_local_input_quantizer(hidden_states) - add_residual = self.add_residual_input_quantizer(input_tensor) - hidden_states = self.LayerNorm(add_local + add_residual) - return hidden_states - - -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta -class RobertaLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = RobertaAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaAttention(config, position_embedding_type="absolute") - self.intermediate = RobertaIntermediate(config) - self.output = RobertaOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" # noqa: E501 - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output - ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta -class RobertaEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_bert.BertPooler -class RobertaPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class RobertaPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RobertaConfig - base_model_prefix = "roberta" - supports_gradient_checkpointing = True - - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value - - def update_keys_to_ignore(self, config, del_keys_to_ignore): - """Remove some keys from ignore list""" - if not config.tie_word_embeddings: - # must make a new list, or the class variable gets modified! - self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore] - self._keys_to_ignore_on_load_missing = [ - k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore - ] - - -ROBERTA_START_DOCSTRING = r""" - - This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) - - This model is also a PyTorch `torch.nn.Module `__ - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to - general usage and behavior. - - Parameters: - config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model - weights. -""" - -ROBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, - 1]``: - - - 0 corresponds to a `sentence A` token, - - 1 corresponds to a `sentence B` token. - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, - config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. - return_dict (:obj:`bool`, `optional`): - Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_START_DOCSTRING, -) -class RobertaModel(RobertaPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration - set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` - argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an - input to the forward pass. - - .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762 - - """ - - _keys_to_ignore_on_load_missing = [r"position_ids"] - - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta - def __init__(self, config, add_pooling_layer=True): - super().__init__(config) - self.config = config - - self.embeddings = RobertaEmbeddings(config) - self.encoder = RobertaEncoder(config) - - self.pooler = RobertaPooler(config) if add_pooling_layer else None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: +from transformer_deploy.QDQModels.ast_module_patch import PatchModule - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - """ # noqa: E501 - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - -@add_start_docstrings( - """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING -) -class RobertaForCausalLM(RobertaPreTrainedModel): - _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config): - super().__init__(config) - - if not config.is_decoder: - logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") - - self.roberta = RobertaModel(config, add_pooling_layer=False) - self.lm_head = RobertaLMHead(config) - - # The LM head weights require special treatment only when they are tied with the word embeddings - self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.lm_head.decoder - - def set_output_embeddings(self, new_embeddings): - self.lm_head.decoder = new_embeddings - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - - Returns: - - Example:: - - >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig - >>> import torch - - >>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base') - >>> config = RobertaConfig.from_pretrained("roberta-base") - >>> config.is_decoder = True - >>> model = RobertaForCausalLM.from_pretrained('roberta-base', config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - """ # noqa: E501 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - -@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) -class RobertaForMaskedLM(RobertaPreTrainedModel): - _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config): - super().__init__(config) - - if config.is_decoder: - logger.warning( - "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.roberta = RobertaModel(config, add_pooling_layer=False) - self.lm_head = RobertaLMHead(config) - - # The LM head weights require special treatment only when they are tied with the word embeddings - self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.lm_head.decoder - - def set_output_embeddings(self, new_embeddings): - self.lm_head.decoder = new_embeddings - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class RobertaLMHead(nn.Module): - """Roberta Head for masked language modeling.""" - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.decoder = nn.Linear(config.hidden_size, config.vocab_size) - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias - - def forward(self, features, **kwargs): - x = self.dense(features) - x = gelu(x) - x = self.layer_norm(x) - - # project back to size of vocabulary with bias - x = self.decoder(x) - - return x - - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - - -@add_start_docstrings( - """ - RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ROBERTA_START_DOCSTRING, -) -class QDQRobertaForSequenceClassification(RobertaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.roberta = RobertaModel(config, add_pooling_layer=False) - self.classifier = RobertaClassificationHead(config) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., - config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ROBERTA_START_DOCSTRING, -) -class RobertaForMultipleChoice(RobertaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - token_type_ids=None, - attention_mask=None, - labels=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., - num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See - :obj:`input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) - - outputs = self.roberta( - flat_input_ids, - position_ids=flat_position_ids, - token_type_ids=flat_token_type_ids, - attention_mask=flat_attention_mask, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ROBERTA_START_DOCSTRING, -) -class RobertaForTokenClassification(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.roberta = RobertaModel(config, add_pooling_layer=False) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - - 1]``. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) - ) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class RobertaClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) - - def forward(self, features, **kwargs): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x) - x = self.dense(x) - x = torch.tanh(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -@add_start_docstrings( - """ - Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROBERTA_START_DOCSTRING, -) -class RobertaForQuestionAnswering(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.roberta = RobertaModel(config, add_pooling_layer=False) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=QuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the - sequence are not taken into account for computing the loss. - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the - sequence are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ +def qdq_create_position_tensorrt(input_ids, padding_idx, past_key_values_length=0): # QDQ change below # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. # int() -> float() because of a limitations in cumsum operator implementation in TensorRT mask = input_ids.ne(padding_idx).float() incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx + + +qdq_roberta_mapping: PatchModule = PatchModule( + module="transformers.models.roberta.modeling_roberta", + monkey_patch={ + "create_position_ids_from_input_ids": (qdq_create_position_tensorrt, "qdq_create_position_tensorrt"), + }, +) diff --git a/src/transformer_deploy/QDQModels/ast_module_patch.py b/src/transformer_deploy/QDQModels/ast_module_patch.py new file mode 100644 index 00000000..d2f1b193 --- /dev/null +++ b/src/transformer_deploy/QDQModels/ast_module_patch.py @@ -0,0 +1,196 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import code +import importlib +import inspect +import logging +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple + +from transformer_deploy.QDQModels.ast_operator_patch import Patch2ArgsNode, PatchAdd2ArgsNode, PatchNode + + +# list of Pytorch operations to optimize, you can reduce it to increase PTQ/QAT accuracy +op_to_quant: List[PatchNode] = [ + Patch2ArgsNode(op="matmul"), + Patch2ArgsNode(op="add"), + Patch2ArgsNode(op="bmm"), + PatchAdd2ArgsNode(op="LayerNorm"), +] + + +@dataclass +class PatchModule: + module: str + monkey_patch: Dict[str, Tuple[Callable, str]] = field(default_factory=dict) + + def print_code(self): + for class_name, cl in self.monkey_patch.items(): + print("---------") + print(class_name) + inspect.getsource(cl) + + def restore(self): + model_module = importlib.import_module(name=self.module) + importlib.reload(model_module) + + +def init_quantizer(name: str) -> ast.Assign: + """ + Generate quantization node initialization to add to the end of __init__() + :param name: generated name of the node + :return: quantization init ast node + """ + quant_linear = ast.Attribute(value=ast.Name(id="quant_nn", ctx=ast.Load()), attr="QuantLinear", ctx=ast.Load()) + default_quant_desc_input = ast.Attribute(value=quant_linear, attr="default_quant_desc_input", ctx=ast.Load()) + tensor_quant = ast.Name(id="TensorQuantizer", ctx=ast.Load()) + quant_value = ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=name, ctx=ast.Store()) + return ast.Assign( + targets=[quant_value], + value=ast.Call(func=tensor_quant, args=[default_quant_desc_input], keywords=[]), + ) + + +def patch_nodes(head_node: ast.Module) -> Tuple[ast.Module, List[str]]: + """ + Replace an operation to optimize by its optimized version. + May have to generate some quantization node names. + :param head_node: ast node to modify + :return: the modified ast tree and the list of generated quantization nodes + """ + q_attr_names: List[str] = list() + for node in ast.walk(head_node): # type: ast.Call + for op in op_to_quant: + if op.should_patch(node=node): + quant_names = op.patch(node=node, nb_quant_node=len(q_attr_names)) + q_attr_names.extend(quant_names) + + return head_node, q_attr_names + + +def add_init_quantizer(head_node: ast.Module, q_attr_names: List[str]) -> ast.Module: + """ + Add initialization of quantizer to __init__() + :param head_node: node related to a class to optimize + :param q_attr_names: list of quantizer names to init + :return: modified ast tree + """ + for node in ast.walk(head_node): # type: ast.FunctionDef + if isinstance(node, ast.FunctionDef) and node.name == "__init__": + for name in q_attr_names: + quantizer = init_quantizer(name) + node.body.append(quantizer) + return head_node + + +def add_qdq_to_class_name(head_node: ast.Module, new_class_name: str) -> ast.Module: + """ + Change the name of the class to optimize (may help in debugging / error messages) + :param head_node: node related to the class to optimize + :param new_class_name: new name to use + :return: the modified ast tree + """ + for node in ast.walk(head_node): # type: ast.ClassDef + if isinstance(node, ast.ClassDef): + node.name = new_class_name + return head_node + + +def add_quant_to_module(module_to_patch: type, new_module_name: str) -> ast.Module: + """ + Modify a class to add quantization operations around each torch operation to optimize. + :param module_to_patch: Pytorch module to patch + :param new_module_name: new name for the module + :return: modified ast tree + """ + source_code = inspect.getsource(module_to_patch) + head = ast.parse(source_code) + head, nodes_to_add = patch_nodes(head) + add_init_quantizer(head_node=head, q_attr_names=nodes_to_add) + head = add_qdq_to_class_name(head_node=head, new_class_name=new_module_name) + return head + + +def contains_op(node: ast.AST) -> bool: + """ + Check if a tree contains some operations to optimize. + :param node: Head of the ast tree + :return: True if ast tree contains operations to optimize + """ + for node in ast.walk(node): + for op in op_to_quant: + if op.should_patch(node=node): + return True + return False + + +def list_class_to_patch(model_module) -> List[str]: + """ + List all classes which contain operations to be optimized. + :param model_module: Pytorch module + :return: the list of module names to be optimized + """ + module_names: List[str] = list() + module_source_code = inspect.getsource(model_module) + head_node = ast.parse(module_source_code) + for node in ast.walk(head_node): + if isinstance(node, ast.ClassDef) and contains_op(node=node): + module_names.append(node.name) + return module_names + + +def load_missing_imports(model_module) -> None: + """ + Execute some imports in the context of a module. + Override Linear layer by its quantized version + :param model_module: module to use for the imports + """ + import_code = """ + from pytorch_quantization import nn as quant_nn + from pytorch_quantization.nn import TensorQuantizer + import torch + torch.nn.Linear = quant_nn.QuantLinear + """ + # remove extra spaces + import_code = inspect.cleandoc(import_code) + # execute the code in the module context + exec(import_code, model_module.__dict__, model_module.__dict__) + + +def add_quantization_to_model( + module_path: str, + class_to_patch: Optional[List[str]], +): + """ + Add quantization support to a model. + :param module_path: model module to optimize + :param class_to_patch: name of modules to patch, if None it will be auto-detected. + :return: backup of original classes + """ + model_module = importlib.import_module(name=module_path) + load_missing_imports(model_module) + + if class_to_patch is None or len(class_to_patch) == 0: + class_to_patch = list_class_to_patch(model_module=model_module) + logging.info(f"modify class {', '.join(class_to_patch)}") + + for class_name in class_to_patch: + module_to_patch = getattr(model_module, class_name) + head = add_quant_to_module(module_to_patch=module_to_patch, new_module_name=class_name) + head = ast.fix_missing_locations(head) + module_patched: code = compile(head, filename="", mode="exec") + # execute the code in the module context so it overrides the original classes and leverage existing imports + exec(module_patched, model_module.__dict__, model_module.__dict__) diff --git a/src/transformer_deploy/QDQModels/ast_operator_patch.py b/src/transformer_deploy/QDQModels/ast_operator_patch.py new file mode 100644 index 00000000..7e457f9d --- /dev/null +++ b/src/transformer_deploy/QDQModels/ast_operator_patch.py @@ -0,0 +1,112 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import ast +from typing import List + + +class PatchNode(object): + __metaclass__ = abc.ABCMeta + torch_op_to_quantize: str + + @abc.abstractmethod + def should_patch(self, node: ast.AST) -> bool: + """ + Check if a node should be patched + :param node: node to check + :return: return True if it matches the operator provided during the __init__ + """ + raise Exception("to implement") + + @abc.abstractmethod + def patch(self, node: ast.AST, nb_quant_node: int) -> List[str]: + """ + Patch node by adding quantizer nodes around the operator provided during the __init__ + :param node: node to patch + :param nb_quant_node: number of existing quantizer node + :return: return list of generated quantizer node names + """ + raise Exception("to implement") + + @staticmethod + def _wrap_attr(quantizer_name: str, tensor_var: ast.expr) -> ast.Call: + """ + Generate quantization wrapping each attribute of a torch operation to optimize (matmul, add, etc.) + :param quantizer_name: generated quantization name + :param tensor_var: the variable to wrap + :return: the ast tree to replace the original variable + """ + return ast.Call( + func=ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=quantizer_name, ctx=ast.Load()), + args=[tensor_var], + keywords=[], + ) + + def get_quant_name(self, node_id: int) -> str: + return f"{self.torch_op_to_quantize.lower()}_quantizer_{node_id}" + + +class Patch2ArgsNode(PatchNode): + def __init__(self, op: str): + """ + Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b)) + :param op: operator to match + """ + self.torch_op_to_quantize = op + + def should_patch(self, node: ast.AST) -> bool: + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "torch" + and node.func.attr == self.torch_op_to_quantize + ) + + def patch(self, node: ast.AST, nb_quant_node: int) -> List[str]: + q_attr_names = list() + for index in range(2): # only apply transfo to the 2 first args + arg = node.args[index] + q_name = self.get_quant_name(nb_quant_node + len(q_attr_names)) + q_attr_names.append(q_name) + node.args[index] = self._wrap_attr(q_name, arg) + return q_attr_names + + +class PatchAdd2ArgsNode(PatchNode): + def __init__(self, op: str): + """ + Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b)) + :param op: operator to match + """ + self.torch_op_to_quantize = op + + def should_patch(self, node: ast.AST) -> bool: + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == self.torch_op_to_quantize + and isinstance(node.args, list) + and len(node.args) == 1 + and isinstance(node.args[0], ast.BinOp) + and isinstance(node.args[0].op, ast.Add) + ) + + def patch(self, node: ast.AST, nb_quant_node: int) -> List[str]: + left_name = self.get_quant_name(nb_quant_node) + right_name = self.get_quant_name(nb_quant_node + 1) + node.args[0].left = self._wrap_attr(left_name, node.args[0].left) + node.args[0].right = self._wrap_attr(right_name, node.args[0].right) + return [left_name, right_name] diff --git a/src/transformer_deploy/QDQModels/calibration_utils.py b/src/transformer_deploy/QDQModels/calibration_utils.py new file mode 100644 index 00000000..f03f42af --- /dev/null +++ b/src/transformer_deploy/QDQModels/calibration_utils.py @@ -0,0 +1,111 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch.cuda +from pytorch_quantization import calib +from pytorch_quantization import nn as quant_nn +from pytorch_quantization.tensor_quant import QuantDescriptor +from transformers import PreTrainedModel + +from transformer_deploy.QDQModels.patch import add_qdq, remove_qdq + + +class QATCalibrate: + def __init__(self, method: str = "histogram", percentile: float = 99.999, per_channel: bool = True): + """ + Calibration will learn how a float tensor should be mapped to an integer tensor. + Will learn range, bias and scale. + Quantization targets signe 8 bits integers as it's the best supported type for Nvidia GPUs + (there are dedicated 8 bits integer tensor cores on most modern Nvidia GPU architectures). + Don't forget to call setup_model_qat at some point. + :param method: the method calibration to use. One of [histogram, percentile]. + Recommended method for transformers is "histogram". + :param percentile: for histogram method, what do you define as an outlier value + :param per_channel: calibration granularity. per channel == per dimension. + """ + assert torch.cuda.is_available(), "CUDA not available" + self.model: Optional[PreTrainedModel] = None + assert method in [ + "histogram", + "max", + ], f"unknown calibration method (for NLP): {method}" + self.calib_method: str = method + self.calibration_percentile: float = percentile + self.calibration_per_channel: bool = per_channel + + def setup_nvidia_qat(self) -> None: + """ + Setup Nvidia QAT library global variables. + Should be called before initializing a model. + """ + input_desc = QuantDescriptor(num_bits=8, calib_method=self.calib_method) + axis = (0,) if self.calibration_per_channel else None + weight_desc = QuantDescriptor(num_bits=8, axis=axis) + quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) + quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) + + def setup_model_qat(self, model: PreTrainedModel) -> None: + """ + Enable calibration on each tensor to quantize. + :param model: model to optimize + """ + self.model = model + model = self.model.cuda() + # Find the TensorQuantizer and enable calibration + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + module.disable_quant() + module.enable_calib() + else: + module.disable() + + def finalize_calibration(self) -> None: + """ + Disable calibration process and enable quantized nodes. + """ + calib_method = "max" if self.calib_method == "max" else "percentile" + for _, module in self.model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + if isinstance(module._calibrator, calib.MaxCalibrator): + module.load_calib_amax() + else: + # strict=False -> avoid Exception when some quantizer are never used + # (because of a condition for instance) + module.load_calib_amax(calib_method, percentile=self.calibration_percentile, strict=False) + module.enable_quant() + module.disable_calib() + else: + module.enable() + # move back model to GPU memory + self.model.cuda() + + @staticmethod + def restore(): + """ + Restore behavior without quantization support. + """ + remove_qdq() + + def __enter__(self): + add_qdq() + self.setup_nvidia_qat() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.finalize_calibration() diff --git a/src/transformer_deploy/QDQModels/patch.py b/src/transformer_deploy/QDQModels/patch.py new file mode 100644 index 00000000..1272daca --- /dev/null +++ b/src/transformer_deploy/QDQModels/patch.py @@ -0,0 +1,74 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import logging +from typing import List, Optional + +from transformer_deploy.QDQModels.ast_module_patch import PatchModule, add_quantization_to_model +from transformer_deploy.QDQModels.QDQAlbert import qdq_albert_mapping +from transformer_deploy.QDQModels.QDQBert import qdq_bert_mapping +from transformer_deploy.QDQModels.QDQDeberta import qdq_deberta_mapping, qdq_deberta_v2_mapping +from transformer_deploy.QDQModels.QDQDistilbert import qdq_distilbert_mapping +from transformer_deploy.QDQModels.QDQElectra import qdq_electra_mapping +from transformer_deploy.QDQModels.QDQRoberta import qdq_roberta_mapping + + +tested_models: List[PatchModule] = [ + qdq_bert_mapping, + qdq_roberta_mapping, + qdq_electra_mapping, + qdq_distilbert_mapping, + qdq_albert_mapping, + qdq_deberta_mapping, # quantization is ok, ONNX export doesn't work + qdq_deberta_v2_mapping, # quantization is ok, ONNX export doesn't work +] + + +def patch_model(patch: PatchModule) -> None: + """ + Perform modifications to model to make it work with ONNX export and quantization. + :param patch: an object containing all the information to perform a modification + """ + add_quantization_to_model(module_path=patch.module, class_to_patch=None) + model_module = importlib.import_module(patch.module) + for target, (modified_object, object_name) in patch.monkey_patch.items(): + source_code = inspect.getsource(modified_object) + source_code += f"\n{target} = {object_name}" + exec(source_code, model_module.__dict__, model_module.__dict__) + + +def add_qdq(modules_to_patch: Optional[List[PatchModule]] = None) -> None: + """ + Add quantization support to each tested model by modifyin their AST. + :param modules_to_patch: list of operator to target + """ + if modules_to_patch is None: + modules_to_patch = tested_models + for patch in modules_to_patch: + logging.info(f"add quantization to module {patch.module}") + patch_model(patch) + + +def remove_qdq(modules_to_patch: Optional[List[PatchModule]] = None) -> None: + """ + Restore AST of modified modules. + :param modules_to_patch: list of operator to target + """ + if modules_to_patch is None: + modules_to_patch = tested_models + for patch in modules_to_patch: + logging.info(f"restore module {patch.module}") + patch.restore() diff --git a/src/transformer_deploy/backends/ort_utils.py b/src/transformer_deploy/backends/ort_utils.py index 6c74e37e..fa59e3ae 100644 --- a/src/transformer_deploy/backends/ort_utils.py +++ b/src/transformer_deploy/backends/ort_utils.py @@ -32,6 +32,14 @@ def create_model_for_provider( path: str, provider_to_use: Union[str, List], nb_threads: int = multiprocessing.cpu_count(), nb_instances: int = 0 ) -> InferenceSession: + """ + Create an ONNX Runtime instance. + :param path: path to ONNX file + :param provider_to_use: provider to use for inference + :param nb_threads: intra_op_num_threads to use + :param nb_instances: inter_op_num_threads to use + :return: ONNX Runtime inference session + """ options = SessionOptions() options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL if type(provider_to_use) != list: @@ -47,6 +55,14 @@ def create_model_for_provider( def convert_to_onnx( model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor], opset: int = 12 ) -> None: + """ + Convert a Pytorch model to an ONNX graph by tracing the provided input inside the Pytorch code. + :param model_pytorch: Pytorch model + :param output_path: where to save ONNX file + :param inputs_pytorch: Tensor, can be dummy data, shape is not important as we declare all axes as dynamic. + Should be on the same device than the model (CPU or GPU) + :param opset: version of ONNX protocol to use, usually 12, or 13 if you use per channel quantized model + """ # dynamic axis == variable length axis dynamic_axis = OrderedDict() for k in inputs_pytorch.keys(): @@ -67,7 +83,32 @@ def convert_to_onnx( ) +def convert_to_quant_onnx( + model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor] +) -> None: + """ + Convert a quantized Pytorch model to ONNX file. + :param model_pytorch: Pytorch model + :param output_path: ONNX file path + :param inputs_pytorch: some dummy input (Pytorch tensor on the same device than the model) + """ + from pytorch_quantization.nn import TensorQuantizer + + TensorQuantizer.use_fb_fake_quant = True + convert_to_onnx(model_pytorch=model_pytorch, output_path=output_path, inputs_pytorch=inputs_pytorch, opset=13) + TensorQuantizer.use_fb_fake_quant = False + + def optimize_onnx(onnx_path: str, onnx_optim_model_path: str, fp16: bool, use_cuda: bool) -> None: + """ + ONNX Runtime transformer graph optimization. + Performs some operator fusion (merge several nodes of the graph in a single one) + and may convert some nodes to reduced precision. + :param onnx_path: ONNX input path + :param onnx_optim_model_path: where to save optimized model + :param fp16: use mixed precision (faster inference) + :param use_cuda: perform optimization on GPU (should ) + """ optimization_options = FusionOptions("bert") optimization_options.enable_gelu_approximation = False # additional optimization optimized_model: BertOnnxModel = optimizer.optimize_model( @@ -85,7 +126,12 @@ def optimize_onnx(onnx_path: str, onnx_optim_model_path: str, fp16: bool, use_cu optimized_model.save_model_to_file(onnx_optim_model_path) -def cpu_quantization(input_model_path: str, output_model_path: str): +def cpu_quantization(input_model_path: str, output_model_path: str) -> None: + """ + ONNX CPU only dynamic quantization + :param input_model_path: ONNX graph (float) to quantize + :param output_model_path: where to save quantized model + """ quantize_dynamic( model_input=input_model_path, model_output=output_model_path, diff --git a/src/transformer_deploy/benchmarks/utils.py b/src/transformer_deploy/benchmarks/utils.py index 4f264b0f..1c2639f4 100644 --- a/src/transformer_deploy/benchmarks/utils.py +++ b/src/transformer_deploy/benchmarks/utils.py @@ -22,7 +22,12 @@ import torch -def print_timings(name: str, timings: List[float]): +def print_timings(name: str, timings: List[float]) -> None: + """ + Format and print latencies + :param name: engine name + :param timings: latencies measured during the inference + """ mean_time = 1e3 * np.mean(timings) std_time = 1e3 * np.std(timings) min_time = 1e3 * np.min(timings) @@ -40,12 +45,20 @@ def print_timings(name: str, timings: List[float]): ) -def setup_logging(level: int = logging.INFO): +def setup_logging(level: int = logging.INFO) -> None: + """ + Set the generic Python logger + :param level: logger level + """ logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=level) @contextmanager -def track_infer_time(buffer: [int]): +def track_infer_time(buffer: List[int]) -> None: + """ + A context manager to perform latency measures + :param buffer: a List where to save latencies for each input + """ start = time.perf_counter() yield end = time.perf_counter() @@ -55,7 +68,15 @@ def track_infer_time(buffer: [int]): def generate_input( seq_len: int, batch_size: int, include_token_ids: bool, device: str = "cuda" ) -> Tuple[Dict[str, torch.Tensor], Dict[str, np.ndarray]]: - assert device in ["cuda", "cpu"] + """ + Generate dummy inputs. + :param seq_len: number of token per input. + :param batch_size: first dimension of the tensor + :param include_token_ids: should we add token_type_ids + :param device: where to store tensors (Pytorch only). One of [cpu, cuda] + :return: a tuple of tensors, Pytorch and numpy + """ + assert device in ["cpu", "cuda"] shape = (batch_size, seq_len) inputs_pytorch: OrderedDict[str, torch.Tensor] = OrderedDict() inputs_pytorch["input_ids"] = torch.randint(high=100, size=shape, dtype=torch.long, device=device) diff --git a/src/transformer_deploy/convert.py b/src/transformer_deploy/convert.py index 65245233..08faf12a 100644 --- a/src/transformer_deploy/convert.py +++ b/src/transformer_deploy/convert.py @@ -44,7 +44,14 @@ def check_accuracy( engine_name: str, pytorch_output: List[np.ndarray], engine_output: List[np.ndarray], tolerance: float -): +) -> None: + """ + Compare engine predictions with a reference. Assert that the difference is under a threshold. + :param engine_name: string used in error message, if any + :param pytorch_output: reference output used for the comparaison + :param engine_output: output from the engine + :param tolerance: if difference in outputs is above threshold, an error will be raised + """ discrepency = compare_outputs(pytorch_output=pytorch_output, engine_output=engine_output) assert discrepency < tolerance, ( f"{engine_name} discrepency is too high ({discrepency:.2f} > {tolerance}):\n" @@ -60,6 +67,13 @@ def check_accuracy( def launch_inference( infer: Callable, inputs: List[Dict[str, Union[np.ndarray, torch.Tensor]]], nb_measures: int ) -> Tuple[List[np.ndarray], List[float]]: + """ + Perform inference and measure latency + :param infer: a lambda which will perform the inference + :param inputs: tensor compatible with the lambda (Torch tensor for Pytorch, or numpy otherwise) + :param nb_measures: number of measures to perform for the latency measure + :return: a tuple of model output and inference latencies + """ assert type(inputs) == list assert len(inputs) > 0 outputs = list() diff --git a/src/transformer_deploy/utils/args.py b/src/transformer_deploy/utils/args.py index 111e70c4..4e8f1536 100644 --- a/src/transformer_deploy/utils/args.py +++ b/src/transformer_deploy/utils/args.py @@ -16,6 +16,11 @@ def parse_args(commands: List[str] = None) -> argparse.Namespace: + """ + Parse command line arguments + :param commands: to provide command line programatically + :return: parsed command line + """ parser = argparse.ArgumentParser( description="optimize and deploy transformers", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) diff --git a/src/transformer_deploy/utils/python_tokenizer.py b/src/transformer_deploy/utils/python_tokenizer.py index 22e84a58..e64f2c61 100644 --- a/src/transformer_deploy/utils/python_tokenizer.py +++ b/src/transformer_deploy/utils/python_tokenizer.py @@ -14,7 +14,7 @@ # noinspection DuplicatedCode import os -from typing import Dict +from typing import Dict, List import numpy as np @@ -27,14 +27,23 @@ class TritonPythonModel: is_tensorrt: bool tokenizer: PreTrainedTokenizer - def initialize(self, args: Dict[str, str]): + def initialize(self, args: Dict[str, str]) -> None: + """ + Initialize the tokenization process + :param args: arguments from Triton config file + """ # more variables in https://github.com/triton-inference-server/python_backend/blob/main/src/python.cc path: str = os.path.join(args["model_repository"], args["model_version"]) model_name: str = args["model_name"] self.is_tensorrt = "tensorrt" in model_name self.tokenizer = AutoTokenizer.from_pretrained(path) - def execute(self, requests): + def execute(self, requests) -> List[List[pb_utils.Tensor]]: + """ + Parse and tokenize each request + :param requests: 1 or more requests received by Triton server. + :return: text as input tensors + """ responses = [] # for loop for batch requests (disabled in our case) for request in requests: diff --git a/tests/test_ast_modifications.py b/tests/test_ast_modifications.py new file mode 100644 index 00000000..7d98f7e9 --- /dev/null +++ b/tests/test_ast_modifications.py @@ -0,0 +1,92 @@ +# Copyright 2021, Lefebvre Sarrut Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import importlib +import logging + +import torch +from torch import nn + +from transformer_deploy.QDQModels.ast_module_patch import add_quant_to_module, list_class_to_patch +from transformer_deploy.QDQModels.ast_operator_patch import Patch2ArgsNode, PatchAdd2ArgsNode + + +class FakeModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(in_features=5, out_features=5, bias=True) + + def forward(self, inputs: torch.Tensor): + a: torch.Tensor = self.linear(inputs) + b = torch.ones(a.shape) + c = torch.matmul(a, b) + d = nn.LayerNorm(a + c) + return d + + +expected_class = """ +class QDQFakeModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(in_features=5, out_features=5, bias=True) + self.matmul_quantizer_0 = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_quantizer_1 = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.layernorm_quantizer_2 = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.layernorm_quantizer_3 = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def forward(self, inputs: torch.Tensor): + a: torch.Tensor = self.linear(inputs) + b = torch.ones(a.shape) + c = torch.matmul(self.matmul_quantizer_0(a), self.matmul_quantizer_1(b)) + d = nn.LayerNorm(self.layernorm_quantizer_2(a) + self.layernorm_quantizer_3(c)) + return d +""".strip() + + +def test_list_class(): + model_module = importlib.import_module(name=__name__) + class_to_patch = list_class_to_patch(model_module=model_module) + assert class_to_patch == ["FakeModel"] + + +def test_add_quant(): + head = add_quant_to_module(module_to_patch=FakeModel, new_module_name="QDQFakeModel") + head = ast.fix_missing_locations(head) + logging.error(ast.unparse(head)) + assert ast.unparse(head) == expected_class + + +def test_patch_2_args_node(): + source_code = "torch.matmul(a, b)" + patch = Patch2ArgsNode(op="matmul") + head: ast.AST = ast.parse(source_code).body[0].value + assert patch.should_patch(head) + head_patched = patch.patch(node=head, nb_quant_node=0) + assert ast.unparse(head) == "torch.matmul(self.matmul_quantizer_0(a), self.matmul_quantizer_1(b))" + assert head_patched == ["matmul_quantizer_0", "matmul_quantizer_1"] + + +def test_add_2_args_node(): + source_code = "nn.LayerNorm(hidden_states + input_tensor)" + patch = PatchAdd2ArgsNode(op="LayerNorm") + head: ast.AST = ast.parse(source_code).body[0].value + assert patch.should_patch(head) + head_patched = patch.patch(node=head, nb_quant_node=0) + assert ( + ast.unparse(head) + == "nn.LayerNorm(self.layernorm_quantizer_0(hidden_states) + self.layernorm_quantizer_1(input_tensor))" + ) + assert head_patched == ["layernorm_quantizer_0", "layernorm_quantizer_1"]