diff --git a/guides/jax_videoswin_video_classification.ipynb b/guides/jax_videoswin_video_classification.ipynb index a14569b..12e8f0b 100644 --- a/guides/jax_videoswin_video_classification.ipynb +++ b/guides/jax_videoswin_video_classification.ipynb @@ -165,6 +165,14 @@ "HOME" ] }, + { + "cell_type": "markdown", + "id": "d3ae2357", + "metadata": {}, + "source": [ + "It supports both single and multi-gpu training." + ] + }, { "cell_type": "code", "execution_count": 4, @@ -1242,11 +1250,11 @@ "description": "", "description_tooltip": null, "layout": "IPY_MODEL_784aa4850b694bc8b1bde38e2ff2fe2a", - "max": 532195102.0, - "min": 0.0, + "max": 532195102, + "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_1ca67d7949b54e7c844c153980f2723d", - "value": 532195102.0 + "value": 532195102 } }, "1ca67d7949b54e7c844c153980f2723d": { diff --git a/guides/k600-ssv2-logit-matching-torch-vs-keras-cv-backbone.ipynb b/guides/k600-ssv2-logit-matching-torch-vs-keras-cv-backbone.ipynb new file mode 100644 index 0000000..57043e5 --- /dev/null +++ b/guides/k600-ssv2-logit-matching-torch-vs-keras-cv-backbone.ipynb @@ -0,0 +1,1857 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b71a4667", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-03-31T16:27:53.343125Z", + "iopub.status.busy": "2024-03-31T16:27:53.341887Z", + "iopub.status.idle": "2024-03-31T16:28:15.976258Z", + "shell.execute_reply": "2024-03-31T16:28:15.975233Z" + }, + "papermill": { + "duration": 22.646472, + "end_time": "2024-03-31T16:28:15.978789", + "exception": false, + "start_time": "2024-03-31T16:27:53.332317", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install einops -q\n", + "import logging\n", + "from functools import partial\n", + "import gc\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.utils.checkpoint as checkpoint\n", + "import numpy as np\n", + "from timm.models.layers import trunc_normal_ \n", + "\n", + "from functools import reduce, lru_cache\n", + "from operator import mul\n", + "from einops import rearrange\n", + "import logging" + ] + }, + { + "cell_type": "markdown", + "id": "18d5bd06", + "metadata": { + "papermill": { + "duration": 0.006646, + "end_time": "2024-03-31T16:28:15.992867", + "exception": false, + "start_time": "2024-03-31T16:28:15.986221", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin Model [PyTorch]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "000c7886", + "metadata": { + "_kg_hide-input": true, + "execution": { + "iopub.execute_input": "2024-03-31T16:28:16.009192Z", + "iopub.status.busy": "2024-03-31T16:28:16.008839Z", + "iopub.status.idle": "2024-03-31T16:28:16.095208Z", + "shell.execute_reply": "2024-03-31T16:28:16.093560Z" + }, + "jupyter": { + "source_hidden": true + }, + "papermill": { + "duration": 0.098049, + "end_time": "2024-03-31T16:28:16.098022", + "exception": false, + "start_time": "2024-03-31T16:28:15.999973", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n", + " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n", + " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n", + " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n", + " 'survival rate' as the argument.\n", + " \"\"\"\n", + " if drop_prob == 0. or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n", + " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", + " if keep_prob > 0.0 and scale_by_keep:\n", + " random_tensor.div_(keep_prob)\n", + " return x * random_tensor\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " \"\"\"\n", + " def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + " self.scale_by_keep = scale_by_keep\n", + "\n", + " def forward(self, x):\n", + " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", + "\n", + " def extra_repr(self):\n", + " return f'drop_prob={round(self.drop_prob,3):0.3f}'\n", + " \n", + " \n", + "def get_root_logger(log_file=None, log_level=logging.INFO):\n", + " \"\"\"Use ``get_logger`` method in mmcv to get the root logger.\n", + " The logger will be initialized if it has not been initialized. By default a\n", + " StreamHandler will be added. If ``log_file`` is specified, a FileHandler\n", + " will also be added. The name of the root logger is the top-level package\n", + " name, e.g., \"mmaction\".\n", + " Args:\n", + " log_file (str | None): The log filename. If specified, a FileHandler\n", + " will be added to the root logger.\n", + " log_level (int): The root logger level. Note that only the process of\n", + " rank 0 is affected, while other processes will set the level to\n", + " \"Error\" and be silent most of the time.\n", + " Returns:\n", + " :obj:`logging.Logger`: The root logger.\n", + " \"\"\"\n", + " return get_logger(__name__.split('.')[0], log_file, log_level)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\" Multilayer perceptron.\"\"\"\n", + "\n", + " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", + " super().__init__()\n", + " out_features = out_features or in_features\n", + " hidden_features = hidden_features or in_features\n", + " self.fc1 = nn.Linear(in_features, hidden_features)\n", + " self.act = act_layer()\n", + " self.fc2 = nn.Linear(hidden_features, out_features)\n", + " self.drop = nn.Dropout(drop)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.act(x)\n", + " x = self.drop(x)\n", + " x = self.fc2(x)\n", + " x = self.drop(x)\n", + " return x\n", + " \n", + " \n", + "def window_partition(x, window_size):\n", + " \"\"\"\n", + " Args:\n", + " x: (B, D, H, W, C)\n", + " window_size (tuple[int]): window size\n", + "\n", + " Returns:\n", + " windows: (B*num_windows, window_size*window_size, C)\n", + " \"\"\"\n", + " B, D, H, W, C = x.shape\n", + " x = x.view(\n", + " B, \n", + " D // window_size[0], window_size[0], \n", + " H // window_size[1], window_size[1], \n", + " W // window_size[2], window_size[2], \n", + " C\n", + " )\n", + " windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)\n", + " return windows\n", + "\n", + "\n", + "def window_reverse(windows, window_size, B, D, H, W):\n", + " \"\"\"\n", + " Args:\n", + " windows: (B*num_windows, window_size, window_size, C)\n", + " window_size (tuple[int]): Window size\n", + " H (int): Height of image\n", + " W (int): Width of image\n", + "\n", + " Returns:\n", + " x: (B, D, H, W, C)\n", + " \"\"\"\n", + " x = windows.view(\n", + " B, \n", + " D // window_size[0], \n", + " H // window_size[1], \n", + " W // window_size[2], \n", + " window_size[0], \n", + " window_size[1], \n", + " window_size[2], \n", + " -1\n", + " )\n", + " x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)\n", + " return x\n", + "\n", + "\n", + "def get_window_size(x_size, window_size, shift_size=None):\n", + " use_window_size = list(window_size)\n", + " if shift_size is not None:\n", + " use_shift_size = list(shift_size)\n", + " for i in range(len(x_size)):\n", + " if x_size[i] <= window_size[i]:\n", + " use_window_size[i] = x_size[i]\n", + " if shift_size is not None:\n", + " use_shift_size[i] = 0\n", + " if shift_size is None:\n", + " return tuple(use_window_size)\n", + " else:\n", + " return tuple(use_window_size), tuple(use_shift_size)\n", + " \n", + " \n", + "class WindowAttention3D(nn.Module):\n", + " \"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n", + " It supports both of shifted and non-shifted window.\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " window_size (tuple[int]): The temporal length, height and width of the window.\n", + " num_heads (int): Number of attention heads.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n", + " attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n", + " proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n", + "\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.window_size = window_size # Wd, Wh, Ww\n", + " self.num_heads = num_heads\n", + " head_dim = dim // num_heads\n", + " self.scale = qk_scale or head_dim ** -0.5\n", + "\n", + " # define a parameter table of relative position bias\n", + " self.relative_position_bias_table = nn.Parameter(\n", + " torch.zeros(\n", + " (2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)\n", + " ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH\n", + "\n", + " # get pair-wise relative position index for each token inside the window\n", + " coords_d = torch.arange(self.window_size[0])\n", + " coords_h = torch.arange(self.window_size[1])\n", + " coords_w = torch.arange(self.window_size[2])\n", + " coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww\n", + " coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww\n", + " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww\n", + " \n", + " relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3\n", + " relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0\n", + " relative_coords[:, :, 1] += self.window_size[1] - 1\n", + " relative_coords[:, :, 2] += self.window_size[2] - 1\n", + " relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)\n", + " relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)\n", + " relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww\n", + "\n", + " self.register_buffer(\"relative_position_index\", relative_position_index)\n", + " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", + " self.attn_drop = nn.Dropout(attn_drop)\n", + " self.proj = nn.Linear(dim, dim)\n", + " self.proj_drop = nn.Dropout(proj_drop)\n", + "\n", + " trunc_normal_(self.relative_position_bias_table, std=.02)\n", + " self.softmax = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, x, mask=None):\n", + " \"\"\" Forward function.\n", + " Args:\n", + " x: input features with shape of (num_windows*B, N, C)\n", + " mask: (0/-inf) mask with shape of (num_windows, N, N) or None\n", + " \"\"\"\n", + " B_, N, C = x.shape\n", + " qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C\n", + " q = q * self.scale\n", + " attn = q @ k.transpose(-2, -1)\n", + "\n", + " relative_position_bias = self.relative_position_bias_table[\n", + " self.relative_position_index[:N, :N].reshape(-1)\n", + " ].reshape(N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH\n", + " relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww\n", + " attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N\n", + " \n", + " if mask is not None:\n", + " nW = mask.shape[0]\n", + " attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n", + " attn = attn.view(-1, self.num_heads, N, N)\n", + " attn = self.softmax(attn)\n", + " else:\n", + " attn = self.softmax(attn)\n", + "\n", + " attn = self.attn_drop(attn)\n", + " x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n", + " x = self.proj(x)\n", + " x = self.proj_drop(x)\n", + " \n", + " return x\n", + " \n", + " \n", + "class SwinTransformerBlock3D(nn.Module):\n", + " \"\"\" Swin Transformer Block.\n", + "\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " num_heads (int): Number of attention heads.\n", + " window_size (tuple[int]): Window size.\n", + " shift_size (tuple[int]): Shift size for SW-MSA.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float, optional): Stochastic depth rate. Default: 0.0\n", + " act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),\n", + " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n", + " act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " self.window_size = window_size\n", + " self.shift_size = shift_size\n", + " self.mlp_ratio = mlp_ratio\n", + " self.use_checkpoint=use_checkpoint\n", + "\n", + " assert 0 <= self.shift_size[0] < self.window_size[0], \"shift_size must in 0-window_size\"\n", + " assert 0 <= self.shift_size[1] < self.window_size[1], \"shift_size must in 0-window_size\"\n", + " assert 0 <= self.shift_size[2] < self.window_size[2], \"shift_size must in 0-window_size\"\n", + "\n", + " self.norm1 = norm_layer(dim)\n", + " self.attn = WindowAttention3D(\n", + " dim, window_size=self.window_size, num_heads=num_heads,\n", + " qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n", + "\n", + " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", + " self.norm2 = norm_layer(dim)\n", + " mlp_hidden_dim = int(dim * mlp_ratio)\n", + " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n", + "\n", + " def forward_part1(self, x, mask_matrix):\n", + " B, D, H, W, C = x.shape\n", + " window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)\n", + "\n", + " x = self.norm1(x)\n", + " # pad feature maps to multiples of window size\n", + " pad_l = pad_t = pad_d0 = 0\n", + " pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]\n", + " pad_b = (window_size[1] - H % window_size[1]) % window_size[1]\n", + " pad_r = (window_size[2] - W % window_size[2]) % window_size[2]\n", + " x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))\n", + " _, Dp, Hp, Wp, _ = x.shape\n", + " # cyclic shift\n", + " if any(i > 0 for i in shift_size):\n", + " shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))\n", + " attn_mask = mask_matrix\n", + " else:\n", + " shifted_x = x\n", + " attn_mask = None\n", + " # partition windows\n", + " x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C\n", + " # W-MSA/SW-MSA\n", + " attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C\n", + " # merge windows\n", + " attn_windows = attn_windows.view(-1, *(window_size+(C,)))\n", + " shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C\n", + " # reverse cyclic shift\n", + " if any(i > 0 for i in shift_size):\n", + " x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))\n", + " else:\n", + " x = shifted_x\n", + "\n", + " if pad_d1 >0 or pad_r > 0 or pad_b > 0:\n", + " x = x[:, :D, :H, :W, :].contiguous()\n", + " return x\n", + "\n", + " def forward_part2(self, x):\n", + " return self.drop_path(self.mlp(self.norm2(x)))\n", + "\n", + " def forward(self, x, mask_matrix):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, D, H, W, C).\n", + " mask_matrix: Attention mask for cyclic shift.\n", + " \"\"\"\n", + " \n", + " shortcut = x\n", + " if self.use_checkpoint:\n", + " x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)\n", + " else:\n", + " x = self.forward_part1(x, mask_matrix)\n", + "\n", + " x = shortcut + self.drop_path(x)\n", + "\n", + " if self.use_checkpoint:\n", + " x = x + checkpoint.checkpoint(self.forward_part2, x)\n", + " else:\n", + " x = x + self.forward_part2(x)\n", + "\n", + " return x\n", + " \n", + " \n", + "class PatchMerging(nn.Module):\n", + " \"\"\" Patch Merging Layer\n", + "\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + " def __init__(self, dim, norm_layer=nn.LayerNorm):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n", + " self.norm = norm_layer(4 * dim)\n", + "\n", + " def forward(self, x):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, D, H, W, C).\n", + " \"\"\"\n", + " B, D, H, W, C = x.shape\n", + "\n", + " # padding\n", + " pad_input = (H % 2 == 1) or (W % 2 == 1)\n", + " if pad_input:\n", + " x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n", + "\n", + " x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C\n", + " x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C\n", + " x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C\n", + " x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C\n", + " x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C\n", + "\n", + " x = self.norm(x)\n", + " x = self.reduction(x)\n", + "\n", + " return x\n", + " \n", + " \n", + "def compute_mask(D, H, W, window_size, shift_size, device):\n", + " img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1\n", + " cnt = 0\n", + " for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):\n", + " for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):\n", + " for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):\n", + " img_mask[:, d, h, w, :] = cnt\n", + " cnt += 1\n", + " mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1\n", + " mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]\n", + " attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n", + " attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n", + " return attn_mask\n", + "\n", + "class BasicLayer(nn.Module):\n", + " \"\"\" A basic Swin Transformer layer for one stage.\n", + "\n", + " Args:\n", + " dim (int): Number of feature channels\n", + " depth (int): Depths of this stage.\n", + " num_heads (int): Number of attention head.\n", + " window_size (tuple[int]): Local window size. Default: (1,7,7).\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " dim,\n", + " depth,\n", + " num_heads,\n", + " window_size=(1,7,7),\n", + " mlp_ratio=4.,\n", + " qkv_bias=False,\n", + " qk_scale=None,\n", + " drop=0.,\n", + " attn_drop=0.,\n", + " drop_path=0.,\n", + " norm_layer=nn.LayerNorm,\n", + " downsample=None,\n", + " use_checkpoint=False):\n", + " super().__init__()\n", + " self.window_size = window_size\n", + " self.shift_size = tuple(i // 2 for i in window_size)\n", + " self.depth = depth\n", + " self.use_checkpoint = use_checkpoint\n", + "\n", + " # build blocks\n", + " self.blocks = nn.ModuleList([\n", + " SwinTransformerBlock3D(\n", + " dim=dim,\n", + " num_heads=num_heads,\n", + " window_size=window_size,\n", + " shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,\n", + " mlp_ratio=mlp_ratio,\n", + " qkv_bias=qkv_bias,\n", + " qk_scale=qk_scale,\n", + " drop=drop,\n", + " attn_drop=attn_drop,\n", + " drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n", + " norm_layer=norm_layer,\n", + " use_checkpoint=use_checkpoint,\n", + " )\n", + " for i in range(depth)])\n", + " \n", + " self.downsample = downsample\n", + " if self.downsample is not None:\n", + " self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n", + "\n", + " def forward(self, x):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, C, D, H, W).\n", + " \"\"\"\n", + " # calculate attention mask for SW-MSA\n", + " B, C, D, H, W = x.shape\n", + " window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)\n", + " x = rearrange(x, 'b c d h w -> b d h w c')\n", + " Dp = int(np.ceil(D / window_size[0])) * window_size[0]\n", + " Hp = int(np.ceil(H / window_size[1])) * window_size[1]\n", + " Wp = int(np.ceil(W / window_size[2])) * window_size[2]\n", + " attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)\n", + " \n", + "\n", + " for blk in self.blocks:\n", + " x = blk(x, attn_mask)\n", + " x = x.view(B, D, H, W, -1)\n", + "\n", + " if self.downsample is not None:\n", + " x = self.downsample(x)\n", + " \n", + " x = rearrange(x, 'b d h w c -> b c d h w')\n", + " return x\n", + " \n", + "class PatchEmbed3D(nn.Module):\n", + " \"\"\" Video to Patch Embedding.\n", + "\n", + " Args:\n", + " patch_size (int): Patch token size. Default: (2,4,4).\n", + " in_chans (int): Number of input video channels. Default: 3.\n", + " embed_dim (int): Number of linear projection output channels. Default: 96.\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: None\n", + " \"\"\"\n", + " def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + "\n", + " self.in_chans = in_chans\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n", + " if norm_layer is not None:\n", + " self.norm = norm_layer(embed_dim)\n", + " else:\n", + " self.norm = None\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward function.\"\"\"\n", + " # padding\n", + " _, _, D, H, W = x.size()\n", + " if W % self.patch_size[2] != 0:\n", + " x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))\n", + " if H % self.patch_size[1] != 0:\n", + " x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))\n", + " if D % self.patch_size[0] != 0:\n", + " x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))\n", + " \n", + " x = self.proj(x) # B C D Wh Ww\n", + "\n", + " if self.norm is not None:\n", + " D, Wh, Ww = x.size(2), x.size(3), x.size(4)\n", + " x = x.flatten(2).transpose(1, 2)\n", + " x = self.norm(x)\n", + " x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)\n", + "\n", + " return x\n", + " \n", + " \n", + "class SwinTransformer3D(nn.Module):\n", + " \"\"\" Swin Transformer backbone.\n", + " A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -\n", + " https://arxiv.org/pdf/2103.14030\n", + "\n", + " Args:\n", + " patch_size (int | tuple(int)): Patch size. Default: (4,4,4).\n", + " in_chans (int): Number of input image channels. Default: 3.\n", + " embed_dim (int): Number of linear projection output channels. Default: 96.\n", + " depths (tuple[int]): Depths of each Swin Transformer stage.\n", + " num_heads (tuple[int]): Number of attention head of each stage.\n", + " window_size (int): Window size. Default: 7.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n", + " qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee\n", + " qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop_rate (float): Dropout rate.\n", + " attn_drop_rate (float): Attention dropout rate. Default: 0.\n", + " drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n", + " norm_layer: Normalization layer. Default: nn.LayerNorm.\n", + " patch_norm (bool): If True, add normalization after patch embedding. Default: False.\n", + " frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n", + " -1 means not freezing any parameters.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " pretrained=None,\n", + " pretrained2d=True,\n", + " patch_size=(4,4,4),\n", + " in_chans=3,\n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " window_size=(2,7,7),\n", + " mlp_ratio=4.,\n", + " qkv_bias=True,\n", + " qk_scale=None,\n", + " drop_rate=0.,\n", + " attn_drop_rate=0.,\n", + " drop_path_rate=0.2,\n", + " norm_layer=nn.LayerNorm,\n", + " patch_norm=False,\n", + " frozen_stages=-1,\n", + " use_checkpoint=False,\n", + " \n", + " # class head\n", + " spatial_type='avg',\n", + " in_channels=768,\n", + " num_classes=400,\n", + " dropout_ratio=0.5 # to do check: no dropout layer in weight state\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.pretrained = pretrained\n", + " self.pretrained2d = pretrained2d\n", + " self.num_layers = len(depths)\n", + " self.embed_dim = embed_dim\n", + " self.patch_norm = patch_norm\n", + " self.frozen_stages = frozen_stages\n", + " self.window_size = window_size\n", + " self.patch_size = patch_size\n", + "\n", + " # split image into non-overlapping patches\n", + " self.patch_embed = PatchEmbed3D(\n", + " patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n", + " norm_layer=norm_layer if self.patch_norm else None)\n", + "\n", + " self.pos_drop = nn.Dropout(p=drop_rate)\n", + "\n", + " # stochastic depth\n", + " dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule\n", + "\n", + " # build layers\n", + " self.layers = nn.ModuleList()\n", + " for i_layer in range(self.num_layers):\n", + " layer = BasicLayer(\n", + " dim=int(embed_dim * 2**i_layer),\n", + " depth=depths[i_layer],\n", + " num_heads=num_heads[i_layer],\n", + " window_size=window_size,\n", + " mlp_ratio=mlp_ratio,\n", + " qkv_bias=qkv_bias,\n", + " qk_scale=qk_scale,\n", + " drop=drop_rate,\n", + " attn_drop=attn_drop_rate,\n", + " drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n", + " norm_layer=norm_layer,\n", + " downsample=PatchMerging if i_layer= 0:\n", + " self.patch_embed.eval()\n", + " for param in self.patch_embed.parameters():\n", + " param.requires_grad = False\n", + "\n", + " if self.frozen_stages >= 1:\n", + " self.pos_drop.eval()\n", + " for i in range(0, self.frozen_stages):\n", + " m = self.layers[i]\n", + " m.eval()\n", + " for param in m.parameters():\n", + " param.requires_grad = False\n", + "\n", + " def inflate_weights(self, logger):\n", + " \"\"\"Inflate the swin2d parameters to swin3d.\n", + "\n", + " The differences between swin3d and swin2d mainly lie in an extra\n", + " axis. To utilize the pretrained parameters in 2d model,\n", + " the weight of swin2d models should be inflated to fit in the shapes of\n", + " the 3d counterpart.\n", + "\n", + " Args:\n", + " logger (logging.Logger): The logger used to print\n", + " debugging infomation.\n", + " \"\"\"\n", + " checkpoint = torch.load(self.pretrained, map_location='cpu')\n", + " state_dict = checkpoint['model']\n", + "\n", + " # delete relative_position_index since we always re-init it\n", + " relative_position_index_keys = [k for k in state_dict.keys() if \"relative_position_index\" in k]\n", + " for k in relative_position_index_keys:\n", + " del state_dict[k]\n", + "\n", + " # delete attn_mask since we always re-init it\n", + " attn_mask_keys = [k for k in state_dict.keys() if \"attn_mask\" in k]\n", + " for k in attn_mask_keys:\n", + " del state_dict[k]\n", + "\n", + " state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]\n", + "\n", + " # bicubic interpolate relative_position_bias_table if not match\n", + " relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n", + " for k in relative_position_bias_table_keys:\n", + " relative_position_bias_table_pretrained = state_dict[k]\n", + " relative_position_bias_table_current = self.state_dict()[k]\n", + " L1, nH1 = relative_position_bias_table_pretrained.size()\n", + " L2, nH2 = relative_position_bias_table_current.size()\n", + " L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)\n", + " wd = self.window_size[0]\n", + " if nH1 != nH2:\n", + " logger.warning(f\"Error in loading {k}, passing\")\n", + " else:\n", + " if L1 != L2:\n", + " S1 = int(L1 ** 0.5)\n", + " relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n", + " relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),\n", + " mode='bicubic')\n", + " relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n", + " state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)\n", + "\n", + " msg = self.load_state_dict(state_dict, strict=False)\n", + " logger.info(msg)\n", + " logger.info(f\"=> loaded successfully '{self.pretrained}'\")\n", + " del checkpoint\n", + " torch.cuda.empty_cache()\n", + "\n", + " def init_weights(self, pretrained=None):\n", + " \"\"\"Initialize the weights in backbone.\n", + "\n", + " Args:\n", + " pretrained (str, optional): Path to pre-trained weights.\n", + " Defaults to None.\n", + " \"\"\"\n", + " def _init_weights(m):\n", + " if isinstance(m, nn.Linear):\n", + " trunc_normal_(m.weight, std=.02)\n", + " if isinstance(m, nn.Linear) and m.bias is not None:\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.LayerNorm):\n", + " nn.init.constant_(m.bias, 0)\n", + " nn.init.constant_(m.weight, 1.0)\n", + "\n", + " if pretrained:\n", + " self.pretrained = pretrained\n", + " if isinstance(self.pretrained, str):\n", + " self.apply(_init_weights)\n", + " logger = get_root_logger()\n", + " logger.info(f'load model from: {self.pretrained}')\n", + "\n", + " if self.pretrained2d:\n", + " # Inflate 2D model into 3D model.\n", + " self.inflate_weights(logger)\n", + " else:\n", + " # Directly load 3D model.\n", + " load_checkpoint(self, self.pretrained, strict=False, logger=logger)\n", + " elif self.pretrained is None:\n", + " self.apply(_init_weights)\n", + " else:\n", + " raise TypeError('pretrained must be a str or None')\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward function.\"\"\"\n", + "\n", + " x = self.patch_embed(x)\n", + " x = self.pos_drop(x)\n", + "\n", + " for layer in self.layers:\n", + " x = layer(x.contiguous())\n", + " \n", + " x = rearrange(x, 'n c d h w -> n d h w c')\n", + " x = self.norm(x)\n", + " x = rearrange(x, 'n d h w c -> n c d h w')\n", + " return x\n", + " \n", + "\n", + " def train(self, mode=True):\n", + " \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n", + " super(SwinTransformer3D, self).train(mode)\n", + " self._freeze_stages()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c98f81e7", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:16.115247Z", + "iopub.status.busy": "2024-03-31T16:28:16.114872Z", + "iopub.status.idle": "2024-03-31T16:28:16.122078Z", + "shell.execute_reply": "2024-03-31T16:28:16.120287Z" + }, + "papermill": { + "duration": 0.018356, + "end_time": "2024-03-31T16:28:16.124734", + "exception": false, + "start_time": "2024-03-31T16:28:16.106378", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def video_swin_base(window_size=(8,7,7), **kwargs):\n", + " model = SwinTransformer3D(\n", + " patch_size=(2,4,4),\n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " window_size=window_size,\n", + " mlp_ratio=4.,\n", + " qkv_bias=True,\n", + " qk_scale=None,\n", + " drop_rate=0.,\n", + " attn_drop_rate=0.,\n", + " drop_path_rate=0.2,\n", + " norm_layer=nn.LayerNorm,\n", + " patch_norm=True,\n", + " in_channels=1024,\n", + " **kwargs\n", + " )\n", + " return model\n" + ] + }, + { + "cell_type": "markdown", + "id": "b4f2ba54", + "metadata": { + "papermill": { + "duration": 0.006702, + "end_time": "2024-03-31T16:28:16.138791", + "exception": false, + "start_time": "2024-03-31T16:28:16.132089", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin K600 PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d5b7dbf5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:16.155776Z", + "iopub.status.busy": "2024-03-31T16:28:16.155312Z", + "iopub.status.idle": "2024-03-31T16:28:17.291208Z", + "shell.execute_reply": "2024-03-31T16:28:17.289281Z" + }, + "papermill": { + "duration": 1.147943, + "end_time": "2024-03-31T16:28:17.293909", + "exception": false, + "start_time": "2024-03-31T16:28:16.145966", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + } + ], + "source": [ + "model_pt = video_swin_base(\n", + " window_size=(8,7,7), num_classes=600\n", + ")\n", + "model_pt.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "83bb8b13", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:17.310229Z", + "iopub.status.busy": "2024-03-31T16:28:17.309869Z", + "iopub.status.idle": "2024-03-31T16:28:19.168922Z", + "shell.execute_reply": "2024-03-31T16:28:19.167768Z" + }, + "papermill": { + "duration": 1.870111, + "end_time": "2024-03-31T16:28:19.171289", + "exception": false, + "start_time": "2024-03-31T16:28:17.301178", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:28:17-- https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics600_22k.pth\r\n", + "Resolving github.com (github.com)... 140.82.113.4\r\n", + "Connecting to github.com (github.com)|140.82.113.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/099f2980-d55e-11eb-8848-6616f5f65526?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162817Z&X-Amz-Expires=300&X-Amz-Signature=1ecb02f3ad86594d9680a52ff630b0668a6c2a170f2902add41aac346e8a286b&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window877_kinetics600_22k.pth&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:28:17-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/099f2980-d55e-11eb-8848-6616f5f65526?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162817Z&X-Amz-Expires=300&X-Amz-Signature=1ecb02f3ad86594d9680a52ff630b0668a6c2a170f2902add41aac346e8a286b&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window877_kinetics600_22k.pth&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 382579368 (365M) [application/octet-stream]\r\n", + "Saving to: 'checkpoint.pt'\r\n", + "\r\n", + "checkpoint.pt 100%[===================>] 364.86M 303MB/s in 1.2s \r\n", + "\r\n", + "2024-03-31 16:28:19 (303 MB/s) - 'checkpoint.pt' saved [382579368/382579368]\r\n", + "\r\n" + ] + } + ], + "source": [ + "base_url = \"https://github.com/SwinTransformer/storage/releases/download/v1.0.4/\"\n", + "checkpoints_pt = f\"{base_url}swin_base_patch244_window877_kinetics600_22k.pth\"\n", + "!wget {checkpoints_pt} -O checkpoint.pt" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ace994fc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:19.190480Z", + "iopub.status.busy": "2024-03-31T16:28:19.190060Z", + "iopub.status.idle": "2024-03-31T16:28:19.389671Z", + "shell.execute_reply": "2024-03-31T16:28:19.388429Z" + }, + "papermill": { + "duration": 0.21121, + "end_time": "2024-03-31T16:28:19.391988", + "exception": false, + "start_time": "2024-03-31T16:28:19.180778", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "state_dict = torch.load(\n", + " 'checkpoint.pt', map_location=\"cpu\"\n", + ")\n", + "state_dict = state_dict['state_dict']\n", + "state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}\n", + "state_dict = {k.replace('cls_head.', ''): v for k, v in state_dict.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5a81cc80", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:19.425761Z", + "iopub.status.busy": "2024-03-31T16:28:19.425410Z", + "iopub.status.idle": "2024-03-31T16:28:19.503955Z", + "shell.execute_reply": "2024-03-31T16:28:19.502975Z" + }, + "papermill": { + "duration": 0.097974, + "end_time": "2024-03-31T16:28:19.505852", + "exception": false, + "start_time": "2024-03-31T16:28:19.407878", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc_cls.weight', 'fc_cls.bias'])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_pt.load_state_dict(state_dict, strict=False) " + ] + }, + { + "cell_type": "markdown", + "id": "a282a412", + "metadata": { + "papermill": { + "duration": 0.007572, + "end_time": "2024-03-31T16:28:19.521247", + "exception": false, + "start_time": "2024-03-31T16:28:19.513675", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc_cls.weight', 'fc_cls.bias'])\n", + "\n", + "\n", + "Expected, we removed final layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0db2e374", + "metadata": { + "papermill": { + "duration": 0.007589, + "end_time": "2024-03-31T16:28:19.536897", + "exception": false, + "start_time": "2024-03-31T16:28:19.529308", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "d671cf7a", + "metadata": { + "papermill": { + "duration": 0.007869, + "end_time": "2024-03-31T16:28:19.552746", + "exception": false, + "start_time": "2024-03-31T16:28:19.544877", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin K600 [Keras CV]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "97a792d6", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:19.570195Z", + "iopub.status.busy": "2024-03-31T16:28:19.569809Z", + "iopub.status.idle": "2024-03-31T16:28:19.574972Z", + "shell.execute_reply": "2024-03-31T16:28:19.573575Z" + }, + "papermill": { + "duration": 0.016661, + "end_time": "2024-03-31T16:28:19.577193", + "exception": false, + "start_time": "2024-03-31T16:28:19.560532", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8097a3bb", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:19.594948Z", + "iopub.status.busy": "2024-03-31T16:28:19.594579Z", + "iopub.status.idle": "2024-03-31T16:28:47.664966Z", + "shell.execute_reply": "2024-03-31T16:28:47.663253Z" + }, + "papermill": { + "duration": 28.081995, + "end_time": "2024-03-31T16:28:47.667322", + "exception": false, + "start_time": "2024-03-31T16:28:19.585327", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13766, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1902/1902), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (764/764), done.\u001b[K\r\n", + "remote: Total 13766 (delta 1325), reused 1611 (delta 1122), pack-reused 11864\u001b[K\r\n", + "Receiving objects: 100% (13766/13766), 25.66 MiB | 26.81 MiB/s, done.\r\n", + "Resolving deltas: 100% (9760/9760), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9cd988de", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:28:47.688615Z", + "iopub.status.busy": "2024-03-31T16:28:47.688157Z", + "iopub.status.idle": "2024-03-31T16:29:06.875822Z", + "shell.execute_reply": "2024-03-31T16:29:06.874231Z" + }, + "papermill": { + "duration": 19.200845, + "end_time": "2024-03-31T16:29:06.877810", + "exception": false, + "start_time": "2024-03-31T16:28:47.676965", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-31 16:28:52.233542: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-03-31 16:28:52.233664: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-03-31 16:28:52.410813: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "data": { + "text/plain": [ + "'3.0.5'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "keras.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "663ffd78", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:06.899316Z", + "iopub.status.busy": "2024-03-31T16:29:06.898673Z", + "iopub.status.idle": "2024-03-31T16:29:08.582025Z", + "shell.execute_reply": "2024-03-31T16:29:08.580526Z" + }, + "papermill": { + "duration": 1.696839, + "end_time": "2024-03-31T16:29:08.584655", + "exception": false, + "start_time": "2024-03-31T16:29:06.887816", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:29:07-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.112.4\r\n", + "Connecting to github.com (github.com)|140.82.112.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/dde3749b-9dae-47e3-8b07-5ef1ef982a48?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162907Z&X-Amz-Expires=300&X-Amz-Signature=96eef6d9c80d98ae9c19cb6985f5c8511e36a16bc6e4d0f93f403d0a5cb5654c&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics600_imagenet22k.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:29:07-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/dde3749b-9dae-47e3-8b07-5ef1ef982a48?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162907Z&X-Amz-Expires=300&X-Amz-Signature=96eef6d9c80d98ae9c19cb6985f5c8511e36a16bc6e4d0f93f403d0a5cb5654c&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics600_imagenet22k.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 351381896 (335M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_kinetics600_imagenet22k.weights.h5'\r\n", + "\r\n", + "videoswin_base_kine 100%[===================>] 335.10M 305MB/s in 1.1s \r\n", + "\r\n", + "2024-03-31 16:29:08 (305 MB/s) - 'videoswin_base_kinetics600_imagenet22k.weights.h5' saved [351381896/351381896]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k.weights.h5" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "db6e9815", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:08.607646Z", + "iopub.status.busy": "2024-03-31T16:29:08.607201Z", + "iopub.status.idle": "2024-03-31T16:29:08.614148Z", + "shell.execute_reply": "2024-03-31T16:29:08.612947Z" + }, + "papermill": { + "duration": 0.021075, + "end_time": "2024-03-31T16:29:08.616312", + "exception": false, + "start_time": "2024-03-31T16:29:08.595237", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_base_kinetics600_imagenet22k.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "5300dedb", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:08.640284Z", + "iopub.status.busy": "2024-03-31T16:29:08.639902Z", + "iopub.status.idle": "2024-03-31T16:29:10.903494Z", + "shell.execute_reply": "2024-03-31T16:29:10.901498Z" + }, + "papermill": { + "duration": 2.279472, + "end_time": "2024-03-31T16:29:10.906342", + "exception": false, + "start_time": "2024-03-31T16:29:08.626870", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_ks = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "15599c39", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:10.928739Z", + "iopub.status.busy": "2024-03-31T16:29:10.928268Z", + "iopub.status.idle": "2024-03-31T16:29:10.936850Z", + "shell.execute_reply": "2024-03-31T16:29:10.935945Z" + }, + "papermill": { + "duration": 0.021784, + "end_time": "2024-03-31T16:29:10.938542", + "exception": false, + "start_time": "2024-03-31T16:29:10.916758", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch, number of params (M): 87.64\n", + "Keras, number of params (M): 87.64\n" + ] + } + ], + "source": [ + "n_parameters = sum(p.numel() for p in model_pt.parameters() if p.requires_grad)\n", + "print(\"PyTorch, number of params (M): %.2f\" % (n_parameters / 1.0e6))\n", + "n_parameters = model_ks.count_params()\n", + "print(\"Keras, number of params (M): %.2f\" % (n_parameters / 1.0e6))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5910a317", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:10.960763Z", + "iopub.status.busy": "2024-03-31T16:29:10.960337Z", + "iopub.status.idle": "2024-03-31T16:29:11.091219Z", + "shell.execute_reply": "2024-03-31T16:29:11.089796Z" + }, + "papermill": { + "duration": 0.145028, + "end_time": "2024-03-31T16:29:11.094030", + "exception": false, + "start_time": "2024-03-31T16:29:10.949002", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])\n" + ] + } + ], + "source": [ + "common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')\n", + "keras_input = ops.array(common_input)\n", + "torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))\n", + "print(keras_input.shape, torch_input.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "eacab339", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:11.122133Z", + "iopub.status.busy": "2024-03-31T16:29:11.120966Z", + "iopub.status.idle": "2024-03-31T16:29:11.129797Z", + "shell.execute_reply": "2024-03-31T16:29:11.128454Z" + }, + "papermill": { + "duration": 0.026546, + "end_time": "2024-03-31T16:29:11.132183", + "exception": false, + "start_time": "2024-03-31T16:29:11.105637", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def logit_checking(keras_model, torch_model):\n", + " # forward pass\n", + " keras_predict = keras_model(keras_input)\n", + " torch_predict = torch_model(torch_input).permute(0,2,3,4,1)\n", + " print(keras_predict.shape, torch_predict.shape)\n", + " np.testing.assert_allclose(\n", + " keras_predict.detach().numpy(),\n", + " torch_predict.detach().numpy(),\n", + " 1e-4, 1e-4\n", + " )\n", + " del keras_model \n", + " del torch_model\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "12b00ce4", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:11.155937Z", + "iopub.status.busy": "2024-03-31T16:29:11.155564Z", + "iopub.status.idle": "2024-03-31T16:29:31.486409Z", + "shell.execute_reply": "2024-03-31T16:29:31.484870Z" + }, + "papermill": { + "duration": 20.346445, + "end_time": "2024-03-31T16:29:31.489813", + "exception": false, + "start_time": "2024-03-31T16:29:11.143368", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " model_ks, model_pt\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39156fbb", + "metadata": { + "papermill": { + "duration": 0.010275, + "end_time": "2024-03-31T16:29:31.511215", + "exception": false, + "start_time": "2024-03-31T16:29:31.500940", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "221fce02", + "metadata": { + "papermill": { + "duration": 0.010258, + "end_time": "2024-03-31T16:29:31.532207", + "exception": false, + "start_time": "2024-03-31T16:29:31.521949", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Something Something V2 PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6714c913", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:31.555235Z", + "iopub.status.busy": "2024-03-31T16:29:31.554885Z", + "iopub.status.idle": "2024-03-31T16:29:32.602168Z", + "shell.execute_reply": "2024-03-31T16:29:32.600775Z" + }, + "papermill": { + "duration": 1.062215, + "end_time": "2024-03-31T16:29:32.604968", + "exception": false, + "start_time": "2024-03-31T16:29:31.542753", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_pt = video_swin_base(\n", + " window_size=(16,7,7), num_classes=174\n", + ")\n", + "model_pt.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "13c68885", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:32.632138Z", + "iopub.status.busy": "2024-03-31T16:29:32.631743Z", + "iopub.status.idle": "2024-03-31T16:29:35.072409Z", + "shell.execute_reply": "2024-03-31T16:29:35.070658Z" + }, + "papermill": { + "duration": 2.456138, + "end_time": "2024-03-31T16:29:35.074968", + "exception": false, + "start_time": "2024-03-31T16:29:32.618830", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:29:32-- https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window1677_sthv2.pth\r\n", + "Resolving github.com (github.com)... 140.82.114.4\r\n", + "Connecting to github.com (github.com)|140.82.114.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/20458080-d55e-11eb-9021-4730e624e0ea?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162933Z&X-Amz-Expires=300&X-Amz-Signature=978c7c4b6a25743f9500f0ec50fc31c41ef41856d255c0a3dad8261388391329&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window1677_sthv2.pth&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:29:33-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/20458080-d55e-11eb-9021-4730e624e0ea?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162933Z&X-Amz-Expires=300&X-Amz-Signature=978c7c4b6a25743f9500f0ec50fc31c41ef41856d255c0a3dad8261388391329&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window1677_sthv2.pth&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 473410081 (451M) [application/octet-stream]\r\n", + "Saving to: 'checkpoint.pt'\r\n", + "\r\n", + "checkpoint.pt 100%[===================>] 451.48M 262MB/s in 1.7s \r\n", + "\r\n", + "2024-03-31 16:29:34 (262 MB/s) - 'checkpoint.pt' saved [473410081/473410081]\r\n", + "\r\n" + ] + } + ], + "source": [ + "base_url = \"https://github.com/SwinTransformer/storage/releases/download/v1.0.4/\"\n", + "checkpoints_pt = f\"{base_url}swin_base_patch244_window1677_sthv2.pth\"\n", + "!wget {checkpoints_pt} -O checkpoint.pt" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a923aa94", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:35.099503Z", + "iopub.status.busy": "2024-03-31T16:29:35.099110Z", + "iopub.status.idle": "2024-03-31T16:29:35.285266Z", + "shell.execute_reply": "2024-03-31T16:29:35.283861Z" + }, + "papermill": { + "duration": 0.201555, + "end_time": "2024-03-31T16:29:35.287934", + "exception": false, + "start_time": "2024-03-31T16:29:35.086379", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "state_dict = torch.load(\n", + " 'checkpoint.pt', map_location=\"cpu\"\n", + ")\n", + "state_dict = state_dict['state_dict']\n", + "state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}\n", + "state_dict = {k.replace('cls_head.', ''): v for k, v in state_dict.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fe4eec0e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:35.313216Z", + "iopub.status.busy": "2024-03-31T16:29:35.312863Z", + "iopub.status.idle": "2024-03-31T16:29:35.412953Z", + "shell.execute_reply": "2024-03-31T16:29:35.411952Z" + }, + "papermill": { + "duration": 0.115572, + "end_time": "2024-03-31T16:29:35.415473", + "exception": false, + "start_time": "2024-03-31T16:29:35.299901", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc_cls.weight', 'fc_cls.bias'])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_pt.load_state_dict(state_dict, strict=False)" + ] + }, + { + "cell_type": "markdown", + "id": "46daa259", + "metadata": { + "papermill": { + "duration": 0.010798, + "end_time": "2024-03-31T16:29:35.438752", + "exception": false, + "start_time": "2024-03-31T16:29:35.427954", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Something Somethinb V2 KerasCV" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "65566f83", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:35.462922Z", + "iopub.status.busy": "2024-03-31T16:29:35.462545Z", + "iopub.status.idle": "2024-03-31T16:29:37.242928Z", + "shell.execute_reply": "2024-03-31T16:29:37.241555Z" + }, + "papermill": { + "duration": 1.795716, + "end_time": "2024-03-31T16:29:37.245565", + "exception": false, + "start_time": "2024-03-31T16:29:35.449849", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:29:35-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.114.4\r\n", + "Connecting to github.com (github.com)|140.82.114.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/9faf8cf2-0308-4a38-9022-0d227ece6073?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162935Z&X-Amz-Expires=300&X-Amz-Signature=d90abf874f38fe7b07b9464b483f46bb54da77d8dde6256b1efb369af77b1896&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_something_something_v2.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:29:35-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/9faf8cf2-0308-4a38-9022-0d227ece6073?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T162935Z&X-Amz-Expires=300&X-Amz-Signature=d90abf874f38fe7b07b9464b483f46bb54da77d8dde6256b1efb369af77b1896&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_something_something_v2.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 355448712 (339M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_something_something_v2.weights.h5'\r\n", + "\r\n", + "videoswin_base_some 100%[===================>] 338.98M 313MB/s in 1.1s \r\n", + "\r\n", + "2024-03-31 16:29:37 (313 MB/s) - 'videoswin_base_something_something_v2.weights.h5' saved [355448712/355448712]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2.weights.h5\n", + "\n", + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " window_size=[16, 7, 7],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_base_something_something_v2.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "b74004a9", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:37.273090Z", + "iopub.status.busy": "2024-03-31T16:29:37.272409Z", + "iopub.status.idle": "2024-03-31T16:29:39.469377Z", + "shell.execute_reply": "2024-03-31T16:29:39.467493Z" + }, + "papermill": { + "duration": 2.213317, + "end_time": "2024-03-31T16:29:39.472153", + "exception": false, + "start_time": "2024-03-31T16:29:37.258836", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_ks = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "836a50fe", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:29:39.499550Z", + "iopub.status.busy": "2024-03-31T16:29:39.499121Z", + "iopub.status.idle": "2024-03-31T16:30:20.045949Z", + "shell.execute_reply": "2024-03-31T16:30:20.043774Z" + }, + "papermill": { + "duration": 40.564324, + "end_time": "2024-03-31T16:30:20.049374", + "exception": false, + "start_time": "2024-03-31T16:29:39.485050", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " model_ks, model_pt\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e72ae4bd", + "metadata": { + "papermill": { + "duration": 0.012143, + "end_time": "2024-03-31T16:30:20.074137", + "exception": false, + "start_time": "2024-03-31T16:30:20.061994", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kaggle": { + "accelerator": "none", + "dataSources": [ + { + "modelInstanceId": 17533, + "sourceId": 21184, + "sourceType": "modelInstanceVersion" + } + ], + "dockerImageVersionId": 30673, + "isGpuEnabled": false, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 153.194598, + "end_time": "2024-03-31T16:30:23.699815", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-03-31T16:27:50.505217", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/guides/k600-ssv2-logit-matching-torch-vs-keras-cv.ipynb b/guides/k600-ssv2-logit-matching-torch-vs-keras-cv.ipynb new file mode 100644 index 0000000..7ec868d --- /dev/null +++ b/guides/k600-ssv2-logit-matching-torch-vs-keras-cv.ipynb @@ -0,0 +1,1870 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "021723d0", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-03-31T16:04:36.141644Z", + "iopub.status.busy": "2024-03-31T16:04:36.140915Z", + "iopub.status.idle": "2024-03-31T16:04:56.916989Z", + "shell.execute_reply": "2024-03-31T16:04:56.915927Z" + }, + "papermill": { + "duration": 20.78853, + "end_time": "2024-03-31T16:04:56.919508", + "exception": false, + "start_time": "2024-03-31T16:04:36.130978", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install einops -q\n", + "import logging\n", + "from functools import partial\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.utils.checkpoint as checkpoint\n", + "import numpy as np\n", + "from timm.models.layers import trunc_normal_ \n", + "\n", + "from functools import reduce, lru_cache\n", + "from operator import mul\n", + "from einops import rearrange\n", + "import logging" + ] + }, + { + "cell_type": "markdown", + "id": "54727125", + "metadata": { + "papermill": { + "duration": 0.007894, + "end_time": "2024-03-31T16:04:56.935360", + "exception": false, + "start_time": "2024-03-31T16:04:56.927466", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin Model [PyTorch]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ea2bdf5b", + "metadata": { + "_kg_hide-input": true, + "execution": { + "iopub.execute_input": "2024-03-31T16:04:56.953573Z", + "iopub.status.busy": "2024-03-31T16:04:56.953176Z", + "iopub.status.idle": "2024-03-31T16:04:57.046559Z", + "shell.execute_reply": "2024-03-31T16:04:57.045683Z" + }, + "jupyter": { + "source_hidden": true + }, + "papermill": { + "duration": 0.105822, + "end_time": "2024-03-31T16:04:57.049010", + "exception": false, + "start_time": "2024-03-31T16:04:56.943188", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n", + " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n", + " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n", + " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n", + " 'survival rate' as the argument.\n", + " \"\"\"\n", + " if drop_prob == 0. or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n", + " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", + " if keep_prob > 0.0 and scale_by_keep:\n", + " random_tensor.div_(keep_prob)\n", + " return x * random_tensor\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " \"\"\"\n", + " def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + " self.scale_by_keep = scale_by_keep\n", + "\n", + " def forward(self, x):\n", + " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", + "\n", + " def extra_repr(self):\n", + " return f'drop_prob={round(self.drop_prob,3):0.3f}'\n", + " \n", + " \n", + "def get_root_logger(log_file=None, log_level=logging.INFO):\n", + " \"\"\"Use ``get_logger`` method in mmcv to get the root logger.\n", + " The logger will be initialized if it has not been initialized. By default a\n", + " StreamHandler will be added. If ``log_file`` is specified, a FileHandler\n", + " will also be added. The name of the root logger is the top-level package\n", + " name, e.g., \"mmaction\".\n", + " Args:\n", + " log_file (str | None): The log filename. If specified, a FileHandler\n", + " will be added to the root logger.\n", + " log_level (int): The root logger level. Note that only the process of\n", + " rank 0 is affected, while other processes will set the level to\n", + " \"Error\" and be silent most of the time.\n", + " Returns:\n", + " :obj:`logging.Logger`: The root logger.\n", + " \"\"\"\n", + " return get_logger(__name__.split('.')[0], log_file, log_level)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\" Multilayer perceptron.\"\"\"\n", + "\n", + " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", + " super().__init__()\n", + " out_features = out_features or in_features\n", + " hidden_features = hidden_features or in_features\n", + " self.fc1 = nn.Linear(in_features, hidden_features)\n", + " self.act = act_layer()\n", + " self.fc2 = nn.Linear(hidden_features, out_features)\n", + " self.drop = nn.Dropout(drop)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.act(x)\n", + " x = self.drop(x)\n", + " x = self.fc2(x)\n", + " x = self.drop(x)\n", + " return x\n", + " \n", + " \n", + "def window_partition(x, window_size):\n", + " \"\"\"\n", + " Args:\n", + " x: (B, D, H, W, C)\n", + " window_size (tuple[int]): window size\n", + "\n", + " Returns:\n", + " windows: (B*num_windows, window_size*window_size, C)\n", + " \"\"\"\n", + " B, D, H, W, C = x.shape\n", + " x = x.view(\n", + " B, \n", + " D // window_size[0], window_size[0], \n", + " H // window_size[1], window_size[1], \n", + " W // window_size[2], window_size[2], \n", + " C\n", + " )\n", + " windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)\n", + " return windows\n", + "\n", + "\n", + "def window_reverse(windows, window_size, B, D, H, W):\n", + " \"\"\"\n", + " Args:\n", + " windows: (B*num_windows, window_size, window_size, C)\n", + " window_size (tuple[int]): Window size\n", + " H (int): Height of image\n", + " W (int): Width of image\n", + "\n", + " Returns:\n", + " x: (B, D, H, W, C)\n", + " \"\"\"\n", + " x = windows.view(\n", + " B, \n", + " D // window_size[0], \n", + " H // window_size[1], \n", + " W // window_size[2], \n", + " window_size[0], \n", + " window_size[1], \n", + " window_size[2], \n", + " -1\n", + " )\n", + " x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)\n", + " return x\n", + "\n", + "\n", + "def get_window_size(x_size, window_size, shift_size=None):\n", + " use_window_size = list(window_size)\n", + " if shift_size is not None:\n", + " use_shift_size = list(shift_size)\n", + " for i in range(len(x_size)):\n", + " if x_size[i] <= window_size[i]:\n", + " use_window_size[i] = x_size[i]\n", + " if shift_size is not None:\n", + " use_shift_size[i] = 0\n", + " if shift_size is None:\n", + " return tuple(use_window_size)\n", + " else:\n", + " return tuple(use_window_size), tuple(use_shift_size)\n", + " \n", + " \n", + "class WindowAttention3D(nn.Module):\n", + " \"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n", + " It supports both of shifted and non-shifted window.\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " window_size (tuple[int]): The temporal length, height and width of the window.\n", + " num_heads (int): Number of attention heads.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n", + " attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n", + " proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n", + "\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.window_size = window_size # Wd, Wh, Ww\n", + " self.num_heads = num_heads\n", + " head_dim = dim // num_heads\n", + " self.scale = qk_scale or head_dim ** -0.5\n", + "\n", + " # define a parameter table of relative position bias\n", + " self.relative_position_bias_table = nn.Parameter(\n", + " torch.zeros(\n", + " (2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)\n", + " ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH\n", + "\n", + " # get pair-wise relative position index for each token inside the window\n", + " coords_d = torch.arange(self.window_size[0])\n", + " coords_h = torch.arange(self.window_size[1])\n", + " coords_w = torch.arange(self.window_size[2])\n", + " coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww\n", + " coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww\n", + " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww\n", + " \n", + " relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3\n", + " relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0\n", + " relative_coords[:, :, 1] += self.window_size[1] - 1\n", + " relative_coords[:, :, 2] += self.window_size[2] - 1\n", + " relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)\n", + " relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)\n", + " relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww\n", + "\n", + " self.register_buffer(\"relative_position_index\", relative_position_index)\n", + " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", + " self.attn_drop = nn.Dropout(attn_drop)\n", + " self.proj = nn.Linear(dim, dim)\n", + " self.proj_drop = nn.Dropout(proj_drop)\n", + "\n", + " trunc_normal_(self.relative_position_bias_table, std=.02)\n", + " self.softmax = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, x, mask=None):\n", + " \"\"\" Forward function.\n", + " Args:\n", + " x: input features with shape of (num_windows*B, N, C)\n", + " mask: (0/-inf) mask with shape of (num_windows, N, N) or None\n", + " \"\"\"\n", + " B_, N, C = x.shape\n", + " qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C\n", + " q = q * self.scale\n", + " attn = q @ k.transpose(-2, -1)\n", + "\n", + " relative_position_bias = self.relative_position_bias_table[\n", + " self.relative_position_index[:N, :N].reshape(-1)\n", + " ].reshape(N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH\n", + " relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww\n", + " attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N\n", + " \n", + " if mask is not None:\n", + " nW = mask.shape[0]\n", + " attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n", + " attn = attn.view(-1, self.num_heads, N, N)\n", + " attn = self.softmax(attn)\n", + " else:\n", + " attn = self.softmax(attn)\n", + "\n", + " attn = self.attn_drop(attn)\n", + " x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n", + " x = self.proj(x)\n", + " x = self.proj_drop(x)\n", + " \n", + " return x\n", + " \n", + " \n", + "class SwinTransformerBlock3D(nn.Module):\n", + " \"\"\" Swin Transformer Block.\n", + "\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " num_heads (int): Number of attention heads.\n", + " window_size (tuple[int]): Window size.\n", + " shift_size (tuple[int]): Shift size for SW-MSA.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float, optional): Stochastic depth rate. Default: 0.0\n", + " act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),\n", + " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n", + " act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " self.window_size = window_size\n", + " self.shift_size = shift_size\n", + " self.mlp_ratio = mlp_ratio\n", + " self.use_checkpoint=use_checkpoint\n", + "\n", + " assert 0 <= self.shift_size[0] < self.window_size[0], \"shift_size must in 0-window_size\"\n", + " assert 0 <= self.shift_size[1] < self.window_size[1], \"shift_size must in 0-window_size\"\n", + " assert 0 <= self.shift_size[2] < self.window_size[2], \"shift_size must in 0-window_size\"\n", + "\n", + " self.norm1 = norm_layer(dim)\n", + " self.attn = WindowAttention3D(\n", + " dim, window_size=self.window_size, num_heads=num_heads,\n", + " qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n", + "\n", + " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", + " self.norm2 = norm_layer(dim)\n", + " mlp_hidden_dim = int(dim * mlp_ratio)\n", + " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n", + "\n", + " def forward_part1(self, x, mask_matrix):\n", + " B, D, H, W, C = x.shape\n", + " window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)\n", + "\n", + " x = self.norm1(x)\n", + " # pad feature maps to multiples of window size\n", + " pad_l = pad_t = pad_d0 = 0\n", + " pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]\n", + " pad_b = (window_size[1] - H % window_size[1]) % window_size[1]\n", + " pad_r = (window_size[2] - W % window_size[2]) % window_size[2]\n", + " x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))\n", + " _, Dp, Hp, Wp, _ = x.shape\n", + " # cyclic shift\n", + " if any(i > 0 for i in shift_size):\n", + " shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))\n", + " attn_mask = mask_matrix\n", + " else:\n", + " shifted_x = x\n", + " attn_mask = None\n", + " # partition windows\n", + " x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C\n", + " # W-MSA/SW-MSA\n", + " attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C\n", + " # merge windows\n", + " attn_windows = attn_windows.view(-1, *(window_size+(C,)))\n", + " shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C\n", + " # reverse cyclic shift\n", + " if any(i > 0 for i in shift_size):\n", + " x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))\n", + " else:\n", + " x = shifted_x\n", + "\n", + " if pad_d1 >0 or pad_r > 0 or pad_b > 0:\n", + " x = x[:, :D, :H, :W, :].contiguous()\n", + " return x\n", + "\n", + " def forward_part2(self, x):\n", + " return self.drop_path(self.mlp(self.norm2(x)))\n", + "\n", + " def forward(self, x, mask_matrix):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, D, H, W, C).\n", + " mask_matrix: Attention mask for cyclic shift.\n", + " \"\"\"\n", + " \n", + " shortcut = x\n", + " if self.use_checkpoint:\n", + " x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)\n", + " else:\n", + " x = self.forward_part1(x, mask_matrix)\n", + "\n", + " x = shortcut + self.drop_path(x)\n", + "\n", + " if self.use_checkpoint:\n", + " x = x + checkpoint.checkpoint(self.forward_part2, x)\n", + " else:\n", + " x = x + self.forward_part2(x)\n", + "\n", + " return x\n", + " \n", + " \n", + "class PatchMerging(nn.Module):\n", + " \"\"\" Patch Merging Layer\n", + "\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + " def __init__(self, dim, norm_layer=nn.LayerNorm):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n", + " self.norm = norm_layer(4 * dim)\n", + "\n", + " def forward(self, x):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, D, H, W, C).\n", + " \"\"\"\n", + " B, D, H, W, C = x.shape\n", + "\n", + " # padding\n", + " pad_input = (H % 2 == 1) or (W % 2 == 1)\n", + " if pad_input:\n", + " x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n", + "\n", + " x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C\n", + " x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C\n", + " x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C\n", + " x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C\n", + " x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C\n", + "\n", + " x = self.norm(x)\n", + " x = self.reduction(x)\n", + "\n", + " return x\n", + " \n", + " \n", + "def compute_mask(D, H, W, window_size, shift_size, device):\n", + " img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1\n", + " cnt = 0\n", + " for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):\n", + " for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):\n", + " for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):\n", + " img_mask[:, d, h, w, :] = cnt\n", + " cnt += 1\n", + " mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1\n", + " mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]\n", + " attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n", + " attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n", + " return attn_mask\n", + "\n", + "class BasicLayer(nn.Module):\n", + " \"\"\" A basic Swin Transformer layer for one stage.\n", + "\n", + " Args:\n", + " dim (int): Number of feature channels\n", + " depth (int): Depths of this stage.\n", + " num_heads (int): Number of attention head.\n", + " window_size (tuple[int]): Local window size. Default: (1,7,7).\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " dim,\n", + " depth,\n", + " num_heads,\n", + " window_size=(1,7,7),\n", + " mlp_ratio=4.,\n", + " qkv_bias=False,\n", + " qk_scale=None,\n", + " drop=0.,\n", + " attn_drop=0.,\n", + " drop_path=0.,\n", + " norm_layer=nn.LayerNorm,\n", + " downsample=None,\n", + " use_checkpoint=False):\n", + " super().__init__()\n", + " self.window_size = window_size\n", + " self.shift_size = tuple(i // 2 for i in window_size)\n", + " self.depth = depth\n", + " self.use_checkpoint = use_checkpoint\n", + "\n", + " # build blocks\n", + " self.blocks = nn.ModuleList([\n", + " SwinTransformerBlock3D(\n", + " dim=dim,\n", + " num_heads=num_heads,\n", + " window_size=window_size,\n", + " shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,\n", + " mlp_ratio=mlp_ratio,\n", + " qkv_bias=qkv_bias,\n", + " qk_scale=qk_scale,\n", + " drop=drop,\n", + " attn_drop=attn_drop,\n", + " drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n", + " norm_layer=norm_layer,\n", + " use_checkpoint=use_checkpoint,\n", + " )\n", + " for i in range(depth)])\n", + " \n", + " self.downsample = downsample\n", + " if self.downsample is not None:\n", + " self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n", + "\n", + " def forward(self, x):\n", + " \"\"\" Forward function.\n", + "\n", + " Args:\n", + " x: Input feature, tensor size (B, C, D, H, W).\n", + " \"\"\"\n", + " # calculate attention mask for SW-MSA\n", + " B, C, D, H, W = x.shape\n", + " window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)\n", + " x = rearrange(x, 'b c d h w -> b d h w c')\n", + " Dp = int(np.ceil(D / window_size[0])) * window_size[0]\n", + " Hp = int(np.ceil(H / window_size[1])) * window_size[1]\n", + " Wp = int(np.ceil(W / window_size[2])) * window_size[2]\n", + " attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)\n", + " \n", + "\n", + " for blk in self.blocks:\n", + " x = blk(x, attn_mask)\n", + " x = x.view(B, D, H, W, -1)\n", + "\n", + " if self.downsample is not None:\n", + " x = self.downsample(x)\n", + " \n", + " x = rearrange(x, 'b d h w c -> b c d h w')\n", + " return x\n", + " \n", + "class PatchEmbed3D(nn.Module):\n", + " \"\"\" Video to Patch Embedding.\n", + "\n", + " Args:\n", + " patch_size (int): Patch token size. Default: (2,4,4).\n", + " in_chans (int): Number of input video channels. Default: 3.\n", + " embed_dim (int): Number of linear projection output channels. Default: 96.\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: None\n", + " \"\"\"\n", + " def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + "\n", + " self.in_chans = in_chans\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n", + " if norm_layer is not None:\n", + " self.norm = norm_layer(embed_dim)\n", + " else:\n", + " self.norm = None\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward function.\"\"\"\n", + " # padding\n", + " _, _, D, H, W = x.size()\n", + " if W % self.patch_size[2] != 0:\n", + " x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))\n", + " if H % self.patch_size[1] != 0:\n", + " x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))\n", + " if D % self.patch_size[0] != 0:\n", + " x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))\n", + " \n", + " x = self.proj(x) # B C D Wh Ww\n", + "\n", + " if self.norm is not None:\n", + " D, Wh, Ww = x.size(2), x.size(3), x.size(4)\n", + " x = x.flatten(2).transpose(1, 2)\n", + " x = self.norm(x)\n", + " x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)\n", + "\n", + " return x\n", + " \n", + " \n", + "class SwinTransformer3D(nn.Module):\n", + " \"\"\" Swin Transformer backbone.\n", + " A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -\n", + " https://arxiv.org/pdf/2103.14030\n", + "\n", + " Args:\n", + " patch_size (int | tuple(int)): Patch size. Default: (4,4,4).\n", + " in_chans (int): Number of input image channels. Default: 3.\n", + " embed_dim (int): Number of linear projection output channels. Default: 96.\n", + " depths (tuple[int]): Depths of each Swin Transformer stage.\n", + " num_heads (tuple[int]): Number of attention head of each stage.\n", + " window_size (int): Window size. Default: 7.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n", + " qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee\n", + " qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop_rate (float): Dropout rate.\n", + " attn_drop_rate (float): Attention dropout rate. Default: 0.\n", + " drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n", + " norm_layer: Normalization layer. Default: nn.LayerNorm.\n", + " patch_norm (bool): If True, add normalization after patch embedding. Default: False.\n", + " frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n", + " -1 means not freezing any parameters.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " pretrained=None,\n", + " pretrained2d=True,\n", + " patch_size=(4,4,4),\n", + " in_chans=3,\n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " window_size=(2,7,7),\n", + " mlp_ratio=4.,\n", + " qkv_bias=True,\n", + " qk_scale=None,\n", + " drop_rate=0.,\n", + " attn_drop_rate=0.,\n", + " drop_path_rate=0.2,\n", + " norm_layer=nn.LayerNorm,\n", + " patch_norm=False,\n", + " frozen_stages=-1,\n", + " use_checkpoint=False,\n", + " \n", + " # class head\n", + " spatial_type='avg',\n", + " in_channels=768,\n", + " num_classes=400,\n", + " dropout_ratio=0.5 # to do check: no dropout layer in weight state\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.pretrained = pretrained\n", + " self.pretrained2d = pretrained2d\n", + " self.num_layers = len(depths)\n", + " self.embed_dim = embed_dim\n", + " self.patch_norm = patch_norm\n", + " self.frozen_stages = frozen_stages\n", + " self.window_size = window_size\n", + " self.patch_size = patch_size\n", + "\n", + " # split image into non-overlapping patches\n", + " self.patch_embed = PatchEmbed3D(\n", + " patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n", + " norm_layer=norm_layer if self.patch_norm else None)\n", + "\n", + " self.pos_drop = nn.Dropout(p=drop_rate)\n", + "\n", + " # stochastic depth\n", + " dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule\n", + "\n", + " # build layers\n", + " self.layers = nn.ModuleList()\n", + " for i_layer in range(self.num_layers):\n", + " layer = BasicLayer(\n", + " dim=int(embed_dim * 2**i_layer),\n", + " depth=depths[i_layer],\n", + " num_heads=num_heads[i_layer],\n", + " window_size=window_size,\n", + " mlp_ratio=mlp_ratio,\n", + " qkv_bias=qkv_bias,\n", + " qk_scale=qk_scale,\n", + " drop=drop_rate,\n", + " attn_drop=attn_drop_rate,\n", + " drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n", + " norm_layer=norm_layer,\n", + " downsample=PatchMerging if i_layer= 0:\n", + " self.patch_embed.eval()\n", + " for param in self.patch_embed.parameters():\n", + " param.requires_grad = False\n", + "\n", + " if self.frozen_stages >= 1:\n", + " self.pos_drop.eval()\n", + " for i in range(0, self.frozen_stages):\n", + " m = self.layers[i]\n", + " m.eval()\n", + " for param in m.parameters():\n", + " param.requires_grad = False\n", + "\n", + " def inflate_weights(self, logger):\n", + " \"\"\"Inflate the swin2d parameters to swin3d.\n", + "\n", + " The differences between swin3d and swin2d mainly lie in an extra\n", + " axis. To utilize the pretrained parameters in 2d model,\n", + " the weight of swin2d models should be inflated to fit in the shapes of\n", + " the 3d counterpart.\n", + "\n", + " Args:\n", + " logger (logging.Logger): The logger used to print\n", + " debugging infomation.\n", + " \"\"\"\n", + " checkpoint = torch.load(self.pretrained, map_location='cpu')\n", + " state_dict = checkpoint['model']\n", + "\n", + " # delete relative_position_index since we always re-init it\n", + " relative_position_index_keys = [k for k in state_dict.keys() if \"relative_position_index\" in k]\n", + " for k in relative_position_index_keys:\n", + " del state_dict[k]\n", + "\n", + " # delete attn_mask since we always re-init it\n", + " attn_mask_keys = [k for k in state_dict.keys() if \"attn_mask\" in k]\n", + " for k in attn_mask_keys:\n", + " del state_dict[k]\n", + "\n", + " state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]\n", + "\n", + " # bicubic interpolate relative_position_bias_table if not match\n", + " relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n", + " for k in relative_position_bias_table_keys:\n", + " relative_position_bias_table_pretrained = state_dict[k]\n", + " relative_position_bias_table_current = self.state_dict()[k]\n", + " L1, nH1 = relative_position_bias_table_pretrained.size()\n", + " L2, nH2 = relative_position_bias_table_current.size()\n", + " L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)\n", + " wd = self.window_size[0]\n", + " if nH1 != nH2:\n", + " logger.warning(f\"Error in loading {k}, passing\")\n", + " else:\n", + " if L1 != L2:\n", + " S1 = int(L1 ** 0.5)\n", + " relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n", + " relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),\n", + " mode='bicubic')\n", + " relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n", + " state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)\n", + "\n", + " msg = self.load_state_dict(state_dict, strict=False)\n", + " logger.info(msg)\n", + " logger.info(f\"=> loaded successfully '{self.pretrained}'\")\n", + " del checkpoint\n", + " torch.cuda.empty_cache()\n", + "\n", + " def init_weights(self, pretrained=None):\n", + " \"\"\"Initialize the weights in backbone.\n", + "\n", + " Args:\n", + " pretrained (str, optional): Path to pre-trained weights.\n", + " Defaults to None.\n", + " \"\"\"\n", + " def _init_weights(m):\n", + " if isinstance(m, nn.Linear):\n", + " trunc_normal_(m.weight, std=.02)\n", + " if isinstance(m, nn.Linear) and m.bias is not None:\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.LayerNorm):\n", + " nn.init.constant_(m.bias, 0)\n", + " nn.init.constant_(m.weight, 1.0)\n", + "\n", + " if pretrained:\n", + " self.pretrained = pretrained\n", + " if isinstance(self.pretrained, str):\n", + " self.apply(_init_weights)\n", + " logger = get_root_logger()\n", + " logger.info(f'load model from: {self.pretrained}')\n", + "\n", + " if self.pretrained2d:\n", + " # Inflate 2D model into 3D model.\n", + " self.inflate_weights(logger)\n", + " else:\n", + " # Directly load 3D model.\n", + " load_checkpoint(self, self.pretrained, strict=False, logger=logger)\n", + " elif self.pretrained is None:\n", + " self.apply(_init_weights)\n", + " else:\n", + " raise TypeError('pretrained must be a str or None')\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward function.\"\"\"\n", + "\n", + " x = self.patch_embed(x)\n", + " x = self.pos_drop(x)\n", + "\n", + " for layer in self.layers:\n", + " x = layer(x.contiguous())\n", + " \n", + " x = rearrange(x, 'n c d h w -> n d h w c')\n", + " x = self.norm(x)\n", + " x = rearrange(x, 'n d h w c -> n c d h w')\n", + " \n", + " x = self.avg_pool3d(x).squeeze()\n", + " x = self.fc_cls(x)\n", + " return x\n", + " \n", + "\n", + " def train(self, mode=True):\n", + " \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n", + " super(SwinTransformer3D, self).train(mode)\n", + " self._freeze_stages()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6bba7805", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:04:57.066377Z", + "iopub.status.busy": "2024-03-31T16:04:57.066007Z", + "iopub.status.idle": "2024-03-31T16:04:57.072137Z", + "shell.execute_reply": "2024-03-31T16:04:57.071055Z" + }, + "jupyter": { + "source_hidden": true + }, + "papermill": { + "duration": 0.017448, + "end_time": "2024-03-31T16:04:57.074285", + "exception": false, + "start_time": "2024-03-31T16:04:57.056837", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def video_swin_base(window_size=(8,7,7), **kwargs):\n", + " model = SwinTransformer3D(\n", + " patch_size=(2,4,4),\n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " window_size=window_size,\n", + " mlp_ratio=4.,\n", + " qkv_bias=True,\n", + " qk_scale=None,\n", + " drop_rate=0.,\n", + " attn_drop_rate=0.,\n", + " drop_path_rate=0.2,\n", + " norm_layer=nn.LayerNorm,\n", + " patch_norm=True,\n", + " in_channels=1024,\n", + " **kwargs\n", + " )\n", + " return model\n" + ] + }, + { + "cell_type": "markdown", + "id": "6aea5774", + "metadata": { + "papermill": { + "duration": 0.007483, + "end_time": "2024-03-31T16:04:57.089632", + "exception": false, + "start_time": "2024-03-31T16:04:57.082149", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin K600 PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b1e6da82", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:04:57.106762Z", + "iopub.status.busy": "2024-03-31T16:04:57.106404Z", + "iopub.status.idle": "2024-03-31T16:04:58.378096Z", + "shell.execute_reply": "2024-03-31T16:04:58.377033Z" + }, + "papermill": { + "duration": 1.283258, + "end_time": "2024-03-31T16:04:58.380748", + "exception": false, + "start_time": "2024-03-31T16:04:57.097490", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + } + ], + "source": [ + "model_pt = video_swin_base(\n", + " window_size=(8,7,7), num_classes=600\n", + ")\n", + "model_pt.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "715f2fbd", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:04:58.398607Z", + "iopub.status.busy": "2024-03-31T16:04:58.397845Z", + "iopub.status.idle": "2024-03-31T16:05:08.659205Z", + "shell.execute_reply": "2024-03-31T16:05:08.658191Z" + }, + "papermill": { + "duration": 10.273307, + "end_time": "2024-03-31T16:05:08.662022", + "exception": false, + "start_time": "2024-03-31T16:04:58.388715", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:04:59-- https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics600_22k.pth\r\n", + "Resolving github.com (github.com)... 20.27.177.113\r\n", + "Connecting to github.com (github.com)|20.27.177.113|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/099f2980-d55e-11eb-8848-6616f5f65526?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160459Z&X-Amz-Expires=300&X-Amz-Signature=183af2b42b064ab70553197314b19c2cd1f7ca3f36da0171fdc7f489de5bcbd2&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window877_kinetics600_22k.pth&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:04:59-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/099f2980-d55e-11eb-8848-6616f5f65526?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160459Z&X-Amz-Expires=300&X-Amz-Signature=183af2b42b064ab70553197314b19c2cd1f7ca3f36da0171fdc7f489de5bcbd2&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window877_kinetics600_22k.pth&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 382579368 (365M) [application/octet-stream]\r\n", + "Saving to: 'checkpoint.pt'\r\n", + "\r\n", + "checkpoint.pt 100%[===================>] 364.86M 47.3MB/s in 7.8s \r\n", + "\r\n", + "2024-03-31 16:05:08 (47.0 MB/s) - 'checkpoint.pt' saved [382579368/382579368]\r\n", + "\r\n" + ] + } + ], + "source": [ + "base_url = \"https://github.com/SwinTransformer/storage/releases/download/v1.0.4/\"\n", + "checkpoints_pt = f\"{base_url}swin_base_patch244_window877_kinetics600_22k.pth\"\n", + "!wget {checkpoints_pt} -O checkpoint.pt" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "641ac80a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:08.688298Z", + "iopub.status.busy": "2024-03-31T16:05:08.687277Z", + "iopub.status.idle": "2024-03-31T16:05:08.989678Z", + "shell.execute_reply": "2024-03-31T16:05:08.988889Z" + }, + "papermill": { + "duration": 0.317536, + "end_time": "2024-03-31T16:05:08.991999", + "exception": false, + "start_time": "2024-03-31T16:05:08.674463", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "state_dict = torch.load(\n", + " 'checkpoint.pt', map_location=\"cpu\"\n", + ")\n", + "state_dict = state_dict['state_dict']\n", + "state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}\n", + "state_dict = {k.replace('cls_head.', ''): v for k, v in state_dict.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8cc7b92e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:09.017187Z", + "iopub.status.busy": "2024-03-31T16:05:09.015048Z", + "iopub.status.idle": "2024-03-31T16:05:09.147071Z", + "shell.execute_reply": "2024-03-31T16:05:09.146018Z" + }, + "papermill": { + "duration": 0.145924, + "end_time": "2024-03-31T16:05:09.149026", + "exception": false, + "start_time": "2024-03-31T16:05:09.003102", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_pt.load_state_dict(state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8d96566", + "metadata": { + "papermill": { + "duration": 0.010665, + "end_time": "2024-03-31T16:05:09.170511", + "exception": false, + "start_time": "2024-03-31T16:05:09.159846", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "704927d6", + "metadata": { + "papermill": { + "duration": 0.010737, + "end_time": "2024-03-31T16:05:09.192102", + "exception": false, + "start_time": "2024-03-31T16:05:09.181365", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Video Swin K600 [Keras CV]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f0e0a7f8", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:09.215666Z", + "iopub.status.busy": "2024-03-31T16:05:09.214689Z", + "iopub.status.idle": "2024-03-31T16:05:09.218817Z", + "shell.execute_reply": "2024-03-31T16:05:09.218152Z" + }, + "papermill": { + "duration": 0.018009, + "end_time": "2024-03-31T16:05:09.220824", + "exception": false, + "start_time": "2024-03-31T16:05:09.202815", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "92998234", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:09.244982Z", + "iopub.status.busy": "2024-03-31T16:05:09.244117Z", + "iopub.status.idle": "2024-03-31T16:05:40.168044Z", + "shell.execute_reply": "2024-03-31T16:05:40.166654Z" + }, + "papermill": { + "duration": 30.938826, + "end_time": "2024-03-31T16:05:40.170636", + "exception": false, + "start_time": "2024-03-31T16:05:09.231810", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13766, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1903/1903), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (770/770), done.\u001b[K\r\n", + "remote: Total 13766 (delta 1322), reused 1606 (delta 1117), pack-reused 11863\u001b[K\r\n", + "Receiving objects: 100% (13766/13766), 25.65 MiB | 21.65 MiB/s, done.\r\n", + "Resolving deltas: 100% (9767/9767), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d3831628", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:40.198732Z", + "iopub.status.busy": "2024-03-31T16:05:40.198351Z", + "iopub.status.idle": "2024-03-31T16:05:56.094959Z", + "shell.execute_reply": "2024-03-31T16:05:56.093869Z" + }, + "papermill": { + "duration": 15.913171, + "end_time": "2024-03-31T16:05:56.097127", + "exception": false, + "start_time": "2024-03-31T16:05:40.183956", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-31 16:05:43.821248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-03-31 16:05:43.821374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-03-31 16:05:43.983458: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "data": { + "text/plain": [ + "'3.0.5'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "keras.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "90c1cfde", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:05:56.125268Z", + "iopub.status.busy": "2024-03-31T16:05:56.124578Z", + "iopub.status.idle": "2024-03-31T16:06:05.310596Z", + "shell.execute_reply": "2024-03-31T16:06:05.309379Z" + }, + "papermill": { + "duration": 9.202803, + "end_time": "2024-03-31T16:06:05.313216", + "exception": false, + "start_time": "2024-03-31T16:05:56.110413", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:05:57-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k_classifier.weights.h5\r\n", + "Resolving github.com (github.com)... 20.27.177.113\r\n", + "Connecting to github.com (github.com)|20.27.177.113|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/7d830aea-e75f-4b5c-a44f-9395b90cd47a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160557Z&X-Amz-Expires=300&X-Amz-Signature=e8d60d912fe3947dd7f0016cbe19d9f1fe549251dae5955fc13640a364c6504b&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics600_imagenet22k_classifier.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:05:57-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/7d830aea-e75f-4b5c-a44f-9395b90cd47a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160557Z&X-Amz-Expires=300&X-Amz-Signature=e8d60d912fe3947dd7f0016cbe19d9f1fe549251dae5955fc13640a364c6504b&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics600_imagenet22k_classifier.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 353850608 (337M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_kinetics600_imagenet22k_classifier.weights.h5'\r\n", + "\r\n", + "videoswin_base_kine 100%[===================>] 337.46M 50.7MB/s in 6.7s \r\n", + "\r\n", + "2024-03-31 16:06:05 (50.3 MB/s) - 'videoswin_base_kinetics600_imagenet22k_classifier.weights.h5' saved [353850608/353850608]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k_classifier.weights.h5" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0c306263", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:05.346765Z", + "iopub.status.busy": "2024-03-31T16:06:05.346082Z", + "iopub.status.idle": "2024-03-31T16:06:05.352536Z", + "shell.execute_reply": "2024-03-31T16:06:05.351649Z" + }, + "papermill": { + "duration": 0.025676, + "end_time": "2024-03-31T16:06:05.354666", + "exception": false, + "start_time": "2024-03-31T16:06:05.328990", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=600,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " 'videoswin_base_kinetics600_imagenet22k_classifier.weights.h5'\n", + " )\n", + " return keras_model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e6004862", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:05.386931Z", + "iopub.status.busy": "2024-03-31T16:06:05.386544Z", + "iopub.status.idle": "2024-03-31T16:06:08.324183Z", + "shell.execute_reply": "2024-03-31T16:06:08.323201Z" + }, + "papermill": { + "duration": 2.957008, + "end_time": "2024-03-31T16:06:08.326923", + "exception": false, + "start_time": "2024-03-31T16:06:05.369915", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_ks = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9fa1b05a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:08.359256Z", + "iopub.status.busy": "2024-03-31T16:06:08.358882Z", + "iopub.status.idle": "2024-03-31T16:06:08.368428Z", + "shell.execute_reply": "2024-03-31T16:06:08.367371Z" + }, + "papermill": { + "duration": 0.028719, + "end_time": "2024-03-31T16:06:08.370959", + "exception": false, + "start_time": "2024-03-31T16:06:08.342240", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch, number of params (M): 88.25\n", + "Keras, number of params (M): 88.25\n" + ] + } + ], + "source": [ + "n_parameters = sum(p.numel() for p in model_pt.parameters() if p.requires_grad)\n", + "print(\"PyTorch, number of params (M): %.2f\" % (n_parameters / 1.0e6))\n", + "n_parameters = model_ks.count_params()\n", + "print(\"Keras, number of params (M): %.2f\" % (n_parameters / 1.0e6))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "688e7664", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:08.403342Z", + "iopub.status.busy": "2024-03-31T16:06:08.402987Z", + "iopub.status.idle": "2024-03-31T16:06:08.586975Z", + "shell.execute_reply": "2024-03-31T16:06:08.585816Z" + }, + "papermill": { + "duration": 0.203142, + "end_time": "2024-03-31T16:06:08.589445", + "exception": false, + "start_time": "2024-03-31T16:06:08.386303", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])\n" + ] + } + ], + "source": [ + "common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')\n", + "keras_input = ops.array(common_input)\n", + "torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))\n", + "print(keras_input.shape, torch_input.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "be091dad", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:08.626529Z", + "iopub.status.busy": "2024-03-31T16:06:08.625742Z", + "iopub.status.idle": "2024-03-31T16:06:08.632827Z", + "shell.execute_reply": "2024-03-31T16:06:08.631746Z" + }, + "papermill": { + "duration": 0.030273, + "end_time": "2024-03-31T16:06:08.636166", + "exception": false, + "start_time": "2024-03-31T16:06:08.605893", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def logit_checking(keras_model, torch_model):\n", + " # forward pass\n", + " keras_predict = keras_model(keras_input)\n", + " torch_predict = torch_model(torch_input)[None, ...]\n", + " print(keras_predict.shape, torch_predict.shape)\n", + " print('keras logits: ', keras_predict[0, :5])\n", + " print('torch logits: ', torch_predict[0, :5], end='\\n')\n", + " np.testing.assert_allclose(\n", + " keras_predict.detach().numpy(),\n", + " torch_predict.detach().numpy(),\n", + " 1e-5, 1e-5\n", + " )\n", + " del keras_model \n", + " del torch_model" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a1910940", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:08.671093Z", + "iopub.status.busy": "2024-03-31T16:06:08.670458Z", + "iopub.status.idle": "2024-03-31T16:06:43.632561Z", + "shell.execute_reply": "2024-03-31T16:06:43.631530Z" + }, + "papermill": { + "duration": 34.98197, + "end_time": "2024-03-31T16:06:43.635238", + "exception": false, + "start_time": "2024-03-31T16:06:08.653268", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 600]) torch.Size([1, 600])\n", + "keras logits: tensor([ 0.6555, -0.0984, -0.3937, 1.3615, -1.1496], grad_fn=)\n", + "torch logits: tensor([ 0.6555, -0.0984, -0.3937, 1.3615, -1.1496], grad_fn=)\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " model_ks, model_pt\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72e3c73c", + "metadata": { + "papermill": { + "duration": 0.015282, + "end_time": "2024-03-31T16:06:43.666082", + "exception": false, + "start_time": "2024-03-31T16:06:43.650800", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "14ade18d", + "metadata": { + "papermill": { + "duration": 0.015046, + "end_time": "2024-03-31T16:06:43.696550", + "exception": false, + "start_time": "2024-03-31T16:06:43.681504", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Something Something V2 PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "7cfa33a2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:43.729399Z", + "iopub.status.busy": "2024-03-31T16:06:43.728586Z", + "iopub.status.idle": "2024-03-31T16:06:44.777480Z", + "shell.execute_reply": "2024-03-31T16:06:44.776538Z" + }, + "papermill": { + "duration": 1.068072, + "end_time": "2024-03-31T16:06:44.780085", + "exception": false, + "start_time": "2024-03-31T16:06:43.712013", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_pt = video_swin_base(\n", + " window_size=(16,7,7), num_classes=174\n", + ")\n", + "model_pt.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "ef8184b1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:44.813809Z", + "iopub.status.busy": "2024-03-31T16:06:44.812698Z", + "iopub.status.idle": "2024-03-31T16:06:52.231172Z", + "shell.execute_reply": "2024-03-31T16:06:52.229682Z" + }, + "papermill": { + "duration": 7.438448, + "end_time": "2024-03-31T16:06:52.233971", + "exception": false, + "start_time": "2024-03-31T16:06:44.795523", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:06:45-- https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window1677_sthv2.pth\r\n", + "Resolving github.com (github.com)... 20.27.177.113\r\n", + "Connecting to github.com (github.com)|20.27.177.113|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/20458080-d55e-11eb-9021-4730e624e0ea?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160646Z&X-Amz-Expires=300&X-Amz-Signature=db24bcf8e1de48dcb05cb9484a45ba7bfd26e989573c3b6d9c2a777a13e8e75f&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window1677_sthv2.pth&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:06:46-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/357198522/20458080-d55e-11eb-9021-4730e624e0ea?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160646Z&X-Amz-Expires=300&X-Amz-Signature=db24bcf8e1de48dcb05cb9484a45ba7bfd26e989573c3b6d9c2a777a13e8e75f&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=357198522&response-content-disposition=attachment%3B%20filename%3Dswin_base_patch244_window1677_sthv2.pth&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 473410081 (451M) [application/octet-stream]\r\n", + "Saving to: 'checkpoint.pt'\r\n", + "\r\n", + "checkpoint.pt 100%[===================>] 451.48M 99.7MB/s in 4.8s \r\n", + "\r\n", + "2024-03-31 16:06:52 (94.3 MB/s) - 'checkpoint.pt' saved [473410081/473410081]\r\n", + "\r\n" + ] + } + ], + "source": [ + "base_url = \"https://github.com/SwinTransformer/storage/releases/download/v1.0.4/\"\n", + "checkpoints_pt = f\"{base_url}swin_base_patch244_window1677_sthv2.pth\"\n", + "!wget {checkpoints_pt} -O checkpoint.pt" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "031319c7", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:52.271026Z", + "iopub.status.busy": "2024-03-31T16:06:52.270577Z", + "iopub.status.idle": "2024-03-31T16:06:52.520063Z", + "shell.execute_reply": "2024-03-31T16:06:52.519136Z" + }, + "papermill": { + "duration": 0.271107, + "end_time": "2024-03-31T16:06:52.522630", + "exception": false, + "start_time": "2024-03-31T16:06:52.251523", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "state_dict = torch.load(\n", + " 'checkpoint.pt', map_location=\"cpu\"\n", + ")\n", + "state_dict = state_dict['state_dict']\n", + "state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}\n", + "state_dict = {k.replace('cls_head.', ''): v for k, v in state_dict.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "5b5d4dab", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:52.560500Z", + "iopub.status.busy": "2024-03-31T16:06:52.560135Z", + "iopub.status.idle": "2024-03-31T16:06:52.708753Z", + "shell.execute_reply": "2024-03-31T16:06:52.707666Z" + }, + "papermill": { + "duration": 0.170469, + "end_time": "2024-03-31T16:06:52.711087", + "exception": false, + "start_time": "2024-03-31T16:06:52.540618", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_pt.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "5cc60f70", + "metadata": { + "papermill": { + "duration": 0.018311, + "end_time": "2024-03-31T16:06:52.747151", + "exception": false, + "start_time": "2024-03-31T16:06:52.728840", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Something Somethinb V2 KerasCV" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "88e869a1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:06:52.848487Z", + "iopub.status.busy": "2024-03-31T16:06:52.848101Z", + "iopub.status.idle": "2024-03-31T16:07:02.546608Z", + "shell.execute_reply": "2024-03-31T16:07:02.545083Z" + }, + "papermill": { + "duration": 9.784241, + "end_time": "2024-03-31T16:07:02.549171", + "exception": false, + "start_time": "2024-03-31T16:06:52.764930", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:06:53-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2_classifier.weights.h5\r\n", + "Resolving github.com (github.com)... 20.27.177.113\r\n", + "Connecting to github.com (github.com)|20.27.177.113|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/2ed443ee-14b8-4243-b5fa-baaf6067e22e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160654Z&X-Amz-Expires=300&X-Amz-Signature=bf014737d4edf44b89e7beb5796f36b25954606db5f7c05b379dfbaa6c2bd505&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_something_something_v2_classifier.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:06:54-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/2ed443ee-14b8-4243-b5fa-baaf6067e22e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T160654Z&X-Amz-Expires=300&X-Amz-Signature=bf014737d4edf44b89e7beb5796f36b25954606db5f7c05b379dfbaa6c2bd505&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_something_something_v2_classifier.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.111.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 356168352 (340M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_something_something_v2_classifier.weights.h5'\r\n", + "\r\n", + "videoswin_base_some 100%[===================>] 339.67M 48.4MB/s in 7.0s \r\n", + "\r\n", + "2024-03-31 16:07:02 (48.3 MB/s) - 'videoswin_base_something_something_v2_classifier.weights.h5' saved [356168352/356168352]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2_classifier.weights.h5\n", + "\n", + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " window_size=[16, 7, 7],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=174,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " 'videoswin_base_something_something_v2_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "cb2ecf5e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:07:02.592931Z", + "iopub.status.busy": "2024-03-31T16:07:02.592216Z", + "iopub.status.idle": "2024-03-31T16:07:05.652536Z", + "shell.execute_reply": "2024-03-31T16:07:05.651238Z" + }, + "papermill": { + "duration": 3.085886, + "end_time": "2024-03-31T16:07:05.655498", + "exception": false, + "start_time": "2024-03-31T16:07:02.569612", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model_ks = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d986e517", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:07:07.667631Z", + "iopub.status.busy": "2024-03-31T16:07:07.666391Z", + "iopub.status.idle": "2024-03-31T16:07:51.632549Z", + "shell.execute_reply": "2024-03-31T16:07:51.631546Z" + }, + "papermill": { + "duration": 44.05175, + "end_time": "2024-03-31T16:07:51.635161", + "exception": false, + "start_time": "2024-03-31T16:07:07.583411", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 174]) torch.Size([1, 174])\n", + "keras logits: tensor([-0.9237, 0.6242, -0.2347, -0.6530, -0.7699], grad_fn=)\n", + "torch logits: tensor([-0.9237, 0.6242, -0.2347, -0.6530, -0.7699], grad_fn=)\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " model_ks, model_pt\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d6bfa12", + "metadata": { + "papermill": { + "duration": 0.019954, + "end_time": "2024-03-31T16:07:51.675491", + "exception": false, + "start_time": "2024-03-31T16:07:51.655537", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kaggle": { + "accelerator": "none", + "dataSources": [ + { + "isSourceIdPinned": true, + "modelInstanceId": 17533, + "sourceId": 21184, + "sourceType": "modelInstanceVersion" + } + ], + "dockerImageVersionId": 30673, + "isGpuEnabled": false, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 202.097042, + "end_time": "2024-03-31T16:07:55.574077", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-03-31T16:04:33.477035", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/guides/kerascv-kinetics-400-evaluation-in-pytorch.ipynb b/guides/kerascv-kinetics-400-evaluation-in-pytorch.ipynb new file mode 100644 index 0000000..616640b --- /dev/null +++ b/guides/kerascv-kinetics-400-evaluation-in-pytorch.ipynb @@ -0,0 +1,1624 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1eaf876c", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-04-03T06:51:23.922069Z", + "iopub.status.busy": "2024-04-03T06:51:23.921347Z", + "iopub.status.idle": "2024-04-03T06:51:38.385947Z", + "shell.execute_reply": "2024-04-03T06:51:38.384845Z" + }, + "papermill": { + "duration": 14.478032, + "end_time": "2024-04-03T06:51:38.388525", + "exception": false, + "start_time": "2024-04-03T06:51:23.910493", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install decord -q" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4dbec267", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:51:38.408525Z", + "iopub.status.busy": "2024-04-03T06:51:38.408119Z", + "iopub.status.idle": "2024-04-03T06:52:09.447505Z", + "shell.execute_reply": "2024-04-03T06:52:09.446420Z" + }, + "papermill": { + "duration": 31.052127, + "end_time": "2024-04-03T06:52:09.449916", + "exception": false, + "start_time": "2024-04-03T06:51:38.397789", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13782, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1919/1919), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (769/769), done.\u001b[K\r\n", + "remote: Total 13782 (delta 1337), reused 1628 (delta 1134), pack-reused 11863\u001b[K\r\n", + "Receiving objects: 100% (13782/13782), 25.65 MiB | 27.53 MiB/s, done.\r\n", + "Resolving deltas: 100% (9788/9788), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1d893c5b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:09.472818Z", + "iopub.status.busy": "2024-04-03T06:52:09.472498Z", + "iopub.status.idle": "2024-04-03T06:52:09.477540Z", + "shell.execute_reply": "2024-04-03T06:52:09.476554Z" + }, + "papermill": { + "duration": 0.018822, + "end_time": "2024-04-03T06:52:09.479574", + "exception": false, + "start_time": "2024-04-03T06:52:09.460752", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os, warnings\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\" \n", + "warnings.simplefilter(action=\"ignore\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ef3dd5d0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:09.502723Z", + "iopub.status.busy": "2024-04-03T06:52:09.502106Z", + "iopub.status.idle": "2024-04-03T06:52:30.729858Z", + "shell.execute_reply": "2024-04-03T06:52:30.728816Z" + }, + "papermill": { + "duration": 21.241686, + "end_time": "2024-04-03T06:52:30.732079", + "exception": false, + "start_time": "2024-04-03T06:52:09.490393", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('3.0.5', '2.1.2')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import os, sys\n", + "import cv2\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from tqdm import tqdm\n", + "from decord import VideoReader\n", + "from decord import cpu, gpu\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "\n", + "keras.__version__, torch.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "156d3160", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:30.754915Z", + "iopub.status.busy": "2024-04-03T06:52:30.754366Z", + "iopub.status.idle": "2024-04-03T06:52:30.760351Z", + "shell.execute_reply": "2024-04-03T06:52:30.759474Z" + }, + "papermill": { + "duration": 0.019592, + "end_time": "2024-04-03T06:52:30.762442", + "exception": false, + "start_time": "2024-04-03T06:52:30.742850", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "id": "c58e1f41", + "metadata": { + "papermill": { + "duration": 0.010719, + "end_time": "2024-04-03T06:52:30.784957", + "exception": false, + "start_time": "2024-04-03T06:52:30.774238", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Data Set" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "57a701bf", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:30.807733Z", + "iopub.status.busy": "2024-04-03T06:52:30.807420Z", + "iopub.status.idle": "2024-04-03T06:52:31.980326Z", + "shell.execute_reply": "2024-04-03T06:52:31.979189Z" + }, + "papermill": { + "duration": 1.186849, + "end_time": "2024-04-03T06:52:31.982494", + "exception": false, + "start_time": "2024-04-03T06:52:30.795645", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idname
00abseiling
11air drumming
22answering questions
33applauding
44applying cream
\n", + "
" + ], + "text/plain": [ + " id name\n", + "0 0 abseiling\n", + "1 1 air drumming\n", + "2 2 answering questions\n", + "3 3 applauding\n", + "4 4 applying cream" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "!wget https://raw.githubusercontent.com/innat/VideoSwin/main/data/kinetics_400_labels.csv -q\n", + "labels = pd.read_csv('kinetics_400_labels.csv')\n", + "labels.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8d4f3c4c", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:32.007009Z", + "iopub.status.busy": "2024-04-03T06:52:32.006111Z", + "iopub.status.idle": "2024-04-03T06:52:32.012778Z", + "shell.execute_reply": "2024-04-03T06:52:32.011933Z" + }, + "papermill": { + "duration": 0.02052, + "end_time": "2024-04-03T06:52:32.014633", + "exception": false, + "start_time": "2024-04-03T06:52:31.994113", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "id2label = dict(zip(labels.id.tolist(), labels.name.tolist()))\n", + "label2id = dict(zip(labels.name.tolist(), labels.id.tolist()))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c43413c5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:52:32.038344Z", + "iopub.status.busy": "2024-04-03T06:52:32.037721Z", + "iopub.status.idle": "2024-04-03T06:53:51.132927Z", + "shell.execute_reply": "2024-04-03T06:53:51.131886Z" + }, + "papermill": { + "duration": 79.144924, + "end_time": "2024-04-03T06:53:51.170549", + "exception": false, + "start_time": "2024-04-03T06:52:32.025625", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "19796it [00:30, 659.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(19796, 3)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
video_pathlabelstring_label
0/kaggle/input/k4testset/videos_val/jf7RDuUTrsQ...325somersaulting
1/kaggle/input/k4testset/videos_val/JTlatknwOrY...233playing harmonica
2/kaggle/input/k4testset/videos_val/8UxlDNur-Z0...262pushing cart
3/kaggle/input/k4testset/videos_val/y9r115bgfNk...320sniffing
4/kaggle/input/k4testset/videos_val/ZnIDviwA8CE...244playing saxophone
\n", + "
" + ], + "text/plain": [ + " video_path label string_label\n", + "0 /kaggle/input/k4testset/videos_val/jf7RDuUTrsQ... 325 somersaulting\n", + "1 /kaggle/input/k4testset/videos_val/JTlatknwOrY... 233 playing harmonica\n", + "2 /kaggle/input/k4testset/videos_val/8UxlDNur-Z0... 262 pushing cart\n", + "3 /kaggle/input/k4testset/videos_val/y9r115bgfNk... 320 sniffing\n", + "4 /kaggle/input/k4testset/videos_val/ZnIDviwA8CE... 244 playing saxophone" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def process_data(text_file_path, data_folder_path, n=10):\n", + " video_paths = []\n", + " labels = []\n", + " string_labels = []\n", + "\n", + " # Get all video filenames from the data folder\n", + " all_files_in_data_folder = [\n", + " f for f in os.listdir(data_folder_path) \n", + " if os.path.isfile(os.path.join(data_folder_path, f))\n", + " ]\n", + " with open(text_file_path, 'r') as f:\n", + " for line in tqdm(f):\n", + " parts = line.strip().split()\n", + " if len(parts) == 2:\n", + " filename, label = parts\n", + " search_string = filename[-n:]\n", + " matching_file = next(\n", + " (\n", + " f for f in all_files_in_data_folder \n", + " if f.endswith(search_string)\n", + " ), None\n", + " )\n", + " if matching_file:\n", + " abs_path = os.path.join(data_folder_path, matching_file)\n", + " video_paths.append(abs_path)\n", + " labels.append(int(label))\n", + " string_labels.append(id2label[int(label)])\n", + " \n", + " df = pd.DataFrame({\n", + " 'video_path': video_paths,\n", + " 'label': labels,\n", + " 'string_label': string_labels\n", + " })\n", + " return df\n", + "\n", + "# Example usage:\n", + "text_file_path = \"/kaggle/input/k4testset/kinetics400_val_list_videos.txt\"\n", + "data_folder_path = \"/kaggle/input/k4testset/videos_val\"\n", + "df = process_data(text_file_path, data_folder_path)\n", + "print(df.shape)\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "27bd6782", + "metadata": { + "papermill": { + "duration": 0.033667, + "end_time": "2024-04-03T06:53:51.238227", + "exception": false, + "start_time": "2024-04-03T06:53:51.204560", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Data Loader\n", + "\n", + "To build the dataloader, we will be using [mmaction](https://mmaction2.readthedocs.io/en/latest/index.html) recipe. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "27c2fd34", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.314733Z", + "iopub.status.busy": "2024-04-03T06:53:51.313960Z", + "iopub.status.idle": "2024-04-03T06:53:51.320470Z", + "shell.execute_reply": "2024-04-03T06:53:51.319317Z" + }, + "papermill": { + "duration": 0.047102, + "end_time": "2024-04-03T06:53:51.322763", + "exception": false, + "start_time": "2024-04-03T06:53:51.275661", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoInit:\n", + " def transform(self, results):\n", + " container = VideoReader(results['filename'])\n", + " results['total_frames'] = len(container)\n", + " results['video_reader'] = container\n", + " results['avg_fps'] = container.get_avg_fps()\n", + " results['start_index'] = 0\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1091a7da", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.403976Z", + "iopub.status.busy": "2024-04-03T06:53:51.403564Z", + "iopub.status.idle": "2024-04-03T06:53:51.438391Z", + "shell.execute_reply": "2024-04-03T06:53:51.437394Z" + }, + "papermill": { + "duration": 0.079147, + "end_time": "2024-04-03T06:53:51.440585", + "exception": false, + "start_time": "2024-04-03T06:53:51.361438", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoSample:\n", + " \"\"\"Sample frames from the video.\n", + "\n", + " Required keys are \"total_frames\", \"start_index\" , added or modified keys\n", + " are \"frame_inds\", \"frame_interval\" and \"num_clips\".\n", + "\n", + " Args:\n", + " clip_len (int): Frames of each sampled output clip.\n", + " frame_interval (int): Temporal interval of adjacent sampled frames.\n", + " Default: 1.\n", + " num_clips (int): Number of clips to be sampled. Default: 1.\n", + " temporal_jitter (bool): Whether to apply temporal jittering.\n", + " Default: False.\n", + " twice_sample (bool): Whether to use twice sample when testing.\n", + " If set to True, it will sample frames with and without fixed shift,\n", + " which is commonly used for testing in TSM model. Default: False.\n", + " out_of_bound_opt (str): The way to deal with out of bounds frame\n", + " indexes. Available options are 'loop', 'repeat_last'.\n", + " Default: 'loop'.\n", + " test_mode (bool): Store True when building test or validation dataset.\n", + " Default: False.\n", + " start_index (None): This argument is deprecated and moved to dataset\n", + " class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc),\n", + " see this: https://github.com/open-mmlab/mmaction2/pull/89.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " clip_len,\n", + " frame_interval=1,\n", + " num_clips=1,\n", + " temporal_jitter=False,\n", + " twice_sample=False,\n", + " out_of_bound_opt='loop',\n", + " test_mode=False,\n", + " start_index=None,\n", + " frame_uniform=False,\n", + " multiview=1\n", + " ):\n", + " self.clip_len = clip_len\n", + " self.frame_interval = frame_interval\n", + " self.num_clips = num_clips\n", + " self.temporal_jitter = temporal_jitter\n", + " self.twice_sample = twice_sample\n", + " self.out_of_bound_opt = out_of_bound_opt\n", + " self.test_mode = test_mode\n", + " self.frame_uniform = frame_uniform\n", + " self.multiview=multiview\n", + " assert self.out_of_bound_opt in ['loop', 'repeat_last']\n", + "\n", + " if start_index is not None:\n", + " warnings.warn(\n", + " 'No longer support \"start_index\" in \"SampleFrames\", '\n", + " 'it should be set in dataset class, see this pr: '\n", + " 'https://github.com/open-mmlab/mmaction2/pull/89'\n", + " )\n", + "\n", + " def _get_train_clips(self, num_frames):\n", + " \"\"\"Get clip offsets in train mode.\n", + "\n", + " It will calculate the average interval for selected frames,\n", + " and randomly shift them within offsets between [0, avg_interval].\n", + " If the total number of frames is smaller than clips num or origin\n", + " frames length, it will return all zero indices.\n", + "\n", + " Args:\n", + " num_frames (int): Total number of frame in the video.\n", + "\n", + " Returns:\n", + " np.ndarray: Sampled frame indices in train mode.\n", + " \"\"\"\n", + " ori_clip_len = self.clip_len * self.frame_interval\n", + " avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips\n", + "\n", + " if avg_interval > 0:\n", + " base_offsets = np.arange(self.num_clips) * avg_interval\n", + " clip_offsets = base_offsets + np.random.randint(\n", + " avg_interval, size=self.num_clips)\n", + " elif num_frames > max(self.num_clips, ori_clip_len):\n", + " clip_offsets = np.sort(\n", + " np.random.randint(\n", + " num_frames - ori_clip_len + 1, size=self.num_clips))\n", + " elif avg_interval == 0:\n", + " ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips\n", + " clip_offsets = np.around(np.arange(self.num_clips) * ratio)\n", + " else:\n", + " clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)\n", + "\n", + " return clip_offsets\n", + "\n", + " def _get_test_clips(self, num_frames):\n", + " \"\"\"Get clip offsets in test mode.\n", + "\n", + " Calculate the average interval for selected frames, and shift them\n", + " fixedly by avg_interval/2. If set twice_sample True, it will sample\n", + " frames together without fixed shift. If the total number of frames is\n", + " not enough, it will return all zero indices.\n", + "\n", + " Args:\n", + " num_frames (int): Total number of frame in the video.\n", + "\n", + " Returns:\n", + " np.ndarray: Sampled frame indices in test mode.\n", + " \"\"\"\n", + " ori_clip_len = self.clip_len * self.frame_interval\n", + " avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)\n", + " if num_frames > ori_clip_len - 1:\n", + " base_offsets = np.arange(self.num_clips) * avg_interval\n", + " clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32)\n", + " if self.twice_sample:\n", + " clip_offsets = np.concatenate([clip_offsets, base_offsets])\n", + " else:\n", + " clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)\n", + " return clip_offsets\n", + "\n", + " def _sample_clips(self, num_frames):\n", + " \"\"\"Choose clip offsets for the video in a given mode.\n", + "\n", + " Args:\n", + " num_frames (int): Total number of frame in the video.\n", + "\n", + " Returns:\n", + " np.ndarray: Sampled frame indices.\n", + " \"\"\"\n", + " if self.test_mode:\n", + " clip_offsets = self._get_test_clips(num_frames)\n", + " else:\n", + " if self.multiview == 1:\n", + " clip_offsets = self._get_train_clips(num_frames)\n", + " else:\n", + " clip_offsets = np.concatenate(\n", + " [\n", + " self._get_train_clips(num_frames) \n", + " for _ in range(self.multiview)\n", + " ]\n", + " )\n", + " return clip_offsets\n", + "\n", + " def get_seq_frames(self, num_frames):\n", + " seg_size = float(num_frames - 1) / self.clip_len\n", + " seq = []\n", + " for i in range(self.clip_len):\n", + " start = int(np.round(seg_size * i))\n", + " end = int(np.round(seg_size * (i + 1)))\n", + " if not self.test_mode:\n", + " seq.append(random.randint(start, end))\n", + " else:\n", + " seq.append((start + end) // 2)\n", + "\n", + " return np.array(seq)\n", + "\n", + " def transform(self, results):\n", + " \"\"\"Perform the SampleFrames loading.\n", + "\n", + " Args:\n", + " results (dict): The resulting dict to be modified and passed\n", + " to the next transform in pipeline.\n", + " \"\"\"\n", + " total_frames = results['total_frames']\n", + " if self.frame_uniform: # sthv2 sampling strategy\n", + " assert results['start_index'] == 0\n", + " frame_inds = self.get_seq_frames(total_frames)\n", + " else:\n", + " clip_offsets = self._sample_clips(total_frames)\n", + " frame_inds = clip_offsets[:, None] + np.arange(\n", + " self.clip_len)[None, :] * self.frame_interval\n", + " frame_inds = np.concatenate(frame_inds)\n", + "\n", + " if self.temporal_jitter:\n", + " perframe_offsets = np.random.randint(\n", + " self.frame_interval, size=len(frame_inds))\n", + " frame_inds += perframe_offsets\n", + "\n", + " frame_inds = frame_inds.reshape((-1, self.clip_len))\n", + " if self.out_of_bound_opt == 'loop':\n", + " frame_inds = np.mod(frame_inds, total_frames)\n", + " elif self.out_of_bound_opt == 'repeat_last':\n", + " safe_inds = frame_inds < total_frames\n", + " unsafe_inds = 1 - safe_inds\n", + " last_ind = np.max(safe_inds * frame_inds, axis=1)\n", + " new_inds = (safe_inds * frame_inds + (unsafe_inds.T * last_ind).T)\n", + " frame_inds = new_inds\n", + " else:\n", + " raise ValueError('Illegal out_of_bound option.')\n", + "\n", + " start_index = results['start_index']\n", + " frame_inds = np.concatenate(frame_inds) + start_index\n", + "\n", + " results['frame_inds'] = frame_inds.astype(np.int32)\n", + " results['clip_len'] = self.clip_len\n", + " results['frame_interval'] = self.frame_interval\n", + " results['num_clips'] = self.num_clips\n", + " return results\n", + "\n", + " def __repr__(self):\n", + " repr_str = (f'{self.__class__.__name__}('\n", + " f'clip_len={self.clip_len}, '\n", + " f'frame_interval={self.frame_interval}, '\n", + " f'num_clips={self.num_clips}, '\n", + " f'temporal_jitter={self.temporal_jitter}, '\n", + " f'twice_sample={self.twice_sample}, '\n", + " f'out_of_bound_opt={self.out_of_bound_opt}, '\n", + " f'test_mode={self.test_mode})')\n", + " return repr_str" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "671e47f5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.521387Z", + "iopub.status.busy": "2024-04-03T06:53:51.520881Z", + "iopub.status.idle": "2024-04-03T06:53:51.527364Z", + "shell.execute_reply": "2024-04-03T06:53:51.526390Z" + }, + "papermill": { + "duration": 0.049057, + "end_time": "2024-04-03T06:53:51.529422", + "exception": false, + "start_time": "2024-04-03T06:53:51.480365", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoDecode:\n", + " def transform(self, results):\n", + " frame_inds = results['frame_inds']\n", + " container = results['video_reader']\n", + " imgs = container.get_batch(frame_inds).asnumpy()\n", + " imgs = list(imgs)\n", + " results['video_reader'] = None\n", + " del container\n", + " results['imgs'] = imgs\n", + " results['img_shape'] = imgs[0].shape[:2]\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d25b4c36", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.607530Z", + "iopub.status.busy": "2024-04-03T06:53:51.607080Z", + "iopub.status.idle": "2024-04-03T06:53:51.620583Z", + "shell.execute_reply": "2024-04-03T06:53:51.619446Z" + }, + "papermill": { + "duration": 0.056298, + "end_time": "2024-04-03T06:53:51.623345", + "exception": false, + "start_time": "2024-04-03T06:53:51.567047", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def _scale_size(\n", + " size,\n", + " scale,\n", + "):\n", + " if isinstance(scale, (float, int)):\n", + " scale = (scale, scale)\n", + " w, h = size\n", + " return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)\n", + "\n", + "def rescale_size(\n", + " old_size,\n", + " scale,\n", + " return_scale=False\n", + "):\n", + " w, h = old_size\n", + " if isinstance(scale, (float, int)):\n", + " if scale <= 0:\n", + " raise ValueError(f'Invalid scale {scale}, must be positive.')\n", + " scale_factor = scale\n", + " elif isinstance(scale, tuple):\n", + " max_long_edge = max(scale)\n", + " max_short_edge = min(scale)\n", + " scale_factor = min(\n", + " max_long_edge / max(h, w),\n", + " max_short_edge / min(h, w)\n", + " )\n", + " else:\n", + " raise TypeError(\n", + " f'Scale must be a number or tuple of int, but got {type(scale)}'\n", + " )\n", + "\n", + " new_size = _scale_size((w, h), scale_factor)\n", + "\n", + " if return_scale:\n", + " return new_size, scale_factor\n", + " else:\n", + " return new_size\n", + "\n", + "class VideoResize:\n", + " def __init__(self, r_size):\n", + " self.r_size = (np.inf, r_size)\n", + "\n", + " def transform(self, results):\n", + " img_h, img_w = results['img_shape']\n", + " new_w, new_h = rescale_size((img_w, img_h), self.r_size)\n", + "\n", + " imgs = [\n", + " cv2.resize(img, (new_w, new_h))\n", + " for img in results['imgs']\n", + " ]\n", + " results['imgs'] = imgs\n", + " results['img_shape'] = imgs[0].shape[:2]\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "46d31ee2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.702400Z", + "iopub.status.busy": "2024-04-03T06:53:51.701648Z", + "iopub.status.idle": "2024-04-03T06:53:51.708951Z", + "shell.execute_reply": "2024-04-03T06:53:51.707952Z" + }, + "papermill": { + "duration": 0.048056, + "end_time": "2024-04-03T06:53:51.711094", + "exception": false, + "start_time": "2024-04-03T06:53:51.663038", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoCrop:\n", + " def __init__(self, c_size):\n", + " self.c_size = c_size\n", + "\n", + " def transform(self, results):\n", + " img_h, img_w = results['img_shape']\n", + " center_x, center_y = img_w // 2, img_h // 2\n", + " x1, x2 = center_x - self.c_size // 2, center_x + self.c_size // 2\n", + " y1, y2 = center_y - self.c_size // 2, center_y + self.c_size // 2\n", + " imgs = [img[y1:y2, x1:x2] for img in results['imgs']]\n", + " results['imgs'] = imgs\n", + " results['img_shape'] = imgs[0].shape[:2]\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9c88cf88", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.782601Z", + "iopub.status.busy": "2024-04-03T06:53:51.781764Z", + "iopub.status.idle": "2024-04-03T06:53:51.787948Z", + "shell.execute_reply": "2024-04-03T06:53:51.786987Z" + }, + "papermill": { + "duration": 0.04297, + "end_time": "2024-04-03T06:53:51.789826", + "exception": false, + "start_time": "2024-04-03T06:53:51.746856", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoFormat:\n", + " def transform(self, results):\n", + " num_clips = results['num_clips']\n", + " clip_len = results['clip_len']\n", + " imgs = results['imgs']\n", + " # [num_clips*clip_len, H, W, C]\n", + " imgs = np.array(imgs)\n", + " # [num_clips, clip_len, H, W, C]\n", + " imgs = imgs.reshape((num_clips, clip_len) + imgs.shape[1:])\n", + " results['imgs'] = imgs\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a0124083", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:51.909478Z", + "iopub.status.busy": "2024-04-03T06:53:51.909112Z", + "iopub.status.idle": "2024-04-03T06:53:52.048494Z", + "shell.execute_reply": "2024-04-03T06:53:52.047019Z" + }, + "papermill": { + "duration": 0.17638, + "end_time": "2024-04-03T06:53:52.050476", + "exception": false, + "start_time": "2024-04-03T06:53:51.874096", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['filename', 'total_frames', 'video_reader', 'avg_fps', 'start_index', 'frame_inds', 'clip_len', 'frame_interval', 'num_clips', 'imgs', 'img_shape'])\n" + ] + }, + { + "data": { + "text/plain": [ + "(1, 16, 224, 224, 3)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item = dict()\n", + "item['filename'] = '/kaggle/input/k4testset/videos_val/--07WQ2iBlw.mp4'\n", + "v_init = VideoInit().transform(item)\n", + "v_sample = VideoSample(clip_len=16, num_clips=1, test_mode=True).transform(v_init)\n", + "v_decode = VideoDecode().transform(v_sample)\n", + "v_resize = VideoResize(r_size=256).transform(v_decode)\n", + "v_crop = VideoCrop(c_size=224).transform(v_resize)\n", + "v_format = VideoFormat().transform(v_crop)\n", + "print(v_format.keys())\n", + "v_format['imgs'].shape" + ] + }, + { + "cell_type": "markdown", + "id": "be48ff6d", + "metadata": { + "papermill": { + "duration": 0.033437, + "end_time": "2024-04-03T06:53:52.118351", + "exception": false, + "start_time": "2024-04-03T06:53:52.084914", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "[inference-config-video-swin](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/db018fb8896251711791386bbd2127562fd8d6a6/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py#L45-L61)\n", + "\n", + "```python\n", + "test_pipeline = [\n", + " dict(type='DecordInit'),\n", + " dict(\n", + " type='SampleFrames',\n", + " clip_len=32,\n", + " frame_interval=2,\n", + " num_clips=4,\n", + " test_mode=True),\n", + " dict(type='DecordDecode'),\n", + " dict(type='Resize', scale=(-1, 224)),\n", + " dict(type='ThreeCrop', crop_size=224),\n", + " dict(type='Flip', flip_ratio=0),\n", + " dict(type='Normalize', **img_norm_cfg),\n", + " dict(type='FormatShape', input_format='NCTHW'),\n", + " dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),\n", + " dict(type='ToTensor', keys=['imgs'])\n", + "]\n", + "```\n", + "\n", + "We will skip `ThreeCrop` and `Flip` at the moment." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b62d1649", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:52.186989Z", + "iopub.status.busy": "2024-04-03T06:53:52.186528Z", + "iopub.status.idle": "2024-04-03T06:53:52.191589Z", + "shell.execute_reply": "2024-04-03T06:53:52.190623Z" + }, + "papermill": { + "duration": 0.041923, + "end_time": "2024-04-03T06:53:52.193583", + "exception": false, + "start_time": "2024-04-03T06:53:52.151660", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "num_classes=400\n", + "batch_size=16\n", + "num_clips=4\n", + "frame_rate=2 \n", + "input_frame=32\n", + "h_crop_size=w_crop_size=224" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "38ad84b5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:52.266322Z", + "iopub.status.busy": "2024-04-03T06:53:52.265343Z", + "iopub.status.idle": "2024-04-03T06:53:52.275520Z", + "shell.execute_reply": "2024-04-03T06:53:52.274464Z" + }, + "papermill": { + "duration": 0.049125, + "end_time": "2024-04-03T06:53:52.277732", + "exception": false, + "start_time": "2024-04-03T06:53:52.228607", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class VideoDataset(Dataset):\n", + " def __init__(self, dataframe, clip_len=1, frame_sample_rate=8):\n", + " self.dataframe = dataframe\n", + " self.clip_len = clip_len\n", + " self.frame_sample_rate = frame_sample_rate\n", + "\n", + " def __len__(self):\n", + " return len(self.dataframe)\n", + "\n", + " def get_frames(self, video_path):\n", + " item = dict()\n", + " item['filename'] = video_path\n", + " v_init = VideoInit().transform(item)\n", + " v_sample = VideoSample(\n", + " clip_len=input_frame, \n", + " num_clips=num_clips, \n", + " frame_interval=frame_rate,\n", + " test_mode=True\n", + " ).transform(v_init)\n", + " v_decode = VideoDecode().transform(v_sample)\n", + " v_resize = VideoResize(r_size=256).transform(v_decode)\n", + " v_crop = VideoCrop(c_size=224).transform(v_resize)\n", + " v_format = VideoFormat().transform(v_crop)\n", + " frames = v_format['imgs']\n", + " return frames\n", + "\n", + " def __getitem__(self, idx):\n", + " video_path = self.dataframe.iloc[idx, 0]\n", + " label = self.dataframe.iloc[idx, 1]\n", + " video = self.get_frames(video_path)\n", + " return torch.tensor(video).to(torch.float32), torch.tensor(label).to(torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "094a3a1a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:52.350322Z", + "iopub.status.busy": "2024-04-03T06:53:52.349529Z", + "iopub.status.idle": "2024-04-03T06:53:52.355074Z", + "shell.execute_reply": "2024-04-03T06:53:52.354169Z" + }, + "papermill": { + "duration": 0.044263, + "end_time": "2024-04-03T06:53:52.357228", + "exception": false, + "start_time": "2024-04-03T06:53:52.312965", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "dataset = VideoDataset(\n", + " dataframe=df, \n", + " clip_len=num_clips, frame_sample_rate=frame_rate\n", + ")\n", + "dataloader = DataLoader(\n", + " dataset, \n", + " batch_size=batch_size, \n", + " shuffle=False, \n", + " pin_memory=True, \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "063e4cf8", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:53:52.428604Z", + "iopub.status.busy": "2024-04-03T06:53:52.427956Z", + "iopub.status.idle": "2024-04-03T06:54:15.228214Z", + "shell.execute_reply": "2024-04-03T06:54:15.227198Z" + }, + "papermill": { + "duration": 22.839482, + "end_time": "2024-04-03T06:54:15.230535", + "exception": false, + "start_time": "2024-04-03T06:53:52.391053", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])\n", + "torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])\n", + "torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])\n" + ] + } + ], + "source": [ + "for i, (videos, labels) in enumerate(dataloader):\n", + " print(videos.shape, labels.shape)\n", + " if i == 2 :\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "b1a09bd1", + "metadata": { + "papermill": { + "duration": 0.034174, + "end_time": "2024-04-03T06:54:15.299292", + "exception": false, + "start_time": "2024-04-03T06:54:15.265118", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f6472d84", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:15.370405Z", + "iopub.status.busy": "2024-04-03T06:54:15.369675Z", + "iopub.status.idle": "2024-04-03T06:54:17.324247Z", + "shell.execute_reply": "2024-04-03T06:54:17.323102Z" + }, + "papermill": { + "duration": 1.993356, + "end_time": "2024-04-03T06:54:17.326792", + "exception": false, + "start_time": "2024-04-03T06:54:15.333436", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q\n", + "\n", + "def vswin_tiny():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=True, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=num_classes,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " 'videoswin_tiny_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a0a1a37c", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:17.398302Z", + "iopub.status.busy": "2024-04-03T06:54:17.397925Z", + "iopub.status.idle": "2024-04-03T06:54:18.517080Z", + "shell.execute_reply": "2024-04-03T06:54:18.516189Z" + }, + "papermill": { + "duration": 1.15803, + "end_time": "2024-04-03T06:54:18.519232", + "exception": false, + "start_time": "2024-04-03T06:54:17.361202", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"video_classifier\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"video_classifier\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ videos (InputLayer)             │ (None, 32, 224, 224,   │             0 │\n",
+       "│                                 │ 3)                     │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ video_swin_backbone             │ (None, 16, 7, 7, 768)  │    27,850,470 │\n",
+       "│ (VideoSwinBackbone)             │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ avg_pool                        │ (None, 768)            │             0 │\n",
+       "│ (GlobalAveragePooling3D)        │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ predictions (Dense)             │ (None, 400)            │       307,600 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ videos (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ video_swin_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m27,850,470\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBackbone\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ avg_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "│ (\u001b[38;5;33mGlobalAveragePooling3D\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ predictions (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m400\u001b[0m) │ \u001b[38;5;34m307,600\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 28,158,070 (107.41 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 28,158,070 (107.41 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model = vswin_tiny()\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "2ea5b058", + "metadata": { + "papermill": { + "duration": 0.035071, + "end_time": "2024-04-03T06:54:18.591038", + "exception": false, + "start_time": "2024-04-03T06:54:18.555967", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Training API" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "5b1fbf80", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:18.664301Z", + "iopub.status.busy": "2024-04-03T06:54:18.663523Z", + "iopub.status.idle": "2024-04-03T06:54:18.669902Z", + "shell.execute_reply": "2024-04-03T06:54:18.669012Z" + }, + "papermill": { + "duration": 0.045088, + "end_time": "2024-04-03T06:54:18.671800", + "exception": false, + "start_time": "2024-04-03T06:54:18.626712", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class AverageMeter:\n", + " \"\"\"Computes and stores the average and current value\"\"\"\n", + " def __init__(self):\n", + " self.reset()\n", + "\n", + " def reset(self):\n", + " self.val = 0\n", + " self.avg = 0\n", + " self.sum = 0\n", + " self.count = 0\n", + "\n", + " def update(self, val, n=1):\n", + " self.val = val\n", + " self.sum += val * n\n", + " self.count += n\n", + " self.avg = self.sum / self.count" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2c17e3e2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:18.742637Z", + "iopub.status.busy": "2024-04-03T06:54:18.742319Z", + "iopub.status.idle": "2024-04-03T06:54:18.752982Z", + "shell.execute_reply": "2024-04-03T06:54:18.752268Z" + }, + "papermill": { + "duration": 0.048423, + "end_time": "2024-04-03T06:54:18.755033", + "exception": false, + "start_time": "2024-04-03T06:54:18.706610", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " model.cuda().eval()\n", + "else:\n", + " model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "25e18510", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:18.827094Z", + "iopub.status.busy": "2024-04-03T06:54:18.826304Z", + "iopub.status.idle": "2024-04-03T06:54:18.830842Z", + "shell.execute_reply": "2024-04-03T06:54:18.829964Z" + }, + "papermill": { + "duration": 0.042234, + "end_time": "2024-04-03T06:54:18.832775", + "exception": false, + "start_time": "2024-04-03T06:54:18.790541", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "acc1_meter, acc5_meter = AverageMeter(), AverageMeter()\n", + "log_print_freq = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1fdfd2e0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-04-03T06:54:18.904983Z", + "iopub.status.busy": "2024-04-03T06:54:18.904128Z", + "iopub.status.idle": "2024-04-03T10:54:14.344321Z", + "shell.execute_reply": "2024-04-03T10:54:14.343208Z" + }, + "papermill": { + "duration": 14395.478546, + "end_time": "2024-04-03T10:54:14.346498", + "exception": false, + "start_time": "2024-04-03T06:54:18.867952", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████| 1238/1238 [3:59:55<00:00, 11.63s/it, Acc1=77.690, Acc5=93.297]\n" + ] + } + ], + "source": [ + "pbar = tqdm(enumerate(dataloader), total=len(dataloader), ncols=80, leave=True)\n", + "\n", + "with torch.no_grad():\n", + " for idx, (image, label) in pbar:\n", + " label_id = label\n", + " label_id = label_id.reshape(-1)\n", + " \n", + " b, n, t, h, w, c = image.size() # batch, clip, time-dim, channel, height, width\n", + " tot_similarity = torch.zeros((b,num_classes)).cuda()\n", + " \n", + " for i in range(n):\n", + " image_input = image[:, i, :, :, :, :] # [b,t,h,w,c]\n", + " label_id = label_id.cuda(non_blocking=True)\n", + " image_input = image_input.cuda(non_blocking=True)\n", + " output = model(image_input)\n", + " similarity = output.view(b, -1).softmax(dim=-1)\n", + " tot_similarity += similarity\n", + " \n", + " values_1, indices_1 = tot_similarity.topk(1, dim=-1)\n", + " values_5, indices_5 = tot_similarity.topk(5, dim=-1)\n", + " acc1, acc5 = 0, 0\n", + " \n", + " for i in range(b):\n", + " if indices_1[i] == label_id[i]:\n", + " acc1 += 1\n", + " if label_id[i] in indices_5[i]:\n", + " acc5 += 1\n", + " \n", + " acc1_meter.update(float(acc1) / b * 100, b)\n", + " acc5_meter.update(float(acc5) / b * 100, b)\n", + " \n", + " if idx % log_print_freq == 0:\n", + " pbar.set_postfix(\n", + " Acc1=f\"{acc1_meter.avg:.3f}\", Acc5=f\"{acc5_meter.avg:.3f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c50c36d7", + "metadata": { + "papermill": { + "duration": 0.129947, + "end_time": "2024-04-03T10:54:14.609281", + "exception": false, + "start_time": "2024-04-03T10:54:14.479334", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kaggle": { + "accelerator": "gpu", + "dataSources": [ + { + "datasetId": 3721472, + "sourceId": 6446831, + "sourceType": "datasetVersion" + } + ], + "dockerImageVersionId": 30673, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 14578.408157, + "end_time": "2024-04-03T10:54:19.452722", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-04-03T06:51:21.044565", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/guides/kerascv-torchvision-video-swin-logits-backbone.ipynb b/guides/kerascv-torchvision-video-swin-logits-backbone.ipynb new file mode 100644 index 0000000..977b801 --- /dev/null +++ b/guides/kerascv-torchvision-video-swin-logits-backbone.ipynb @@ -0,0 +1,1114 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d2a41fe2", + "metadata": { + "papermill": { + "duration": 0.008476, + "end_time": "2024-03-31T16:56:33.025353", + "exception": false, + "start_time": "2024-03-31T16:56:33.016877", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# About\n", + "\n", + "This notebook demonstrates the identical results of vidoe swin transformer, imported from `keras-cv` and `torch-vision` libraries. The `keras-cv` version of video swin is implemented in `keras 3`, makes it able to run in multiple backend, i.e. `tensorflow`, `torch`, and `jax`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ab79b749", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-03-31T16:56:33.043892Z", + "iopub.status.busy": "2024-03-31T16:56:33.043098Z", + "iopub.status.idle": "2024-03-31T16:56:34.096735Z", + "shell.execute_reply": "2024-03-31T16:56:34.095522Z" + }, + "papermill": { + "duration": 1.066465, + "end_time": "2024-03-31T16:56:34.099589", + "exception": false, + "start_time": "2024-03-31T16:56:33.033124", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "import os\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "97b1c2e6", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:56:34.116409Z", + "iopub.status.busy": "2024-03-31T16:56:34.115809Z", + "iopub.status.idle": "2024-03-31T16:56:34.121637Z", + "shell.execute_reply": "2024-03-31T16:56:34.120415Z" + }, + "papermill": { + "duration": 0.016785, + "end_time": "2024-03-31T16:56:34.123892", + "exception": false, + "start_time": "2024-03-31T16:56:34.107107", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "os.environ[\"KERAS_BACKEND\"] = \"torch\" # 'torch', 'tensorflow', 'jax'\n", + "warnings.simplefilter(action=\"ignore\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f3fe802e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:56:34.139998Z", + "iopub.status.busy": "2024-03-31T16:56:34.139606Z", + "iopub.status.idle": "2024-03-31T16:57:12.230426Z", + "shell.execute_reply": "2024-03-31T16:57:12.229322Z" + }, + "papermill": { + "duration": 38.102278, + "end_time": "2024-03-31T16:57:12.233346", + "exception": false, + "start_time": "2024-03-31T16:56:34.131068", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13766, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1903/1903), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (760/760), done.\u001b[K\r\n", + "remote: Total 13766 (delta 1325), reused 1617 (delta 1127), pack-reused 11863\u001b[K\r\n", + "Receiving objects: 100% (13766/13766), 25.64 MiB | 27.10 MiB/s, done.\r\n", + "Resolving deltas: 100% (9776/9776), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "markdown", + "id": "7d595aeb", + "metadata": { + "papermill": { + "duration": 0.009127, + "end_time": "2024-03-31T16:57:12.252522", + "exception": false, + "start_time": "2024-03-31T16:57:12.243395", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# KerasCV: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "27594587", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:12.273806Z", + "iopub.status.busy": "2024-03-31T16:57:12.273338Z", + "iopub.status.idle": "2024-03-31T16:57:36.167429Z", + "shell.execute_reply": "2024-03-31T16:57:36.166267Z" + }, + "papermill": { + "duration": 23.908121, + "end_time": "2024-03-31T16:57:36.169910", + "exception": false, + "start_time": "2024-03-31T16:57:12.261789", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'3.0.5'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "\n", + "keras.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0b342441", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:36.193323Z", + "iopub.status.busy": "2024-03-31T16:57:36.191636Z", + "iopub.status.idle": "2024-03-31T16:57:38.404512Z", + "shell.execute_reply": "2024-03-31T16:57:38.403189Z" + }, + "papermill": { + "duration": 2.227219, + "end_time": "2024-03-31T16:57:38.407364", + "exception": false, + "start_time": "2024-03-31T16:57:36.180145", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:57:37-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.121.3\r\n", + "Connecting to github.com (github.com)|140.82.121.3|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/5153e756-236b-41e7-a602-ab854a57034f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165737Z&X-Amz-Expires=300&X-Amz-Signature=6188cd48f4cffee2ddbc4c5a3c8e4701e15588c12e3446ed5b8a52a002072164&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_tiny_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:57:37-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/5153e756-236b-41e7-a602-ab854a57034f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165737Z&X-Amz-Expires=300&X-Amz-Signature=6188cd48f4cffee2ddbc4c5a3c8e4701e15588c12e3446ed5b8a52a002072164&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_tiny_kinetics400.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 111855496 (107M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_tiny_kinetics400.weights.h5'\r\n", + "\r\n", + "videoswin_tiny_kine 100%[===================>] 106.67M 230MB/s in 0.5s \r\n", + "\r\n", + "2024-03-31 16:57:38 (230 MB/s) - 'videoswin_tiny_kinetics400.weights.h5' saved [111855496/111855496]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5\n", + "\n", + "def vswin_tiny():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_tiny_kinetics400.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "314cb4d0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:38.430540Z", + "iopub.status.busy": "2024-03-31T16:57:38.430074Z", + "iopub.status.idle": "2024-03-31T16:57:40.912946Z", + "shell.execute_reply": "2024-03-31T16:57:40.911656Z" + }, + "papermill": { + "duration": 2.498091, + "end_time": "2024-03-31T16:57:40.915960", + "exception": false, + "start_time": "2024-03-31T16:57:38.417869", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:57:39-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.121.4\r\n", + "Connecting to github.com (github.com)|140.82.121.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/898b24c6-f517-4b01-872b-8f19acd2c54d?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165739Z&X-Amz-Expires=300&X-Amz-Signature=ebd091cee3f64c57654966b81170827c3c667d8bab72a6b4969a987561926f0f&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_small_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:57:39-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/898b24c6-f517-4b01-872b-8f19acd2c54d?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165739Z&X-Amz-Expires=300&X-Amz-Signature=ebd091cee3f64c57654966b81170827c3c667d8bab72a6b4969a987561926f0f&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_small_kinetics400.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 198902800 (190M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_small_kinetics400.weights.h5'\r\n", + "\r\n", + "videoswin_small_kin 100%[===================>] 189.69M 237MB/s in 0.8s \r\n", + "\r\n", + "2024-03-31 16:57:40 (237 MB/s) - 'videoswin_small_kinetics400.weights.h5' saved [198902800/198902800]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5\n", + "\n", + "def vswin_small():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_small_kinetics400.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "30973043", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:40.941747Z", + "iopub.status.busy": "2024-03-31T16:57:40.941255Z", + "iopub.status.idle": "2024-03-31T16:57:44.047679Z", + "shell.execute_reply": "2024-03-31T16:57:44.046374Z" + }, + "papermill": { + "duration": 3.123539, + "end_time": "2024-03-31T16:57:44.050740", + "exception": false, + "start_time": "2024-03-31T16:57:40.927201", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:57:41-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.121.4\r\n", + "Connecting to github.com (github.com)|140.82.121.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/d5a7b9f0-78b7-4151-b1d3-ddba5c66c7c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165742Z&X-Amz-Expires=300&X-Amz-Signature=4dbdc130edd48a081675524290e07e70aa8d48bcd5862bb80d8f15b72903abdd&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:57:42-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/d5a7b9f0-78b7-4151-b1d3-ddba5c66c7c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165742Z&X-Amz-Expires=300&X-Amz-Signature=4dbdc130edd48a081675524290e07e70aa8d48bcd5862bb80d8f15b72903abdd&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 351381896 (335M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_kinetics400.weights.h5'\r\n", + "\r\n", + "videoswin_base_kine 100%[===================>] 335.10M 241MB/s in 1.4s \r\n", + "\r\n", + "2024-03-31 16:57:43 (241 MB/s) - 'videoswin_base_kinetics400.weights.h5' saved [351381896/351381896]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5\n", + "\n", + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_base_kinetics400.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e3bdd3ec", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:44.077645Z", + "iopub.status.busy": "2024-03-31T16:57:44.077210Z", + "iopub.status.idle": "2024-03-31T16:57:51.781985Z", + "shell.execute_reply": "2024-03-31T16:57:51.780643Z" + }, + "papermill": { + "duration": 7.721169, + "end_time": "2024-03-31T16:57:51.784182", + "exception": false, + "start_time": "2024-03-31T16:57:44.063013", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"video_swin_backbone\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"video_swin_backbone\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ videos (InputLayer)             │ (None, 32, 224, 224,   │             0 │\n",
+       "│                                 │ 3)                     │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_patching_and_embeddi… │ (None, 16, 56, 56, 96) │         9,504 │\n",
+       "│ (VideoSwinPatchingAndEmbedding) │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ pos_drop (Dropout)              │ (None, 16, 56, 56, 96) │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_basic_layer_1         │ (None, 16, 28, 28,     │       313,386 │\n",
+       "│ (VideoSwinBasicLayer)           │ 192)                   │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_basic_layer_2         │ (None, 16, 14, 14,     │     1,216,596 │\n",
+       "│ (VideoSwinBasicLayer)           │ 384)                   │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_basic_layer_3         │ (None, 16, 7, 7, 768)  │    12,012,024 │\n",
+       "│ (VideoSwinBasicLayer)           │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_basic_layer_4         │ (None, 16, 7, 7, 768)  │    14,297,424 │\n",
+       "│ (VideoSwinBasicLayer)           │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ videoswin_top_norm              │ (None, 16, 7, 7, 768)  │         1,536 │\n",
+       "│ (LayerNormalization)            │                        │               │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ videos (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_patching_and_embeddi… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m56\u001b[0m, \u001b[38;5;34m56\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m9,504\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinPatchingAndEmbedding\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ pos_drop (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m56\u001b[0m, \u001b[38;5;34m56\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_basic_layer_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, │ \u001b[38;5;34m313,386\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBasicLayer\u001b[0m) │ \u001b[38;5;34m192\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_basic_layer_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, │ \u001b[38;5;34m1,216,596\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBasicLayer\u001b[0m) │ \u001b[38;5;34m384\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_basic_layer_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m12,012,024\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBasicLayer\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_basic_layer_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m14,297,424\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBasicLayer\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ videoswin_top_norm │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m1,536\u001b[0m │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 27,850,470 (106.24 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m27,850,470\u001b[0m (106.24 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 27,850,470 (106.24 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m27,850,470\u001b[0m (106.24 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "keras_models = [vswin_tiny(), vswin_small(), vswin_base()]\n", + "keras_models[0].summary()" + ] + }, + { + "cell_type": "markdown", + "id": "93c03854", + "metadata": { + "papermill": { + "duration": 0.012956, + "end_time": "2024-03-31T16:57:51.810513", + "exception": false, + "start_time": "2024-03-31T16:57:51.797557", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# TorchVision: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "649b28ba", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:51.839013Z", + "iopub.status.busy": "2024-03-31T16:57:51.838285Z", + "iopub.status.idle": "2024-03-31T16:57:52.183960Z", + "shell.execute_reply": "2024-03-31T16:57:52.182857Z" + }, + "papermill": { + "duration": 0.362617, + "end_time": "2024-03-31T16:57:52.186447", + "exception": false, + "start_time": "2024-03-31T16:57:51.823830", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torchinfo import summary\n", + "from torchvision.models.video import Swin3D_T_Weights, Swin3D_S_Weights, Swin3D_B_Weights" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b6cb791e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:52.216057Z", + "iopub.status.busy": "2024-03-31T16:57:52.215643Z", + "iopub.status.idle": "2024-03-31T16:57:52.224782Z", + "shell.execute_reply": "2024-03-31T16:57:52.223537Z" + }, + "papermill": { + "duration": 0.02691, + "end_time": "2024-03-31T16:57:52.227103", + "exception": false, + "start_time": "2024-03-31T16:57:52.200193", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def exclude_top(model):\n", + " backbone = torch.nn.Sequential(\n", + " *(list(model.children())[:-2])\n", + " )\n", + " backbone.eval()\n", + " return backbone\n", + "\n", + "def torch_vswin_tiny():\n", + " torch_model = torchvision.models.video.swin3d_t(\n", + " weights=Swin3D_T_Weights.KINETICS400_V1\n", + " ).eval()\n", + " backbone = exclude_top(torch_model)\n", + " return backbone\n", + "\n", + "def torch_vswin_small():\n", + " torch_model = torchvision.models.video.swin3d_s(\n", + " weights=Swin3D_S_Weights.KINETICS400_V1\n", + " ).eval()\n", + " backbone = exclude_top(torch_model)\n", + " return backbone\n", + "\n", + "def torch_vswin_base():\n", + " torch_model = torchvision.models.video.swin3d_b(\n", + " weights=Swin3D_B_Weights.KINETICS400_V1\n", + " ).eval()\n", + " backbone = exclude_top(torch_model)\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0d69b1bc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:57:52.256949Z", + "iopub.status.busy": "2024-03-31T16:57:52.256159Z", + "iopub.status.idle": "2024-03-31T16:58:14.269808Z", + "shell.execute_reply": "2024-03-31T16:58:14.268431Z" + }, + "papermill": { + "duration": 22.031378, + "end_time": "2024-03-31T16:58:14.272471", + "exception": false, + "start_time": "2024-03-31T16:57:52.241093", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/swin3d_t-7615ae03.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth\n", + "100%|██████████| 122M/122M [00:00<00:00, 131MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_s-da41c237.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth\n", + "100%|██████████| 218M/218M [00:06<00:00, 37.7MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_b_1k-24f7c7c6.pth\n", + "100%|██████████| 364M/364M [00:02<00:00, 135MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "=========================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "=========================================================================================================\n", + "Sequential [1, 16, 7, 7, 768] --\n", + "├─PatchEmbed3d: 1-1 [1, 16, 56, 56, 96] --\n", + "│ └─Conv3d: 2-1 [1, 96, 16, 56, 56] 9,312\n", + "│ └─LayerNorm: 2-2 [1, 16, 56, 56, 96] 192\n", + "├─Dropout: 1-2 [1, 16, 56, 56, 96] --\n", + "├─Sequential: 1-3 [1, 16, 7, 7, 768] --\n", + "│ └─Sequential: 2-3 [1, 16, 56, 56, 96] --\n", + "│ │ └─SwinTransformerBlock: 3-1 [1, 16, 56, 56, 96] 119,445\n", + "│ │ └─SwinTransformerBlock: 3-2 [1, 16, 56, 56, 96] 119,445\n", + "│ └─PatchMerging: 2-4 [1, 16, 28, 28, 192] --\n", + "│ │ └─LayerNorm: 3-3 [1, 16, 28, 28, 384] 768\n", + "│ │ └─Linear: 3-4 [1, 16, 28, 28, 192] 73,728\n", + "│ └─Sequential: 2-5 [1, 16, 28, 28, 192] --\n", + "│ │ └─SwinTransformerBlock: 3-5 [1, 16, 28, 28, 192] 460,074\n", + "│ │ └─SwinTransformerBlock: 3-6 [1, 16, 28, 28, 192] 460,074\n", + "│ └─PatchMerging: 2-6 [1, 16, 14, 14, 384] --\n", + "│ │ └─LayerNorm: 3-7 [1, 16, 14, 14, 768] 1,536\n", + "│ │ └─Linear: 3-8 [1, 16, 14, 14, 384] 294,912\n", + "│ └─Sequential: 2-7 [1, 16, 14, 14, 384] --\n", + "│ │ └─SwinTransformerBlock: 3-9 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-10 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-11 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-12 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-13 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-14 [1, 16, 14, 14, 384] 1,804,884\n", + "│ └─PatchMerging: 2-8 [1, 16, 7, 7, 768] --\n", + "│ │ └─LayerNorm: 3-15 [1, 16, 7, 7, 1536] 3,072\n", + "│ │ └─Linear: 3-16 [1, 16, 7, 7, 768] 1,179,648\n", + "│ └─Sequential: 2-9 [1, 16, 7, 7, 768] --\n", + "│ │ └─SwinTransformerBlock: 3-17 [1, 16, 7, 7, 768] 7,148,712\n", + "│ │ └─SwinTransformerBlock: 3-18 [1, 16, 7, 7, 768] 7,148,712\n", + "├─LayerNorm: 1-4 [1, 16, 7, 7, 768] 1,536\n", + "=========================================================================================================\n", + "Total params: 27,850,470\n", + "Trainable params: 27,850,470\n", + "Non-trainable params: 0\n", + "Total mult-adds (M): 486.09\n", + "=========================================================================================================\n", + "Input size (MB): 19.27\n", + "Forward/backward pass size (MB): 1464.34\n", + "Params size (MB): 75.43\n", + "Estimated Total Size (MB): 1559.03\n", + "=========================================================================================================" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_models = [torch_vswin_tiny(), torch_vswin_small(), torch_vswin_base()]\n", + "summary(\n", + " torch_models[0], input_size=(1, 3, 32, 224, 224)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8f5f856f", + "metadata": { + "papermill": { + "duration": 0.022661, + "end_time": "2024-03-31T16:58:14.318044", + "exception": false, + "start_time": "2024-03-31T16:58:14.295383", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "68608e40", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:58:14.365429Z", + "iopub.status.busy": "2024-03-31T16:58:14.364979Z", + "iopub.status.idle": "2024-03-31T16:58:14.532780Z", + "shell.execute_reply": "2024-03-31T16:58:14.531291Z" + }, + "papermill": { + "duration": 0.194807, + "end_time": "2024-03-31T16:58:14.535575", + "exception": false, + "start_time": "2024-03-31T16:58:14.340768", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])\n" + ] + } + ], + "source": [ + "common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')\n", + "keras_input = ops.array(common_input)\n", + "torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))\n", + "print(keras_input.shape, torch_input.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c4ade2b5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:58:14.583908Z", + "iopub.status.busy": "2024-03-31T16:58:14.583167Z", + "iopub.status.idle": "2024-03-31T16:58:14.589488Z", + "shell.execute_reply": "2024-03-31T16:58:14.588707Z" + }, + "papermill": { + "duration": 0.033121, + "end_time": "2024-03-31T16:58:14.591746", + "exception": false, + "start_time": "2024-03-31T16:58:14.558625", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def logit_checking(keras_model, torch_model):\n", + " # forward pass\n", + " keras_predict = keras_model(keras_input)\n", + " torch_predict = torch_model(torch_input) \n", + " print(keras_predict.shape, torch_predict.shape)\n", + " np.testing.assert_allclose(\n", + " keras_predict.detach().numpy(),\n", + " torch_predict.detach().numpy(),\n", + " 1e-4, 1e-4\n", + " )\n", + " del keras_model \n", + " del torch_model" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8930bfc5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:58:14.640567Z", + "iopub.status.busy": "2024-03-31T16:58:14.639913Z", + "iopub.status.idle": "2024-03-31T16:59:23.764724Z", + "shell.execute_reply": "2024-03-31T16:59:23.763372Z" + }, + "papermill": { + "duration": 69.152659, + "end_time": "2024-03-31T16:59:23.767439", + "exception": false, + "start_time": "2024-03-31T16:58:14.614780", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 16, 7, 7, 768]) torch.Size([1, 16, 7, 7, 768])\n", + "torch.Size([1, 16, 7, 7, 768]) torch.Size([1, 16, 7, 7, 768])\n", + "torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])\n" + ] + } + ], + "source": [ + "for km, tm in zip(keras_models, torch_models):\n", + " logit_checking(\n", + " km, tm\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "68938336", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:59:23.817302Z", + "iopub.status.busy": "2024-03-31T16:59:23.816843Z", + "iopub.status.idle": "2024-03-31T16:59:24.215821Z", + "shell.execute_reply": "2024-03-31T16:59:24.214493Z" + }, + "papermill": { + "duration": 0.42734, + "end_time": "2024-03-31T16:59:24.218241", + "exception": false, + "start_time": "2024-03-31T16:59:23.790901", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "40" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "55b12581", + "metadata": { + "papermill": { + "duration": 0.023243, + "end_time": "2024-03-31T16:59:24.265940", + "exception": false, + "start_time": "2024-03-31T16:59:24.242697", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Keras: Video Swin Base - Pretrained: ImageNet 22K" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "43d4625f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:59:24.315830Z", + "iopub.status.busy": "2024-03-31T16:59:24.315022Z", + "iopub.status.idle": "2024-03-31T16:59:32.786393Z", + "shell.execute_reply": "2024-03-31T16:59:32.784606Z" + }, + "papermill": { + "duration": 8.49938, + "end_time": "2024-03-31T16:59:32.789273", + "exception": false, + "start_time": "2024-03-31T16:59:24.289893", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-31 16:59:25-- https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5\r\n", + "Resolving github.com (github.com)... 140.82.121.4\r\n", + "Connecting to github.com (github.com)|140.82.121.4|:443... connected.\r\n", + "HTTP request sent, awaiting response... 302 Found\r\n", + "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/75b53567-f9ae-4739-87c1-0d5d9d423f25?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165925Z&X-Amz-Expires=300&X-Amz-Signature=b3146437a5644138b963a8e376f8cc066e3ff2c0b4bb7a05e1f705e930095453&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400_imagenet22k.weights.h5&response-content-type=application%2Foctet-stream [following]\r\n", + "--2024-03-31 16:59:25-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/75b53567-f9ae-4739-87c1-0d5d9d423f25?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165925Z&X-Amz-Expires=300&X-Amz-Signature=b3146437a5644138b963a8e376f8cc066e3ff2c0b4bb7a05e1f705e930095453&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400_imagenet22k.weights.h5&response-content-type=application%2Foctet-stream\r\n", + "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\r\n", + "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\r\n", + "HTTP request sent, awaiting response... 200 OK\r\n", + "Length: 351381896 (335M) [application/octet-stream]\r\n", + "Saving to: 'videoswin_base_kinetics400_imagenet22k.weights.h5'\r\n", + "\r\n", + "videoswin_base_kine 100%[===================>] 335.10M 47.1MB/s in 6.5s \r\n", + "\r\n", + "2024-03-31 16:59:32 (51.9 MB/s) - 'videoswin_base_kinetics400_imagenet22k.weights.h5' saved [351381896/351381896]\r\n", + "\r\n" + ] + } + ], + "source": [ + "!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5\n", + "\n", + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " backbone.load_weights(\n", + " 'videoswin_base_kinetics400_imagenet22k.weights.h5'\n", + " )\n", + " return backbone" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2da61fbf", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:59:32.846326Z", + "iopub.status.busy": "2024-03-31T16:59:32.845846Z", + "iopub.status.idle": "2024-03-31T16:59:35.458124Z", + "shell.execute_reply": "2024-03-31T16:59:35.456891Z" + }, + "papermill": { + "duration": 2.644126, + "end_time": "2024-03-31T16:59:35.461400", + "exception": false, + "start_time": "2024-03-31T16:59:32.817274", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "keras_models = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "29944adc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:59:35.518541Z", + "iopub.status.busy": "2024-03-31T16:59:35.518114Z", + "iopub.status.idle": "2024-03-31T16:59:58.027938Z", + "shell.execute_reply": "2024-03-31T16:59:58.026294Z" + }, + "papermill": { + "duration": 22.542295, + "end_time": "2024-03-31T16:59:58.031557", + "exception": false, + "start_time": "2024-03-31T16:59:35.489262", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_b_22k-7c6ae6fa.pth\n", + "100%|██████████| 364M/364M [00:19<00:00, 19.7MB/s]\n" + ] + } + ], + "source": [ + "import torchvision\n", + "from torchvision.models.video import Swin3D_B_Weights\n", + "\n", + "torch_model = torchvision.models.video.swin3d_b(\n", + " weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1\n", + ").eval()\n", + "torch_model = exclude_top(torch_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5bb18406", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-31T16:59:58.125789Z", + "iopub.status.busy": "2024-03-31T16:59:58.124570Z", + "iopub.status.idle": "2024-03-31T17:00:31.735065Z", + "shell.execute_reply": "2024-03-31T17:00:31.733788Z" + }, + "papermill": { + "duration": 33.660037, + "end_time": "2024-03-31T17:00:31.738443", + "exception": false, + "start_time": "2024-03-31T16:59:58.078406", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " keras_models, torch_model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20f490d2", + "metadata": { + "papermill": { + "duration": 0.042622, + "end_time": "2024-03-31T17:00:31.825841", + "exception": false, + "start_time": "2024-03-31T17:00:31.783219", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kaggle": { + "accelerator": "none", + "dataSources": [ + { + "modelInstanceId": 17431, + "sourceId": 21048, + "sourceType": "modelInstanceVersion" + }, + { + "modelInstanceId": 17474, + "sourceId": 21097, + "sourceType": "modelInstanceVersion" + }, + { + "modelInstanceId": 17533, + "sourceId": 21184, + "sourceType": "modelInstanceVersion" + } + ], + "dockerImageVersionId": 30673, + "isGpuEnabled": false, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 246.289645, + "end_time": "2024-03-31T17:00:36.170724", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-03-31T16:56:29.881079", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/guides/kerascv-torchvision-video-swin-logits-comparison.ipynb b/guides/kerascv-torchvision-video-swin-logits-comparison.ipynb new file mode 100644 index 0000000..2d9fb0b --- /dev/null +++ b/guides/kerascv-torchvision-video-swin-logits-comparison.ipynb @@ -0,0 +1,979 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cc7788a6", + "metadata": { + "papermill": { + "duration": 0.006608, + "end_time": "2024-03-28T18:55:53.448019", + "exception": false, + "start_time": "2024-03-28T18:55:53.441411", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# About\n", + "\n", + "This notebook demonstrates the identical results of vidoe swin transformer, imported from `keras-cv` and `torch-vision` libraries. The `keras-cv` version of video swin is implemented in `keras 3`, makes it able to run in multiple backend, i.e. `tensorflow`, `torch`, and `jax`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "86437e84", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-03-28T18:55:53.461887Z", + "iopub.status.busy": "2024-03-28T18:55:53.461497Z", + "iopub.status.idle": "2024-03-28T18:55:54.343326Z", + "shell.execute_reply": "2024-03-28T18:55:54.342152Z" + }, + "papermill": { + "duration": 0.89173, + "end_time": "2024-03-28T18:55:54.346095", + "exception": false, + "start_time": "2024-03-28T18:55:53.454365", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "import os\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6c8b9d13", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:55:54.359527Z", + "iopub.status.busy": "2024-03-28T18:55:54.359011Z", + "iopub.status.idle": "2024-03-28T18:55:54.364144Z", + "shell.execute_reply": "2024-03-28T18:55:54.363316Z" + }, + "papermill": { + "duration": 0.014217, + "end_time": "2024-03-28T18:55:54.366189", + "exception": false, + "start_time": "2024-03-28T18:55:54.351972", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "os.environ[\"KERAS_BACKEND\"] = \"torch\" # 'torch', 'tensorflow', 'jax'\n", + "\n", + "warnings.simplefilter(action=\"ignore\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9641510f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:55:54.379961Z", + "iopub.status.busy": "2024-03-28T18:55:54.378975Z", + "iopub.status.idle": "2024-03-28T18:56:26.782756Z", + "shell.execute_reply": "2024-03-28T18:56:26.781252Z" + }, + "papermill": { + "duration": 32.413608, + "end_time": "2024-03-28T18:56:26.785582", + "exception": false, + "start_time": "2024-03-28T18:55:54.371974", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'keras-cv'...\r\n", + "remote: Enumerating objects: 13735, done.\u001b[K\r\n", + "remote: Counting objects: 100% (1872/1872), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (752/752), done.\u001b[K\r\n", + "remote: Total 13735 (delta 1297), reused 1587 (delta 1104), pack-reused 11863\u001b[K\r\n", + "Receiving objects: 100% (13735/13735), 25.64 MiB | 31.71 MiB/s, done.\r\n", + "Resolving deltas: 100% (9742/9742), done.\r\n", + "/kaggle/working/keras-cv\n" + ] + } + ], + "source": [ + "!git clone --branch video_swin https://github.com/innat/keras-cv.git\n", + "%cd keras-cv\n", + "!pip install -q -e ." + ] + }, + { + "cell_type": "markdown", + "id": "dbbc08fc", + "metadata": { + "papermill": { + "duration": 0.007, + "end_time": "2024-03-28T18:56:26.800170", + "exception": false, + "start_time": "2024-03-28T18:56:26.793170", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# KerasCV: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a42e9c56", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:56:26.816883Z", + "iopub.status.busy": "2024-03-28T18:56:26.816366Z", + "iopub.status.idle": "2024-03-28T18:56:52.607621Z", + "shell.execute_reply": "2024-03-28T18:56:52.606236Z" + }, + "papermill": { + "duration": 25.803068, + "end_time": "2024-03-28T18:56:52.610477", + "exception": false, + "start_time": "2024-03-28T18:56:26.807409", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'3.0.5'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "from keras import ops\n", + "from keras_cv.models import VideoSwinBackbone\n", + "from keras_cv.models import VideoClassifier\n", + "\n", + "keras.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "af7810e0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:56:52.629367Z", + "iopub.status.busy": "2024-03-28T18:56:52.628675Z", + "iopub.status.idle": "2024-03-28T18:56:52.635618Z", + "shell.execute_reply": "2024-03-28T18:56:52.634364Z" + }, + "papermill": { + "duration": 0.019551, + "end_time": "2024-03-28T18:56:52.638019", + "exception": false, + "start_time": "2024-03-28T18:56:52.618468", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_tiny():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 6, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/input/videoswin/keras/tiny/1/videoswin_tiny_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f27545dc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:56:52.655360Z", + "iopub.status.busy": "2024-03-28T18:56:52.654951Z", + "iopub.status.idle": "2024-03-28T18:56:52.661588Z", + "shell.execute_reply": "2024-03-28T18:56:52.660250Z" + }, + "papermill": { + "duration": 0.018416, + "end_time": "2024-03-28T18:56:52.664262", + "exception": false, + "start_time": "2024-03-28T18:56:52.645846", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_small():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=96,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[3, 6, 12, 24],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/input/videoswin/keras/small/1/videoswin_small_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a6bcb76", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:56:52.682793Z", + "iopub.status.busy": "2024-03-28T18:56:52.681811Z", + "iopub.status.idle": "2024-03-28T18:56:52.688494Z", + "shell.execute_reply": "2024-03-28T18:56:52.687403Z" + }, + "papermill": { + "duration": 0.0189, + "end_time": "2024-03-28T18:56:52.690784", + "exception": false, + "start_time": "2024-03-28T18:56:52.671884", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/input/videoswin/keras/base/1/videoswin_base_kinetics400_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "25687c67", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:56:52.709040Z", + "iopub.status.busy": "2024-03-28T18:56:52.708618Z", + "iopub.status.idle": "2024-03-28T18:57:07.215996Z", + "shell.execute_reply": "2024-03-28T18:57:07.215121Z" + }, + "papermill": { + "duration": 14.519874, + "end_time": "2024-03-28T18:57:07.218056", + "exception": false, + "start_time": "2024-03-28T18:56:52.698182", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"video_classifier\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"video_classifier\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ videos (InputLayer)             │ (None, 32, 224, 224,   │             0 │\n",
+       "│                                 │ 3)                     │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ video_swin_backbone             │ (None, 16, 7, 7, 768)  │    27,850,470 │\n",
+       "│ (VideoSwinBackbone)             │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ avg_pool                        │ (None, 768)            │             0 │\n",
+       "│ (GlobalAveragePooling3D)        │                        │               │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ predictions (Dense)             │ (None, 400)            │       307,600 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ videos (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ video_swin_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m27,850,470\u001b[0m │\n", + "│ (\u001b[38;5;33mVideoSwinBackbone\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ avg_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "│ (\u001b[38;5;33mGlobalAveragePooling3D\u001b[0m) │ │ │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ predictions (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m400\u001b[0m) │ \u001b[38;5;34m307,600\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 28,158,070 (107.41 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 28,158,070 (107.41 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "keras_models = [vswin_tiny(), vswin_small(), vswin_base()]\n", + "keras_models[0].summary()" + ] + }, + { + "cell_type": "markdown", + "id": "544542ba", + "metadata": { + "papermill": { + "duration": 0.00794, + "end_time": "2024-03-28T18:57:07.234196", + "exception": false, + "start_time": "2024-03-28T18:57:07.226256", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# TorchVision: Video Swin : Pretrained: ImageNet 1K" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a8fe6d48", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:57:07.254724Z", + "iopub.status.busy": "2024-03-28T18:57:07.253993Z", + "iopub.status.idle": "2024-03-28T18:57:07.628853Z", + "shell.execute_reply": "2024-03-28T18:57:07.627884Z" + }, + "papermill": { + "duration": 0.388364, + "end_time": "2024-03-28T18:57:07.631527", + "exception": false, + "start_time": "2024-03-28T18:57:07.243163", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torchinfo import summary\n", + "from torchvision.models.video import Swin3D_T_Weights, Swin3D_S_Weights, Swin3D_B_Weights\n", + "\n", + "def torch_vswin_tiny():\n", + " torch_model = torchvision.models.video.swin3d_t(\n", + " weights=Swin3D_T_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model\n", + "\n", + "def torch_vswin_small():\n", + " torch_model = torchvision.models.video.swin3d_s(\n", + " weights=Swin3D_S_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model\n", + "\n", + "def torch_vswin_base():\n", + " torch_model = torchvision.models.video.swin3d_b(\n", + " weights=Swin3D_B_Weights.KINETICS400_V1\n", + " ).eval()\n", + " return torch_model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "baa9604b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:57:07.649608Z", + "iopub.status.busy": "2024-03-28T18:57:07.649212Z", + "iopub.status.idle": "2024-03-28T18:57:34.795055Z", + "shell.execute_reply": "2024-03-28T18:57:34.794149Z" + }, + "papermill": { + "duration": 27.157606, + "end_time": "2024-03-28T18:57:34.797479", + "exception": false, + "start_time": "2024-03-28T18:57:07.639873", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/swin3d_t-7615ae03.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth\n", + "100%|██████████| 122M/122M [00:02<00:00, 54.0MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_s-da41c237.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth\n", + "100%|██████████| 218M/218M [00:04<00:00, 55.4MB/s]\n", + "Downloading: \"https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_b_1k-24f7c7c6.pth\n", + "100%|██████████| 364M/364M [00:06<00:00, 57.0MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "=========================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "=========================================================================================================\n", + "SwinTransformer3d [1, 400] --\n", + "├─PatchEmbed3d: 1-1 [1, 16, 56, 56, 96] --\n", + "│ └─Conv3d: 2-1 [1, 96, 16, 56, 56] 9,312\n", + "│ └─LayerNorm: 2-2 [1, 16, 56, 56, 96] 192\n", + "├─Dropout: 1-2 [1, 16, 56, 56, 96] --\n", + "├─Sequential: 1-3 [1, 16, 7, 7, 768] --\n", + "│ └─Sequential: 2-3 [1, 16, 56, 56, 96] --\n", + "│ │ └─SwinTransformerBlock: 3-1 [1, 16, 56, 56, 96] 119,445\n", + "│ │ └─SwinTransformerBlock: 3-2 [1, 16, 56, 56, 96] 119,445\n", + "│ └─PatchMerging: 2-4 [1, 16, 28, 28, 192] --\n", + "│ │ └─LayerNorm: 3-3 [1, 16, 28, 28, 384] 768\n", + "│ │ └─Linear: 3-4 [1, 16, 28, 28, 192] 73,728\n", + "│ └─Sequential: 2-5 [1, 16, 28, 28, 192] --\n", + "│ │ └─SwinTransformerBlock: 3-5 [1, 16, 28, 28, 192] 460,074\n", + "│ │ └─SwinTransformerBlock: 3-6 [1, 16, 28, 28, 192] 460,074\n", + "│ └─PatchMerging: 2-6 [1, 16, 14, 14, 384] --\n", + "│ │ └─LayerNorm: 3-7 [1, 16, 14, 14, 768] 1,536\n", + "│ │ └─Linear: 3-8 [1, 16, 14, 14, 384] 294,912\n", + "│ └─Sequential: 2-7 [1, 16, 14, 14, 384] --\n", + "│ │ └─SwinTransformerBlock: 3-9 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-10 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-11 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-12 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-13 [1, 16, 14, 14, 384] 1,804,884\n", + "│ │ └─SwinTransformerBlock: 3-14 [1, 16, 14, 14, 384] 1,804,884\n", + "│ └─PatchMerging: 2-8 [1, 16, 7, 7, 768] --\n", + "│ │ └─LayerNorm: 3-15 [1, 16, 7, 7, 1536] 3,072\n", + "│ │ └─Linear: 3-16 [1, 16, 7, 7, 768] 1,179,648\n", + "│ └─Sequential: 2-9 [1, 16, 7, 7, 768] --\n", + "│ │ └─SwinTransformerBlock: 3-17 [1, 16, 7, 7, 768] 7,148,712\n", + "│ │ └─SwinTransformerBlock: 3-18 [1, 16, 7, 7, 768] 7,148,712\n", + "├─LayerNorm: 1-4 [1, 16, 7, 7, 768] 1,536\n", + "├─AdaptiveAvgPool3d: 1-5 [1, 768, 1, 1, 1] --\n", + "├─Linear: 1-6 [1, 400] 307,600\n", + "=========================================================================================================\n", + "Total params: 28,158,070\n", + "Trainable params: 28,158,070\n", + "Non-trainable params: 0\n", + "Total mult-adds (M): 486.39\n", + "=========================================================================================================\n", + "Input size (MB): 19.27\n", + "Forward/backward pass size (MB): 1464.34\n", + "Params size (MB): 76.66\n", + "Estimated Total Size (MB): 1560.26\n", + "=========================================================================================================" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_models = [torch_vswin_tiny(), torch_vswin_small(), torch_vswin_base()]\n", + "summary(\n", + " torch_models[0], input_size=(1, 3, 32, 224, 224)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "967ea820", + "metadata": { + "papermill": { + "duration": 0.015667, + "end_time": "2024-03-28T18:57:34.828907", + "exception": false, + "start_time": "2024-03-28T18:57:34.813240", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "45dcf674", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:57:34.861533Z", + "iopub.status.busy": "2024-03-28T18:57:34.860787Z", + "iopub.status.idle": "2024-03-28T18:57:35.023585Z", + "shell.execute_reply": "2024-03-28T18:57:35.022058Z" + }, + "papermill": { + "duration": 0.182132, + "end_time": "2024-03-28T18:57:35.025987", + "exception": false, + "start_time": "2024-03-28T18:57:34.843855", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])\n" + ] + } + ], + "source": [ + "common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')\n", + "keras_input = ops.array(common_input)\n", + "torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))\n", + "print(keras_input.shape, torch_input.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1758718d", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:57:35.058379Z", + "iopub.status.busy": "2024-03-28T18:57:35.057419Z", + "iopub.status.idle": "2024-03-28T18:57:35.065129Z", + "shell.execute_reply": "2024-03-28T18:57:35.063980Z" + }, + "papermill": { + "duration": 0.026627, + "end_time": "2024-03-28T18:57:35.067490", + "exception": false, + "start_time": "2024-03-28T18:57:35.040863", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def logit_checking(keras_model, torch_model):\n", + " # forward pass\n", + " keras_predict = keras_model(keras_input)\n", + " torch_predict = torch_model(torch_input)\n", + " print(keras_predict.shape, torch_predict.shape)\n", + " print('keras logits: ', keras_predict[0, :5])\n", + " print('torch logits: ', torch_predict[0, :5], end='\\n')\n", + " np.testing.assert_allclose(\n", + " keras_predict.detach().numpy(),\n", + " torch_predict.detach().numpy(),\n", + " 1e-5, 1e-5\n", + " )\n", + " del keras_model \n", + " del torch_model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "68a1e59a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:57:35.100767Z", + "iopub.status.busy": "2024-03-28T18:57:35.100079Z", + "iopub.status.idle": "2024-03-28T18:59:16.615601Z", + "shell.execute_reply": "2024-03-28T18:59:16.613891Z" + }, + "papermill": { + "duration": 101.535287, + "end_time": "2024-03-28T18:59:16.618485", + "exception": false, + "start_time": "2024-03-28T18:57:35.083198", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 400]) torch.Size([1, 400])\n", + "keras logits: tensor([-0.0906, 1.2267, 1.1639, -0.3530, -1.5449], grad_fn=)\n", + "torch logits: tensor([-0.0906, 1.2267, 1.1639, -0.3530, -1.5449], grad_fn=)\n", + "torch.Size([1, 400]) torch.Size([1, 400])\n", + "keras logits: tensor([ 0.6399, 1.2136, 0.9395, -0.4962, -1.9626], grad_fn=)\n", + "torch logits: tensor([ 0.6399, 1.2136, 0.9395, -0.4962, -1.9626], grad_fn=)\n", + "torch.Size([1, 400]) torch.Size([1, 400])\n", + "keras logits: tensor([ 1.1572, 0.0092, 0.0929, -1.8786, -2.8799], grad_fn=)\n", + "torch logits: tensor([ 1.1572, 0.0092, 0.0929, -1.8786, -2.8799], grad_fn=)\n" + ] + } + ], + "source": [ + "for km, tm in zip(keras_models, torch_models):\n", + " logit_checking(\n", + " km, tm\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8b7aefae", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:59:16.653715Z", + "iopub.status.busy": "2024-03-28T18:59:16.653289Z", + "iopub.status.idle": "2024-03-28T18:59:17.266830Z", + "shell.execute_reply": "2024-03-28T18:59:17.265516Z" + }, + "papermill": { + "duration": 0.633776, + "end_time": "2024-03-28T18:59:17.269193", + "exception": false, + "start_time": "2024-03-28T18:59:16.635417", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "27" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "e73db7f4", + "metadata": { + "papermill": { + "duration": 0.01687, + "end_time": "2024-03-28T18:59:17.302135", + "exception": false, + "start_time": "2024-03-28T18:59:17.285265", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Keras: Video Swin Base - Pretrained: ImageNet 22K" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "ef9221ed", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:59:17.334873Z", + "iopub.status.busy": "2024-03-28T18:59:17.334394Z", + "iopub.status.idle": "2024-03-28T18:59:17.341984Z", + "shell.execute_reply": "2024-03-28T18:59:17.340602Z" + }, + "papermill": { + "duration": 0.027379, + "end_time": "2024-03-28T18:59:17.344960", + "exception": false, + "start_time": "2024-03-28T18:59:17.317581", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def vswin_base():\n", + " backbone=VideoSwinBackbone(\n", + " input_shape=(32, 224, 224, 3), \n", + " embed_dim=128,\n", + " depths=[2, 2, 18, 2],\n", + " num_heads=[4, 8, 16, 32],\n", + " include_rescaling=False, \n", + " )\n", + " keras_model = VideoClassifier(\n", + " backbone=backbone,\n", + " num_classes=400,\n", + " activation=None,\n", + " pooling='avg',\n", + " )\n", + " keras_model.load_weights(\n", + " '/kaggle/input/videoswin/keras/base/1/videoswin_base_kinetics400_imagenet22k_classifier.weights.h5'\n", + " )\n", + " return keras_model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1f8d0375", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:59:17.380760Z", + "iopub.status.busy": "2024-03-28T18:59:17.380288Z", + "iopub.status.idle": "2024-03-28T18:59:23.407797Z", + "shell.execute_reply": "2024-03-28T18:59:23.406765Z" + }, + "papermill": { + "duration": 6.048239, + "end_time": "2024-03-28T18:59:23.411009", + "exception": false, + "start_time": "2024-03-28T18:59:17.362770", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "keras_models = vswin_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "25e03a40", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:59:23.444463Z", + "iopub.status.busy": "2024-03-28T18:59:23.444048Z", + "iopub.status.idle": "2024-03-28T18:59:33.888855Z", + "shell.execute_reply": "2024-03-28T18:59:33.887528Z" + }, + "papermill": { + "duration": 10.465804, + "end_time": "2024-03-28T18:59:33.892692", + "exception": false, + "start_time": "2024-03-28T18:59:23.426888", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth\" to /root/.cache/torch/hub/checkpoints/swin3d_b_22k-7c6ae6fa.pth\n", + "100%|██████████| 364M/364M [00:07<00:00, 51.8MB/s]\n" + ] + } + ], + "source": [ + "import torchvision\n", + "from torchvision.models.video import Swin3D_B_Weights\n", + "\n", + "torch_model = torchvision.models.video.swin3d_b(\n", + " weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1\n", + ").eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "908ae048", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-28T18:59:33.936777Z", + "iopub.status.busy": "2024-03-28T18:59:33.935944Z", + "iopub.status.idle": "2024-03-28T19:00:04.235561Z", + "shell.execute_reply": "2024-03-28T19:00:04.234520Z" + }, + "papermill": { + "duration": 30.323177, + "end_time": "2024-03-28T19:00:04.238763", + "exception": false, + "start_time": "2024-03-28T18:59:33.915586", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 400]) torch.Size([1, 400])\n", + "keras logits: tensor([ 0.2773, 0.8488, 1.4034, -1.0703, -1.4610], grad_fn=)\n", + "torch logits: tensor([ 0.2773, 0.8488, 1.4034, -1.0703, -1.4610], grad_fn=)\n" + ] + } + ], + "source": [ + "logit_checking(\n", + " keras_models, torch_model\n", + ")" + ] + } + ], + "metadata": { + "kaggle": { + "accelerator": "none", + "dataSources": [ + { + "isSourceIdPinned": true, + "modelInstanceId": 17431, + "sourceId": 21048, + "sourceType": "modelInstanceVersion" + }, + { + "isSourceIdPinned": true, + "modelInstanceId": 17474, + "sourceId": 21097, + "sourceType": "modelInstanceVersion" + }, + { + "isSourceIdPinned": true, + "modelInstanceId": 17533, + "sourceId": 21184, + "sourceType": "modelInstanceVersion" + } + ], + "dockerImageVersionId": 30673, + "isGpuEnabled": false, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 257.248119, + "end_time": "2024-03-28T19:00:07.884277", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-03-28T18:55:50.636158", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}