diff --git a/.gitignore b/.gitignore index b1bffcb..b07a8d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ spliceai_train_code/ -note/.ipynb_checkpoints +data/ +note/.ipynb_checkpoints/ +*.egg-info/ + +**/__pycache__/ diff --git a/note/investigating_dataset_h5.ipynb b/note/investigating_dataset_h5.ipynb index 10c6835..69a6bdc 100644 --- a/note/investigating_dataset_h5.ipynb +++ b/note/investigating_dataset_h5.ipynb @@ -7,12 +7,13 @@ "outputs": [], "source": [ "import h5py\n", - "import numpy as np" + "import numpy as np\n", + "import torch" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -48,42 +49,68 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "torch.Size([5662, 5080, 4])" ] }, - "execution_count": 15, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "hf['X0']" + "torch.from_numpy(hf['X0'][:]).shape" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "tensor([[0., 0., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " ...,\n", + " [0., 1., 0., 0.],\n", + " [0., 0., 0., 1.],\n", + " [1., 0., 0., 0.]])" ] }, - "execution_count": 16, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "hf['Y0']" + "torch.from_numpy(hf['X0'][:])[0].float()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 5662, 5000, 3])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.from_numpy(hf['Y0'][:]).shape" ] }, { diff --git a/note/multistep_lr_test.ipynb b/note/multistep_lr_test.ipynb new file mode 100644 index 0000000..a01822f --- /dev/null +++ b/note/multistep_lr_test.ipynb @@ -0,0 +1,96 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.optim as optim\n", + "from torch.optim.lr_scheduler import MultiStepLR\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD4CAYAAAAHHSreAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAeeElEQVR4nO3de3hU9b3v8fc395BAAuQyAQIBCUISUDAiCvVSoASlgnqs2NPKqT1FWvS057S7otvebbft3m09WqvFdrdY21KqUrMVoYBYKwUFvAAhXGJQEi4hAQG5hpDf/iPLNiQhGZKQNZP5vJ4nz8ys+f1mPmse8TOz1po15pxDRESksSi/A4iISOhROYiISDMqBxERaUblICIizagcRESkmRi/A3SGtLQ0l5OT0665dXV1AMTE+PdShEIG5VCOUM+gHJ2fY8OGDTXOufSW7usW5ZCTk8P69evbNbempgaAtLS0zowUdhmUQzlCPYNydH4OM3v/XPdps5KIiDSjchARkWZUDiIi0ozKQUREmlE5iIhIM0GVg5kVmdk2Myszs3kt3G9m9oh3/0YzG9PWXDO71cxKzKzezAqbPN593vhtZjalIysoIiLnr81yMLNo4DFgKpAH3G5meU2GTQVyvb/ZwONBzN0M3Ay82uT58oCZQD5QBPzcexwREekiwXzPYSxQ5pwrBzCzhcB0YEujMdOBp1zD+b/XmlmqmWUBOeea65wr9ZY1fb7pwELn3Clgp5mVeRnWtG8Vz23f4ZOsLNkHQNWpA5398EHLjD/tewblCN0c/RLruHJIGv4eUS+RJphy6A9UNLpdCVwRxJj+Qc5t6fnWtvBYZzGz2TR8SmHgwIFtPGTLqo6cZEXpfgBeef9Eux6jM1w7KNH3DMoRwjkGJnKmHgb1z/Q1h0SWYMqh2Vt7oOkvBJ1rTDBz2/N8OOfmA/MBCgsL2/WLRZdkp/Lv/2MUoG+fKkfo5vjuM6+zefdhXzNI5Almh3QlkN3o9gBgT5BjgpnbnucTiRgj+6ew7/BJdtYc8zuKRJBgymEdkGtmg80sjoadxcVNxhQDd3hHLY0DDjvn9gY5t6liYKaZxZvZYBp2cr9xHusk0q0U9E8BYOnmfT4nkUjSZjk45+qAu4FlQCmwyDlXYmZzzGyON2wJUA6UAU8CX2ptLoCZ3WRmlcCVwItmtsybUwIsomGH91JgrnPuTCetr0jY6d0jluw+iSwtUTlI1wnqrKzOuSU0FEDjZU80uu6AucHO9ZYvBhafY873ge8Hk00kEhT0T+Gpt3ay59AJ+qUm+h1HIoC+IS0SBkZ6m5b+ok8P0kVUDiJhIL1nPMMyk7VpSbqMykEkTBTlB3hj50EOHD3ldxSJACoHkTAxpSBAvYMVpVV+R5EIoHIQCRN5Wb3I7pPISzqkVbqAykEkTJgZRfkBVpfVcOTkab/jSDenchAJI0UFAU6fcazaut/vKNLNqRxEwsjo7N5k9IzXt6XlglM5iISRqChjSn6AV7ZVc6JWJw6QC0flIBJmigoCnDh9hld3VPsdRboxlYNImBk7uA+pPWJZpk1LcgGpHETCTGx0FJNGZLKitIraunq/40g3pXIQCUNF+QGOnKxjTbm/P2Eq3ZfKQSQMTchNo0dctI5akgtG5SAShhJio7lueAbLt+zjTH27fiVXpFUqB5EwNbUgQM3RWja8/4HfUaQbUjmIhKlrL84gLiZKm5bkglA5iISp5PgYrs5NY1nJPhp+jFGk86gcRMLYlPwAuw+dYPPuI35HkW5G5SASxiaNyCQ6ylhastfvKNLNqBxEwljvpDjGDemj/Q7S6VQOImGuKD/Au9XH2FH1od9RpBtROYiEuU/kBwD06UE6lcpBJMxl9krgskG9WVqicpDOo3IQ6QaK8gOU7DlCxcHjfkeRbkLlININTPE2LS3TpwfpJCoHkW5gYN8e5GX10n4H6TQqB5FuoqggwIZdH7D/yEm/o0g3oHIQ6SaKCgI4B3/ZUuV3FOkGVA4i3URuRjJD0pK0aUk6hcpBpJswM6YUBFhTfoBDx2v9jiNhTuUg0o0U5Qc4U+9YUbrf7ygS5lQOIt3IqAEp9EtJ0KYl6bCgysHMisxsm5mVmdm8Fu43M3vEu3+jmY1pa66Z9TGz5Wa2w7vs7S2PNbMFZrbJzErN7L7OWFGRSPDRpqVXd1Rz7FSd33EkjLVZDmYWDTwGTAXygNvNLK/JsKlArvc3G3g8iLnzgJXOuVxgpXcb4FYg3jk3ErgMuMvMctq7giKRpig/QG1dPa9sq/Y7ioSxYD45jAXKnHPlzrlaYCEwvcmY6cBTrsFaINXMstqYOx1Y4F1fAMzwrjsgycxigESgFtAvmYgEqTCnD32T4nSuJemQYMqhP1DR6HaltyyYMa3NzXTO7QXwLjO85c8Ax4C9wC7gP5xzB5uGMrPZZrbezNZXV+sdkshHoqOMT+Rn8nJpFSdPn/E7joSpYMrBWljW9AdrzzUmmLlNjQXOAP2AwcBXzWxIswdxbr5zrtA5V5ient7GQ4pElin5AY7VnmF1WY3fUSRMBVMOlUB2o9sDgD1BjmltbpW36Qnv8qNj7z4NLHXOnXbO7QdWA4VB5BQRz1UXpdEzPkZHLUm7BVMO64BcMxtsZnHATKC4yZhi4A7vqKVxwGFvU1Frc4uBWd71WcDz3vVdwMe9x0oCxgFb27l+IhEpLiaKiSMyWF5aRd2Zer/jSBhqsxycc3XA3cAyoBRY5JwrMbM5ZjbHG7YEKAfKgCeBL7U215vzEDDZzHYAk73b0HB0UzKwmYZy+bVzbmNHV1Qk0hQVZHHo+Gne2Nlsl51Im2KCGeScW0JDATRe9kSj6w6YG+xcb/kBYGILy4/ScDiriHTANcPSSYiNYmnJPq4amuZ3HAkz+oa0SDeVGBfNtcMyWFayj/r6to4DETmbykGkGysqCFB15BRvVx7yO4qEGZWDSDd23fAMYqONZTpqSc6TykGkG0tJjOWqi9J4afM+GnYNigRH5SDSzRUVBNh18Dilez/0O4qEEZWDSDc3OS+TKEPnWpLzonIQ6ebSkuO5PKeP9jvIeVE5iESAooIA26o+pLz6qN9RJEyoHEQiwJT8AADLSqp8TiLhQuUgEgH6pSZyyYAU7XeQoKkcRCLElIIA71QcYs+hE35HkTCgchCJEEXepqW/6NODBEHlIBIhhqQnMywzmZd01JIEQeUgEkGK8gOse+8gNUdP+R1FQpzKQSSCFBVkUe9gxRYdtSStUzmIRJARWT0Z2KeHjlqSNqkcRCKImVFUEGB1WQ1HTp72O46EMJWDSISZkh/g9BnHqq37/Y4iIUzlIBJhRmenktEznqU6aklaoXIQiTBRUcaU/ACvbKvmRO0Zv+NIiFI5iESgooIAJ06f4a/bq/2OIiFK5SASgcYO7kNqj1iW6aglOQeVg0gEio2OYtKITFaUVlFbV+93HAlBKgeRCDW1IMCHJ+tYU37A7ygSglQOIhFq/NA0kuKiddSStEjlIBKhEmKjuW54Bsu37ONMvfM7joQYlYNIBCsqCFBztJYN73/gdxQJMSoHkQh27cUZxMVEadOSNKNyEIlgyfExXJ2bxrKSfTinTUvyTyoHkQg3JT/A7kMn2LT7sN9RJISoHEQi3KQRmURHmTYtyVlUDiIRrndSHFcO6cvSzdq0JP+kchARphQEKK85Rtn+o35HkRARVDmYWZGZbTOzMjOb18L9ZmaPePdvNLMxbc01sz5mttzMdniXvRvdN8rM1phZiZltMrOEjq6oiJzblLxMzNCmJfmHNsvBzKKBx4CpQB5wu5nlNRk2Fcj1/mYDjwcxdx6w0jmXC6z0bmNmMcDTwBznXD5wLaCfrBK5gDJ6JTBmYG/9fKj8QzCfHMYCZc65cudcLbAQmN5kzHTgKddgLZBqZlltzJ0OLPCuLwBmeNc/AWx0zr0D4Jw74JzTSedFLrCi/AAle45QcfC431EkBARTDv2Bika3K71lwYxpbW6mc24vgHeZ4S0fBjgzW2Zmb5rZ14NZERHpmCn5AUCblqRBMOVgLSxrekjDucYEM7epGGAC8D+9y5vMbGKzUGazzWy9ma2vrtYPloh01MC+PcjL6qVNSwIEVw6VQHaj2wOAPUGOaW1ulbfpCe/yo187rwT+6pyrcc4dB5YAY2jCOTffOVfonCtMT08PYjVEpC1FBQE2vP8B+4+c9DuK+CyYclgH5JrZYDOLA2YCxU3GFAN3eEctjQMOe5uKWptbDMzyrs8CnveuLwNGmVkPb+f0NcCWdq6fiJyHqQUNm5aWbanyOYn4rc1ycM7VAXfT8D/tUmCRc67EzOaY2Rxv2BKgHCgDngS+1Npcb85DwGQz2wFM9m7jnPsA+AkNxfI28KZz7sWOr6qItGVoRjJD0pNYpv0OES8mmEHOuSU0FEDjZU80uu6AucHO9ZYfAJrtS/Due5qGw1lFpAuZGUX5AX7xajmHjteS2iPO70jiE31DWkTOUlQQ4Ey9Y0Xp/rYHS7elchCRs4zsn0K/lAQd0hrhVA4ichYzY0pBgFd3VHPsVJ3fccQnKgcRaaYoP0BtXT2rtmnTUqRSOYhIM4U5fUhLjtOmpQimchCRZqKjjMl5AVZt3c/J0zq1WSRSOYhIi4oKAhyrPcPqshq/o4gPVA4i0qIrh/SlZ0KMNi1FKJWDiLQoLiaKSSMyWV5aRX29fj400qgcROScpuQHOHT8NOU1x/yOIl1M5SAi53TNsHQSYqPYtPuw31Gki6kcROScEuOiuXZYBpt2H8Zpy1JEUTmISKuKCgJ8eKKOXfr50IiichCRVl03PIPoKLRpKcKoHESkVSmJsQzN7OltWtK2pUihchCRNo3sn8LBo7WsKT/gdxTpIioHEWnT6IGp9EmO5YHFm3U6jQihchCRNsVFR3HLmAGU1xzjZy+X+R1HuoDKQUSCMiyzJ7eMGcATf32X0r1H/I4jF5jKQUSC9sANI0hJjGXesxs5o1NqdGsqBxEJWu+kOL51Yz7vVB7m16t3+h1HLiCVg4icl0+OymLi8Ax+/JftVOiLcd2WykFEzouZ8b0ZBUQZ3L94k7770E2pHETkvPVLTeTeqcP5244anntzt99x5AJQOYhIu3zmikFcNqg333txCzVHT/kdRzqZykFE2iUqynjo5pEcP3WG7/7XFr/jSCdTOYhIu+Vm9mTudUMpfmcPL2+t8juOdCKVg4h0yBevvYhhmck8sHgzR0/V+R1HOonKQUQ6JC4min+7eRR7j5zk35du9TuOdBKVg4h02GWDejPryhyeWvs+G94/6Hcc6QQqBxHpFF+bcjH9UhK599lNnKrTmVvDncpBRDpFcnwMD95UQNn+o/x81bt+x5EOUjmISKe57uIMpl/aj5+/Usb2qg/9jiMdoHIQkU71zWl5JMfHcK/O3BrWgioHMysys21mVmZm81q438zsEe/+jWY2pq25ZtbHzJab2Q7vsneTxxxoZkfN7GsdWUER6Vp9k+P55ifzeGvXIZ5e+77fcaSd2iwHM4sGHgOmAnnA7WaW12TYVCDX+5sNPB7E3HnASudcLrDSu93YT4GX2rFOIuKzGZf25+ph6fxo6VZ2Hzrhdxxph2A+OYwFypxz5c65WmAhML3JmOnAU67BWiDVzLLamDsdWOBdXwDM+OjBzGwGUA6UtGutRMRXZsb3ZxTggAd05tawFEw59AcqGt2u9JYFM6a1uZnOub0A3mUGgJklAfcC32ktlJnNNrP1Zra+uro6iNUQka6U3acHX/vExazaVk3xO3v8jiPnKZhysBaWNX0bcK4xwcxt6jvAT51zR1sb5Jyb75wrdM4Vpqent/GQIuKHWVflcEl2Kt/5ry0cPFbrdxw5D8GUQyWQ3ej2AKDp24BzjWltbpW36Qnvcr+3/ArgR2b2HvAV4H4zuzuInCISYqKjjB/eMpIjJ07z4Is6c2s4CaYc1gG5ZjbYzOKAmUBxkzHFwB3eUUvjgMPepqLW5hYDs7zrs4DnAZxzH3PO5TjncoCHgR84537W7jUUEV8ND/Tii9dexHNv7uav27UJOFy0WQ7OuTrgbmAZUAoscs6VmNkcM5vjDVtCww7kMuBJ4EutzfXmPARMNrMdwGTvtoh0Q3d/fCgXpSdx/3ObOKYzt4aFmGAGOeeW0FAAjZc90ei6A+YGO9dbfgCY2MbzfjuYfCIS2uJjonnollHc+sQafrJ8O9+Y1vRoeAk1+oa0iHSJy3P68JlxA/n16p28XXHI7zjSBpWDiHSZe4uGk9EzgXnPbqS2rt7vONIKlYOIdJmeCbF8b0YBW/d9yPxXdebWUKZyEJEuNTkvkxtGZfHIyjLK9rf6dSbxkcpBRLrctz+ZT2JcNPc/t4l6nbk1JKkcRKTLpfeM519vGMEb7x3k92/s8juOtEDlICK+uPWyAYwf2peHXtrKvsMn/Y4jTagcRMQXZsYPbhpJXX09D/x5s87cGmJUDiLim0F9k/h/k4exorSKlzbv8zuONKJyEBFf3Tl+MCP7p/DN50s4dFxnbg0VKgcR8VVMdBQP3TKSD47X8oMlpX7HEY/KQUR8l98vhdlXD2HR+kpWl9X4HUdQOYhIiPjyxFxy+vbg/sWbOFF7xu84EU/lICIhISE2mn+7eRTvHzjOwyu2+x0n4qkcRCRkXHlRX24fm82Tfytn8+7DfseJaCoHEQkp86aOoG9yPF9/ZiOnz+jMrX5ROYhISElJjOV70/PZsvcIv3ptp99xIpbKQURCTlFBFkX5AX66fDs7a475HSciqRxEJCR9Z3o+cTFR3PfcRp1awwcqBxEJSZm9Erj/+hGsLT/IovUVfseJOCoHEQlZtxVmc8XgPnz/xVL2H9GZW7uSykFEQlZUlPHQLaM4WVfPt4pL/I4TUVQOIhLSBqcl8ZVJuby0eR+b9N2HLqNyEJGQ94WPDWFEVi8Wv7WbE6d1ao2uoHIQkZAXGx3FD28ZydGTdby4ca/fcSKCykFEwsKoAalcMyyd18sP8vVn3uGkPkFcUCoHEQkbU0dmMSkvg0XrK7nl8b9TcfC435G6LZWDiISNKIMp+QF+eUchuw4eZ9qjr7Fq236/Y3VLKgcRCTuT8jJ54Z4J9EtN5M7frOPhFdupr9e3qDuTykFEwtKgvkk898WruGl0fx5esYM7F6zTb1B3IpWDiIStxLhofnzrJTw4o4DVZTVMe/Q1/Q5EJ1E5iEhYMzM+M24Qi+66kjP1jpsf/7vOxdQJVA4i0i2MHtibF+6ZQOGg3nz9mY3c99xGHe7aAUGVg5kVmdk2Myszs3kt3G9m9oh3/0YzG9PWXDPrY2bLzWyHd9nbWz7ZzDaY2Sbv8uOdsaIi0v31TY7nqTvH8sVrL+IPb1TwqV+sofIDHe7aHm2Wg5lFA48BU4E84HYzy2sybCqQ6/3NBh4PYu48YKVzLhdY6d0GqAE+6ZwbCcwCftvutRORiBMTHcW9RcOZ/9nL2Fl9jGmPvsar26v9jhV2gvnkMBYoc86VO+dqgYXA9CZjpgNPuQZrgVQzy2pj7nRggXd9ATADwDn3lnNuj7e8BEgws/j2rZ6IRKpP5AcovmcCmT0TmPXrN3h05Q4d7noegimH/kDjvTuV3rJgxrQ2N9M5txfAu8xo4blvAd5yzp1qeoeZzTaz9Wa2vrpa7wpEpLnBaUksnnsVN17Sjx8v384XnlrP4ROn/Y4VFoIpB2thWdP6PdeYYOa2/KRm+cAPgbtaut85N985V+icK0xPTw/mIUUkAvWIi+Hh2y7lOzfm89ft1dz4s9fYsueI37FCXjDlUAlkN7o9ANgT5JjW5lZ5m57wLv/xHXgzGwAsBu5wzr0bREYRkXMyM2ZdlcMf7xrHydNnuPnx1Ty7odLvWCEtmHJYB+Sa2WAziwNmAsVNxhQDd3hHLY0DDnubilqbW0zDDme8y+cBzCwVeBG4zzm3uv2rJiJytssG9eGFez7GpdmpfPVP7/DAnzdxqk6Hu7akzXJwztUBdwPLgFJgkXOuxMzmmNkcb9gSoBwoA54EvtTaXG/OQ8BkM9sBTPZu440fCnzDzN72/lraHyEict7Se8bz9Oev4K6rh/D02l3c9ou17Dl0wu9YIScmmEHOuSU0FEDjZU80uu6AucHO9ZYfACa2sPxB4MFgcomItEdMdBT3XT+CS7NT+ZdnNjLt0dd49PbRjB+a5ne0kKFvSItIxJo6Movn7x5P36Q4Pvur1/n5K2U0vNcVlYOIRLSL0pP589zxXD8yix8t3cZdv93AkZM63FXlICIRLyk+hkdvH803p+Xx8tb9TP/Zarbui+zDXVUOIiI0HO5654TB/GH2OI6equOmx/7O82/v9juWb1QOIiKNXJ7ThxfvmcDI/il8eeHbfLu4hNq6er9jdTmVg4hIExm9EvjdF67g8xMG85u/v8fM+WvYd/ik37G6lMpBRKQFsdFRfGNaHj/79Gi27vuQaY/+jTXvHvA7VpdROYiItGLaqH48P3c8KYmxfOZXrzP/1cg4o4/KQUSkDbmZPXn+7glMyc/kB0u28tSa97v9r8wF9Q1pEZFIlxwfw2OfHsMv/7aTNaXv8d6yY2RnfcCnCrMZ1DfJ73idTuUgIhIkM+MLVw9hdEY0L2/dz+OvvMtjq97lyiF9mTk2myn5ARJio/2O2SlUDiIi5yknLYk7Jwzms9cV8Mz6ShZtqODLC98mJTGWm0b351OF2eT16+V3zA5ROYiItFNWSiL3TMxl7nVDWVN+gIXrKvj967v4zd/fY9SAFG67PJsbL+lHz4RYv6OeN5WDiEgHRUUZ44emMX5oGh8cq+XPb+9m4RsV/OvizTz4QinXj8xi5thsCgf1xqylH8gMPSoHEZFO1Dspjs+NH8z/uiqHdyoP88d1uyh+ew/PvlnJkPQkZl6ezc1jBpCWHO931FapHERELgAz49LsVC7NTuWBG/J4cdNe/riugh8s2cqPlm5j0ohMbhubzdW56URHhd6nCZWDiMgFlhQfw6cKs/lUYTZl+z/kj+sqePbN3Swt2UdWSgK3XjaAWwuzye7Tw++o/6ByEBHpQkMzevKvN+TxL1OGs6K0ioXrKnh0VRmPripjwtA0brs8m8l5mcTH+HtIrMpBRMQHcTFRXD8yi+tHZrH70An+tL6CP62v5O7fv0XvHrHcNHoAt12ezcWBnr7kUzmIiPisf2oiX5k0jHs+nstrZTUsWlfBb9e+x3+u3snogancVpjNtEv6kRzfdf/LVjmIiISI6CjjmmHpXDMsnQNHT7H4rd0sXFfBvOc28d0XtvDJUf24bWw2o7NTL/ghsSoHEZEQ1Dc5nv/9sSF8fsJg3tx1qOGQ2Hf28Mf1FeRmJDfsmxiSRFL8hdk3oXIQEQlhZsZlg3pz2aDefGNaHi9sbDgk9sEXS1k7OJHxQ9P43MS0Tn9elYOISJjomRDL7WMHcvvYgWzdd4TVm3fSOynugjyXykFEJAwND/QiLab/BXt8/diPiIg0o3IQEZFmVA4iItKMykFERJpROYiISDMqBxERaUblICIizagcRESkGXPO+Z2hw8ysGnjf7xwdlAbU+B0ihOj1OJtej3/Sa3G2jrweg5xz6S3d0S3KoTsws/XOuUK/c4QKvR5n0+vxT3otznahXg9tVhIRkWZUDiIi0ozKIXTM9ztAiNHrcTa9Hv+k1+JsF+T10D4HERFpRp8cRESkGZWDiIg0o3LwmZllm9kqMys1sxIz+7LfmfxmZtFm9paZveB3Fr+ZWaqZPWNmW73/Rq70O5OfzOz/ev9ONpvZH8wswe9MXcnM/tPM9pvZ5kbL+pjZcjPb4V327oznUjn4rw74qnNuBDAOmGtmeT5n8tuXgVK/Q4SI/w8sdc4NBy4hgl8XM+sP/B+g0DlXAEQDM/1N1eV+AxQ1WTYPWOmcywVWerc7TOXgM+fcXufcm971D2n4x3/hfvsvxJnZAOAG4Jd+Z/GbmfUCrgZ+BeCcq3XOHfI1lP9igEQziwF6AHt8ztOlnHOvAgebLJ4OLPCuLwBmdMZzqRxCiJnlAKOB132O4qeHga8D9T7nCAVDgGrg195mtl+aWZLfofzinNsN/AewC9gLHHbO/cXfVCEh0zm3FxrebAIZnfGgKocQYWbJwLPAV5xzR/zO4wczmwbsd85t8DtLiIgBxgCPO+dGA8fopE0G4cjblj4dGAz0A5LM7DP+puq+VA4hwMxiaSiG3znnnvM7j4/GAzea2XvAQuDjZva0v5F8VQlUOuc++iT5DA1lEakmATudc9XOudPAc8BVPmcKBVVmlgXgXe7vjAdVOfjMzIyGbcqlzrmf+J3HT865+5xzA5xzOTTsaHzZORex7wydc/uACjO72Fs0EdjiYyS/7QLGmVkP79/NRCJ4B30jxcAs7/os4PnOeNCYzngQ6ZDxwGeBTWb2trfsfufcEv8iSQi5B/idmcUB5cDnfM7jG+fc62b2DPAmDUf5vUWEnUrDzP4AXAukmVkl8C3gIWCRmX2ehgK9tVOeS6fPEBGRprRZSUREmlE5iIhIMyoHERFpRuUgIiLNqBxERKQZlYOIiDSjchARkWb+G+V90zTQ9MMOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "x = torch.nn.Parameter(torch.randn(4, 4))\n", + "optimizer = optim.Adam([x], lr=1e-3)\n", + "scheduler = MultiStepLR(optimizer, milestones=[6, 7, 8, 9], gamma=0.5)\n", + "\n", + "def get_lr(optimizer):\n", + " for param_group in optimizer.param_groups:\n", + " return param_group['lr']\n", + " \n", + "lrs = []\n", + "for i in range(10):\n", + " lrs.append(get_lr(optimizer))\n", + " scheduler.step()\n", + "\n", + "plt.plot(range(1, 11), lrs)\n", + "for x in range(1, 11):\n", + " plt.axvline(x, c='0.8', alpha=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([19, 11, 0, 15, 18, 10, 12, 14, 17, 6, 13, 1, 4, 9, 8, 5, 2,\n", + " 3, 7, 16])" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "np.random.permutation(20)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dohoon", + "language": "python", + "name": "dohoon" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/spliceai_pytorch/__init__.py b/spliceai_pytorch/__init__.py index e69de29..f4ca5bc 100644 --- a/spliceai_pytorch/__init__.py +++ b/spliceai_pytorch/__init__.py @@ -0,0 +1 @@ +from spliceai_pytorch import SpliceAI_80nt, SpliceAI_400nt, SpliceAI_2k, SpliceAI_10k \ No newline at end of file diff --git a/spliceai_pytorch/data.py b/spliceai_pytorch/data.py new file mode 100644 index 0000000..df25eb7 --- /dev/null +++ b/spliceai_pytorch/data.py @@ -0,0 +1,16 @@ +from torch.utils.data import TensorDataset, DataLoader + +if __name__ == '__main__': + import torch + import h5py + + h5f = h5py.File('../spliceai_train_code/Canonical/dataset_train_all.h5') + idx = 1 + + X, Y = h5f[f'X{idx}'][:], h5f[f'Y{idx}'][0, ...] + ds = TensorDataset(torch.from_numpy(X), torch.from_numpy(Y)) + loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=8) + + for batch in loader: + print(batch[0].shape, batch[1].shape) + break \ No newline at end of file diff --git a/spliceai_pytorch/spliceai_pytorch.py b/spliceai_pytorch/spliceai_pytorch.py index 0acb4f9..d3624fa 100644 --- a/spliceai_pytorch/spliceai_pytorch.py +++ b/spliceai_pytorch/spliceai_pytorch.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from einops import rearrange + class Residual(nn.Module): def __init__(self, fn): super().__init__() @@ -42,7 +44,7 @@ def forward(self, x): x = detour + self.block1(x) x = self.conv_last(x) - return x[..., 40:5000 + 40].softmax(dim=-1) + return rearrange(x[..., 40:5000 + 40], 'b c l -> b l c') class SpliceAI_400nt(nn.Module): S = 400 @@ -82,8 +84,7 @@ def forward(self, x): x = self.block2(x) + detour x = self.conv_last(x) - return x[..., 200:5000 + 200].softmax(dim=-1) - + return rearrange(x[..., 200:5000 + 200], 'b c l -> b l c') class SpliceAI_2k(nn.Module): S = 2000 @@ -135,7 +136,7 @@ def forward(self, x): x = self.block3(x) + detour x = self.conv_last(x) - return x[..., 1000:5000 + 1000].softmax(dim=-1) + return rearrange(x[..., 1000:5000 + 1000], 'b c l -> b l c') class SpliceAI_10k(nn.Module): S = 10000 @@ -198,7 +199,7 @@ def forward(self, x): x = self.block4(x) + detour x = self.conv_last(x) - return x[..., 5000:5000 + 5000].softmax(dim=-1) + return rearrange(x[..., 5000:5000 + 5000], 'b c l -> b l c') class SpliceAI(): diff --git a/spliceai_pytorch/train.py b/spliceai_pytorch/train.py new file mode 100644 index 0000000..56dfb05 --- /dev/null +++ b/spliceai_pytorch/train.py @@ -0,0 +1,136 @@ +import argparse +import tqdm +import random + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from torch.utils.data import TensorDataset, DataLoader +from spliceai_pytorch import SpliceAI_80nt +import numpy as np + +def shuffle(arr): + return np.random.choice(arr, size=len(arr), replace=False) + +def train(model, h5f, train_shard_idxs, batch_size, optimizer, criterion): + model.train() + running_output, running_label = [], [] + + for i, shard_idx in enumerate(shuffle(train_shard_idxs), 1): + X = h5f[f'X{shard_idx}'][:].transpose(0, 2, 1) + Y = h5f[f'Y{shard_idx}'][0, ...] + + ds = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(Y).float()) + loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) + + bar = tqdm.tqdm(loader, leave=False, total=len(loader), desc=f'Shard {i}/{len(train_shard_idxs)}') + for idx, batch in enumerate(bar): + X, Y = batch[0].cuda(), batch[1].cuda() + optimizer.zero_grad() + out = model(X) # (batch_size, 5000, 3) + loss = criterion(out, Y) + loss.backward() + optimizer.step() + + running_output.append(out.detach().cpu()) + running_label.append(Y.detach().cpu()) + + if idx % 100 == 0: + running_output = torch.cat(running_output, dim=0) + running_label = torch.cat(running_label, dim=0) + + loss = criterion(running_output, running_label) + bar.set_postfix(loss=f'{loss.item():.4f}') + + running_output, running_label = [], [] + + +def validate(model, h5f, val_shard_idxs, batch_size, criterion): + model.eval() + + out, label = [], [] + for shard_idx in val_shard_idxs: + X = h5f[f'X{shard_idx}'][:].transpose(0, 2, 1) + Y = h5f[f'Y{shard_idx}'][0, ...] + + ds = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(Y).float()) + loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) + + bar = tqdm.tqdm(loader, leave=False, total=len(loader)) + for idx, batch in enumerate(bar): + X, Y = batch[0].cuda(), batch[1].cuda() + _out = model(X).detach().cpu() + _label = Y.detach().cpu() + + out.append(_out) + label.append(_label) + + loss = criterion(torch.cat(out, dim=0), torch.cat(label, dim=0)) + return loss.item() + +def test(model, test_loader, device): + model.eval() + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_accuracy = 100. * correct / len(test_loader.dataset) + return test_accuracy + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + # Performance drops, so commenting out for now. + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + +def main(): + import pandas as pd + import h5py + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--train-h5', required=True) + parser.add_argument('--test-h5', required=True) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--batch-size', '-b', type=int, default=6) + parser.add_argument('--learning-rate', '-lr', type=float, default=1e-3) + parser.add_argument('--seed', type=int, default=42) + args = parser.parse_args() + + seed_everything(args.seed) + + train_h5f = h5py.File(args.train_h5, 'r') + test_h5f = h5py.File(args.test_h5, 'r') + + num_shards = len(train_h5f.keys()) // 2 + shard_idxs = np.random.permutation(num_shards) + train_shard_idxs = shard_idxs[:int(0.9 * num_shards)] + val_shard_idxs = shard_idxs[int(0.9 * num_shards):] + + model = SpliceAI_80nt() + model.cuda() + + criterion = torch.nn.BCEWithLogitsLoss() + optimizer = optim.Adam(model.parameters(), lr=1e-3) + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6, 7, 8, 9], gamma=0.5) + + for epoch in range(args.epochs): + train(model, train_h5f, train_shard_idxs, args.batch_size, optimizer, criterion) + validate(model, train_h5f, val_shard_idxs, args.batch_size, criterion) + + + scheduler.step() + + +if __name__ == '__main__': + main() \ No newline at end of file