diff --git a/guides/video-swin-transformer-kerascv-and-torchvision.ipynb b/guides/video-swin-transformer-kerascv-and-torchvision.ipynb new file mode 100644 index 0000000..248d6c3 --- /dev/null +++ b/guides/video-swin-transformer-kerascv-and-torchvision.ipynb @@ -0,0 +1,1029 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "104894cf", + "metadata": { + "papermill": { + "duration": 0.006219, + "end_time": "2024-03-25T17:36:29.648158", + "exception": false, + "start_time": "2024-03-25T17:36:29.641939", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Originated from https://github.com/keras-team/keras-cv/pull/2369" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "706e2a80", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-03-25T17:36:29.662221Z", + "iopub.status.busy": "2024-03-25T17:36:29.661017Z", + "iopub.status.idle": "2024-03-25T17:36:30.600470Z", + "shell.execute_reply": "2024-03-25T17:36:30.599532Z" + }, + "papermill": { + "duration": 0.948976, + "end_time": "2024-03-25T17:36:30.602921", + "exception": false, + "start_time": "2024-03-25T17:36:29.653945", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "import os\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dcb8ddf2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:36:30.616617Z", + "iopub.status.busy": "2024-03-25T17:36:30.616146Z", + "iopub.status.idle": "2024-03-25T17:36:30.621344Z", + "shell.execute_reply": "2024-03-25T17:36:30.620603Z" + }, + "papermill": { + "duration": 0.014131, + "end_time": "2024-03-25T17:36:30.623244", + "exception": false, + "start_time": "2024-03-25T17:36:30.609113", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "os.environ[\"KERAS_BACKEND\"] = \"torch\" # 'torch', 'tensorflow', 'jax'\n", + "\n", + "warnings.simplefilter(action=\"ignore\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6a5d5959", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:36:30.636389Z", + "iopub.status.busy": "2024-03-25T17:36:30.635887Z", + "iopub.status.idle": "2024-03-25T17:37:02.758550Z", + "shell.execute_reply": "2024-03-25T17:37:02.757028Z" + }, + "papermill": { + "duration": 32.132307, + "end_time": "2024-03-25T17:37:02.761267", + "exception": false, + "start_time": "2024-03-25T17:36:30.628960", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13111, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1248/1248), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (488/488), done.\u001b[K\r\n", + "remote: Total 13111 (delta 861), reused 1044 (delta 748), pack-reused 11863\u001b[K\r\n", + "Receiving objects: 100% (13111/13111), 25.45 MiB | 20.42 MiB/s, done.\r\n", + "Resolving deltas: 100% (9312/9312), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "markdown", + "id": "cbeb45fd", + "metadata": { + "papermill": { + "duration": 0.006975, + "end_time": "2024-03-25T17:37:02.775736", + "exception": false, + "start_time": "2024-03-25T17:37:02.768761", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# KerasCV: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9803f5a4", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:37:02.792296Z", + "iopub.status.busy": "2024-03-25T17:37:02.791890Z", + "iopub.status.idle": "2024-03-25T17:37:24.692452Z", + "shell.execute_reply": "2024-03-25T17:37:24.691574Z" + }, + "papermill": { + "duration": 21.911977, + "end_time": "2024-03-25T17:37:24.695081", + "exception": false, + "start_time": "2024-03-25T17:37:02.783104", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'3.0.5'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "\n", + "keras.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "531afe0a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:37:24.712504Z", + "iopub.status.busy": "2024-03-25T17:37:24.711588Z", + "iopub.status.idle": "2024-03-25T17:37:34.152330Z", + "shell.execute_reply": "2024-03-25T17:37:34.150808Z" + }, + "papermill": { + "duration": 9.45242, + "end_time": "2024-03-25T17:37:34.155078", + "exception": false, + "start_time": "2024-03-25T17:37:24.702658", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5 -q\n", + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q\n", + "\n", + "def vswin_tiny():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_tiny_kinetics400.weights.h5'\n", + " )\n", + " \n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_tiny_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9e2c445a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:37:34.171992Z", + "iopub.status.busy": "2024-03-25T17:37:34.171578Z", + "iopub.status.idle": "2024-03-25T17:39:09.556885Z", + "shell.execute_reply": "2024-03-25T17:39:09.555334Z" + }, + "papermill": { + "duration": 95.39738, + "end_time": "2024-03-25T17:39:09.559788", + "exception": false, + "start_time": "2024-03-25T17:37:34.162408", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5 -q\n", + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400_classifier.weights.h5 -q\n", + "\n", + "def vswin_small():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_small_kinetics400.weights.h5'\n", + " )\n", + " \n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_small_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fb20a006", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:39:09.578126Z", + "iopub.status.busy": "2024-03-25T17:39:09.577190Z", + "iopub.status.idle": "2024-03-25T17:39:27.045076Z", + "shell.execute_reply": "2024-03-25T17:39:27.043337Z" + }, + "papermill": { + "duration": 17.479658, + "end_time": "2024-03-25T17:39:27.048029", + "exception": false, + "start_time": "2024-03-25T17:39:09.568371", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5 -q\n", + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_classifier.weights.h5 -q\n", + "\n", + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_base_kinetics400.weights.h5'\n", + " )\n", + " \n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/working/keras-cv/videoswin_base_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1d9d05b2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:39:27.064796Z", + "iopub.status.busy": "2024-03-25T17:39:27.064362Z", + "iopub.status.idle": "2024-03-25T17:39:35.853874Z", + "shell.execute_reply": "2024-03-25T17:39:35.852721Z" + }, + "papermill": { + "duration": 8.800351, + "end_time": "2024-03-25T17:39:35.855874", + "exception": false, + "start_time": "2024-03-25T17:39:27.055523", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"video_classifier\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"video_classifier\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ videos (InputLayer) │ (None, 32, 224, 224, │ 0 │\n", + "│ │ 3) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ video_swin_backbone │ (None, 16, 7, 7, 768) │ 27,850,470 │\n", + "│ (VideoSwinBackbone) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ avg_pool │ (None, 768) │ 0 │\n", + "│ (GlobalAveragePooling3D) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ predictions (Dense) │ (None, 400) │ 307,600 │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ videos (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ video_swin_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m27,850,470\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBackbone\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ avg_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "│ (\u001b[38;5;33mGlobalAveragePooling3D\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ predictions (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m400\u001b[0m) │ \u001b[38;5;34m307,600\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Total params: 28,158,070 (107.41 MB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 28,158,070 (107.41 MB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Non-trainable params: 0 (0.00 B)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "keras_models = [vswin_tiny(), vswin_small(), vswin_base()]\n", + "keras_models[0].summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8533a98c", + "metadata": { + "papermill": { + "duration": 0.00777, + "end_time": "2024-03-25T17:39:35.871851", + "exception": false, + "start_time": "2024-03-25T17:39:35.864081", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "f4f1cc5f", + "metadata": { + "papermill": { + "duration": 0.007748, + "end_time": "2024-03-25T17:39:35.888094", + "exception": false, + "start_time": "2024-03-25T17:39:35.880346", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# TorchVision: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e4643e39", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:39:35.906217Z", + "iopub.status.busy": "2024-03-25T17:39:35.905835Z", + "iopub.status.idle": "2024-03-25T17:39:36.229742Z", + "shell.execute_reply": "2024-03-25T17:39:36.228912Z" + }, + "papermill": { + "duration": 0.335864, + "end_time": "2024-03-25T17:39:36.232060", + "exception": false, + "start_time": "2024-03-25T17:39:35.896196", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torchinfo import summary\n", + "from torchvision.models.video import Swin3D_T_Weights, Swin3D_S_Weights, Swin3D_B_Weights\n", + "\n", + "def torch_vswin_tiny():\n", + " torch_model = torchvision.models.video.swin3d_t(\n", + " weights=Swin3D_T_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model\n", + "\n", + "def torch_vswin_small():\n", + " torch_model = torchvision.models.video.swin3d_s(\n", + " weights=Swin3D_S_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model\n", + "\n", + "def torch_vswin_base():\n", + " torch_model = torchvision.models.video.swin3d_b(\n", + " weights=Swin3D_B_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "def75d23", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:39:36.250575Z", + "iopub.status.busy": "2024-03-25T17:39:36.249663Z", + "iopub.status.idle": "2024-03-25T17:40:00.880847Z", + "shell.execute_reply": "2024-03-25T17:40:00.879762Z" + }, + "papermill": { + "duration": 24.642887, + "end_time": "2024-03-25T17:40:00.883098", + "exception": false, + "start_time": "2024-03-25T17:39:36.240211", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/swin3d_t-7615ae03.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth\n", + "100%|██████████| 122M/122M [00:00<00:00, 146MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_s-da41c237.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth\n", + "100%|██████████| 218M/218M [00:01<00:00, 144MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_b_1k-24f7c7c6.pth\n", + "100%|██████████| 364M/364M [00:09<00:00, 38.6MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "=========================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "=========================================================================================================\n", + "SwinTransformer3d [1, 400] --\n", + "├─PatchEmbed3d: 1-1 [1, 16, 56, 56, 96] --\n", + "│ └─Conv3d: 2-1 [1, 96, 16, 56, 56] 9,312\n", + "│ └─LayerNorm: 2-2 [1, 16, 56, 56, 96] 192\n", + "├─Dropout: 1-2 [1, 16, 56, 56, 96] --\n", + "├─Sequential: 1-3 [1, 16, 7, 7, 768] --\n", + "│ └─Sequential: 2-3 [1, 16, 56, 56, 96] --\n", + "│ │ └─SwinTransformerBlock: 3-1 [1, 16, 56, 56, 96] 119,445\n", + "│ │ └─SwinTransformerBlock: 3-2 [1, 16, 56, 56, 96] 119,445\n", + "│ └─PatchMerging: 2-4 [1, 16, 28, 28, 192] --\n", + "│ │ └─LayerNorm: 3-3 [1, 16, 28, 28, 384] 768\n", + "│ │ └─Linear: 3-4 [1, 16, 28, 28, 192] 73,728\n", + "│ └─Sequential: 2-5 [1, 16, 28, 28, 192] --\n", + "│ │ └─SwinTransformerBlock: 3-5 [1, 16, 28, 28, 192] 460,074\n", + "│ │ └─SwinTransformerBlock: 3-6 [1, 16, 28, 28, 192] 460,074\n", + "│ └─PatchMerging: 2-6 [1, 16, 14, 14, 384] --\n", + "│ │ └─LayerNorm: 3-7 [1, 16, 14, 14, 768] 1,536\n", + "│ │ └─Linear: 3-8 [1, 16, 14, 14, 384] 294,912\n", + "│ └─Sequential: 2-7 [1, 16, 14, 14, 384] --\n", + "│ │ └─SwinTransformerBlock: 3-9 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-10 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-11 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-12 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-13 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-14 [1, 16, 14, 14, 384] 1,804,884\n", + "│ └─PatchMerging: 2-8 [1, 16, 7, 7, 768] --\n", + "│ │ └─LayerNorm: 3-15 [1, 16, 7, 7, 1536] 3,072\n", + "│ │ └─Linear: 3-16 [1, 16, 7, 7, 768] 1,179,648\n", + "│ └─Sequential: 2-9 [1, 16, 7, 7, 768] --\n", + "│ │ └─SwinTransformerBlock: 3-17 [1, 16, 7, 7, 768] 7,148,712\n", + "│ │ └─SwinTransformerBlock: 3-18 [1, 16, 7, 7, 768] 7,148,712\n", + "├─LayerNorm: 1-4 [1, 16, 7, 7, 768] 1,536\n", + "├─AdaptiveAvgPool3d: 1-5 [1, 768, 1, 1, 1] --\n", + "├─Linear: 1-6 [1, 400] 307,600\n", + "=========================================================================================================\n", + "Total params: 28,158,070\n", + "Trainable params: 28,158,070\n", + "Non-trainable params: 0\n", + "Total mult-adds (M): 486.39\n", + "=========================================================================================================\n", + "Input size (MB): 19.27\n", + "Forward/backward pass size (MB): 1464.34\n", + "Params size (MB): 76.66\n", + "Estimated Total Size (MB): 1560.26\n", + "=========================================================================================================" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_models = [torch_vswin_tiny(), torch_vswin_small(), torch_vswin_base()]\n", + "summary(\n", + " torch_models[0], input_size=(1, 3, 32, 224, 224)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "756df6fc", + "metadata": { + "papermill": { + "duration": 0.015792, + "end_time": "2024-03-25T17:40:00.914650", + "exception": false, + "start_time": "2024-03-25T17:40:00.898858", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "746bff71", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:40:00.948422Z", + "iopub.status.busy": "2024-03-25T17:40:00.948038Z", + "iopub.status.idle": "2024-03-25T17:40:01.109399Z", + "shell.execute_reply": "2024-03-25T17:40:01.108088Z" + }, + "papermill": { + "duration": 0.180746, + "end_time": "2024-03-25T17:40:01.111946", + "exception": false, + "start_time": "2024-03-25T17:40:00.931200", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])\n" + ] + } + ], + "source": [ + "common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')\n", + "keras_input = ops.array(common_input)\n", + "torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))\n", + "print(keras_input.shape, torch_input.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4defac0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:40:01.146510Z", + "iopub.status.busy": "2024-03-25T17:40:01.146105Z", + "iopub.status.idle": "2024-03-25T17:40:01.152185Z", + "shell.execute_reply": "2024-03-25T17:40:01.151091Z" + }, + "papermill": { + "duration": 0.025902, + "end_time": "2024-03-25T17:40:01.154320", + "exception": false, + "start_time": "2024-03-25T17:40:01.128418", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def logit_checking(keras_model, torch_model):\n", + " # forward pass\n", + " keras_predict = keras_model(keras_input)\n", + " torch_predict = torch_model(torch_input)\n", + " \n", + " print(keras_predict.shape, torch_predict.shape)\n", + " print('keras logits: ', keras_predict[0, :5])\n", + " print('torch logits: ', torch_predict[0, :5], end='\\n')\n", + " \n", + " np.testing.assert_allclose(\n", + " keras_predict.detach().numpy(),\n", + " torch_predict.detach().numpy(),\n", + " 1e-5, 1e-5\n", + " )\n", + "\n", + " del keras_model \n", + " del torch_model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "33166bae", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-25T17:40:01.187535Z", + "iopub.status.busy": "2024-03-25T17:40:01.186732Z", + "iopub.status.idle": "2024-03-25T17:41:29.704036Z", + "shell.execute_reply": "2024-03-25T17:41:29.702590Z" + }, + "papermill": { + "duration": 88.536879, + "end_time": "2024-03-25T17:41:29.706850", + "exception": false, + "start_time": "2024-03-25T17:40:01.169971", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 400]) torch.Size([1, 400])\n", + "keras logits: tensor([-0.1536, 1.2143, 1.1613, -0.3310, -1.5233], grad_fn=