Skip to content

Commit

Permalink
Train now works.
Browse files Browse the repository at this point in the history
  • Loading branch information
dohlee committed Feb 16, 2023
1 parent 734ce45 commit 45fcdf4
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 16 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
spliceai_train_code/

note/.ipynb_checkpoints
data/
note/.ipynb_checkpoints/

*.egg-info/

**/__pycache__/
47 changes: 37 additions & 10 deletions note/investigating_dataset_h5.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
"outputs": [],
"source": [
"import h5py\n",
"import numpy as np"
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -48,42 +49,68 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<HDF5 dataset \"X0\": shape (5662, 5080, 4), type \"|i1\">"
"torch.Size([5662, 5080, 4])"
]
},
"execution_count": 15,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf['X0']"
"torch.from_numpy(hf['X0'][:]).shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<HDF5 dataset \"Y0\": shape (1, 5662, 5000, 3), type \"|i1\">"
"tensor([[0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" ...,\n",
" [0., 1., 0., 0.],\n",
" [0., 0., 0., 1.],\n",
" [1., 0., 0., 0.]])"
]
},
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf['Y0']"
"torch.from_numpy(hf['X0'][:])[0].float()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 5662, 5000, 3])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.from_numpy(hf['Y0'][:]).shape"
]
},
{
Expand Down
96 changes: 96 additions & 0 deletions note/multistep_lr_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"from torch.optim.lr_scheduler import MultiStepLR\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD4CAYAAAAHHSreAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAeeElEQVR4nO3de3hU9b3v8fc395BAAuQyAQIBCUISUDAiCvVSoASlgnqs2NPKqT1FWvS057S7otvebbft3m09WqvFdrdY21KqUrMVoYBYKwUFvAAhXGJQEi4hAQG5hpDf/iPLNiQhGZKQNZP5vJ4nz8ys+f1mPmse8TOz1po15pxDRESksSi/A4iISOhROYiISDMqBxERaUblICIizagcRESkmRi/A3SGtLQ0l5OT0665dXV1AMTE+PdShEIG5VCOUM+gHJ2fY8OGDTXOufSW7usW5ZCTk8P69evbNbempgaAtLS0zowUdhmUQzlCPYNydH4OM3v/XPdps5KIiDSjchARkWZUDiIi0ozKQUREmlE5iIhIM0GVg5kVmdk2Myszs3kt3G9m9oh3/0YzG9PWXDO71cxKzKzezAqbPN593vhtZjalIysoIiLnr81yMLNo4DFgKpAH3G5meU2GTQVyvb/ZwONBzN0M3Ay82uT58oCZQD5QBPzcexwREekiwXzPYSxQ5pwrBzCzhcB0YEujMdOBp1zD+b/XmlmqmWUBOeea65wr9ZY1fb7pwELn3Clgp5mVeRnWtG8Vz23f4ZOsLNkHQNWpA5398EHLjD/tewblCN0c/RLruHJIGv4eUS+RJphy6A9UNLpdCVwRxJj+Qc5t6fnWtvBYZzGz2TR8SmHgwIFtPGTLqo6cZEXpfgBeef9Eux6jM1w7KNH3DMoRwjkGJnKmHgb1z/Q1h0SWYMqh2Vt7oOkvBJ1rTDBz2/N8OOfmA/MBCgsL2/WLRZdkp/Lv/2MUoG+fKkfo5vjuM6+zefdhXzNI5Almh3QlkN3o9gBgT5BjgpnbnucTiRgj+6ew7/BJdtYc8zuKRJBgymEdkGtmg80sjoadxcVNxhQDd3hHLY0DDjvn9gY5t6liYKaZxZvZYBp2cr9xHusk0q0U9E8BYOnmfT4nkUjSZjk45+qAu4FlQCmwyDlXYmZzzGyON2wJUA6UAU8CX2ptLoCZ3WRmlcCVwItmtsybUwIsomGH91JgrnPuTCetr0jY6d0jluw+iSwtUTlI1wnqrKzOuSU0FEDjZU80uu6AucHO9ZYvBhafY873ge8Hk00kEhT0T+Gpt3ay59AJ+qUm+h1HIoC+IS0SBkZ6m5b+ok8P0kVUDiJhIL1nPMMyk7VpSbqMykEkTBTlB3hj50EOHD3ldxSJACoHkTAxpSBAvYMVpVV+R5EIoHIQCRN5Wb3I7pPISzqkVbqAykEkTJgZRfkBVpfVcOTkab/jSDenchAJI0UFAU6fcazaut/vKNLNqRxEwsjo7N5k9IzXt6XlglM5iISRqChjSn6AV7ZVc6JWJw6QC0flIBJmigoCnDh9hld3VPsdRboxlYNImBk7uA+pPWJZpk1LcgGpHETCTGx0FJNGZLKitIraunq/40g3pXIQCUNF+QGOnKxjTbm/P2Eq3ZfKQSQMTchNo0dctI5akgtG5SAShhJio7lueAbLt+zjTH27fiVXpFUqB5EwNbUgQM3RWja8/4HfUaQbUjmIhKlrL84gLiZKm5bkglA5iISp5PgYrs5NY1nJPhp+jFGk86gcRMLYlPwAuw+dYPPuI35HkW5G5SASxiaNyCQ6ylhastfvKNLNqBxEwljvpDjGDemj/Q7S6VQOImGuKD/Au9XH2FH1od9RpBtROYiEuU/kBwD06UE6lcpBJMxl9krgskG9WVqicpDOo3IQ6QaK8gOU7DlCxcHjfkeRbkLlININTPE2LS3TpwfpJCoHkW5gYN8e5GX10n4H6TQqB5FuoqggwIZdH7D/yEm/o0g3oHIQ6SaKCgI4B3/ZUuV3FOkGVA4i3URuRjJD0pK0aUk6hcpBpJswM6YUBFhTfoBDx2v9jiNhTuUg0o0U5Qc4U+9YUbrf7ygS5lQOIt3IqAEp9EtJ0KYl6bCgysHMisxsm5mVmdm8Fu43M3vEu3+jmY1pa66Z9TGz5Wa2w7vs7S2PNbMFZrbJzErN7L7OWFGRSPDRpqVXd1Rz7FSd33EkjLVZDmYWDTwGTAXygNvNLK/JsKlArvc3G3g8iLnzgJXOuVxgpXcb4FYg3jk3ErgMuMvMctq7giKRpig/QG1dPa9sq/Y7ioSxYD45jAXKnHPlzrlaYCEwvcmY6cBTrsFaINXMstqYOx1Y4F1fAMzwrjsgycxigESgFtAvmYgEqTCnD32T4nSuJemQYMqhP1DR6HaltyyYMa3NzXTO7QXwLjO85c8Ax4C9wC7gP5xzB5uGMrPZZrbezNZXV+sdkshHoqOMT+Rn8nJpFSdPn/E7joSpYMrBWljW9AdrzzUmmLlNjQXOAP2AwcBXzWxIswdxbr5zrtA5V5ient7GQ4pElin5AY7VnmF1WY3fUSRMBVMOlUB2o9sDgD1BjmltbpW36Qnv8qNj7z4NLHXOnXbO7QdWA4VB5BQRz1UXpdEzPkZHLUm7BVMO64BcMxtsZnHATKC4yZhi4A7vqKVxwGFvU1Frc4uBWd71WcDz3vVdwMe9x0oCxgFb27l+IhEpLiaKiSMyWF5aRd2Zer/jSBhqsxycc3XA3cAyoBRY5JwrMbM5ZjbHG7YEKAfKgCeBL7U215vzEDDZzHYAk73b0HB0UzKwmYZy+bVzbmNHV1Qk0hQVZHHo+Gne2Nlsl51Im2KCGeScW0JDATRe9kSj6w6YG+xcb/kBYGILy4/ScDiriHTANcPSSYiNYmnJPq4amuZ3HAkz+oa0SDeVGBfNtcMyWFayj/r6to4DETmbykGkGysqCFB15BRvVx7yO4qEGZWDSDd23fAMYqONZTpqSc6TykGkG0tJjOWqi9J4afM+GnYNigRH5SDSzRUVBNh18Dilez/0O4qEEZWDSDc3OS+TKEPnWpLzonIQ6ebSkuO5PKeP9jvIeVE5iESAooIA26o+pLz6qN9RJEyoHEQiwJT8AADLSqp8TiLhQuUgEgH6pSZyyYAU7XeQoKkcRCLElIIA71QcYs+hE35HkTCgchCJEEXepqW/6NODBEHlIBIhhqQnMywzmZd01JIEQeUgEkGK8gOse+8gNUdP+R1FQpzKQSSCFBVkUe9gxRYdtSStUzmIRJARWT0Z2KeHjlqSNqkcRCKImVFUEGB1WQ1HTp72O46EMJWDSISZkh/g9BnHqq37/Y4iIUzlIBJhRmenktEznqU6aklaoXIQiTBRUcaU/ACvbKvmRO0Zv+NIiFI5iESgooIAJ06f4a/bq/2OIiFK5SASgcYO7kNqj1iW6aglOQeVg0gEio2OYtKITFaUVlFbV+93HAlBKgeRCDW1IMCHJ+tYU37A7ygSglQOIhFq/NA0kuKiddSStEjlIBKhEmKjuW54Bsu37ONMvfM7joQYlYNIBCsqCFBztJYN73/gdxQJMSoHkQh27cUZxMVEadOSNKNyEIlgyfExXJ2bxrKSfTinTUvyTyoHkQg3JT/A7kMn2LT7sN9RJISoHEQi3KQRmURHmTYtyVlUDiIRrndSHFcO6cvSzdq0JP+kchARphQEKK85Rtn+o35HkRARVDmYWZGZbTOzMjOb18L9ZmaPePdvNLMxbc01sz5mttzMdniXvRvdN8rM1phZiZltMrOEjq6oiJzblLxMzNCmJfmHNsvBzKKBx4CpQB5wu5nlNRk2Fcj1/mYDjwcxdx6w0jmXC6z0bmNmMcDTwBznXD5wLaCfrBK5gDJ6JTBmYG/9fKj8QzCfHMYCZc65cudcLbAQmN5kzHTgKddgLZBqZlltzJ0OLPCuLwBmeNc/AWx0zr0D4Jw74JzTSedFLrCi/AAle45QcfC431EkBARTDv2Bika3K71lwYxpbW6mc24vgHeZ4S0fBjgzW2Zmb5rZ14NZERHpmCn5AUCblqRBMOVgLSxrekjDucYEM7epGGAC8D+9y5vMbGKzUGazzWy9ma2vrtYPloh01MC+PcjL6qVNSwIEVw6VQHaj2wOAPUGOaW1ulbfpCe/yo187rwT+6pyrcc4dB5YAY2jCOTffOVfonCtMT08PYjVEpC1FBQE2vP8B+4+c9DuK+CyYclgH5JrZYDOLA2YCxU3GFAN3eEctjQMOe5uKWptbDMzyrs8CnveuLwNGmVkPb+f0NcCWdq6fiJyHqQUNm5aWbanyOYn4rc1ycM7VAXfT8D/tUmCRc67EzOaY2Rxv2BKgHCgDngS+1Npcb85DwGQz2wFM9m7jnPsA+AkNxfI28KZz7sWOr6qItGVoRjJD0pNYpv0OES8mmEHOuSU0FEDjZU80uu6AucHO9ZYfAJrtS/Due5qGw1lFpAuZGUX5AX7xajmHjteS2iPO70jiE31DWkTOUlQQ4Ey9Y0Xp/rYHS7elchCRs4zsn0K/lAQd0hrhVA4ichYzY0pBgFd3VHPsVJ3fccQnKgcRaaYoP0BtXT2rtmnTUqRSOYhIM4U5fUhLjtOmpQimchCRZqKjjMl5AVZt3c/J0zq1WSRSOYhIi4oKAhyrPcPqshq/o4gPVA4i0qIrh/SlZ0KMNi1FKJWDiLQoLiaKSSMyWV5aRX29fj400qgcROScpuQHOHT8NOU1x/yOIl1M5SAi53TNsHQSYqPYtPuw31Gki6kcROScEuOiuXZYBpt2H8Zpy1JEUTmISKuKCgJ8eKKOXfr50IiichCRVl03PIPoKLRpKcKoHESkVSmJsQzN7OltWtK2pUihchCRNo3sn8LBo7WsKT/gdxTpIioHEWnT6IGp9EmO5YHFm3U6jQihchCRNsVFR3HLmAGU1xzjZy+X+R1HuoDKQUSCMiyzJ7eMGcATf32X0r1H/I4jF5jKQUSC9sANI0hJjGXesxs5o1NqdGsqBxEJWu+kOL51Yz7vVB7m16t3+h1HLiCVg4icl0+OymLi8Ax+/JftVOiLcd2WykFEzouZ8b0ZBUQZ3L94k7770E2pHETkvPVLTeTeqcP5244anntzt99x5AJQOYhIu3zmikFcNqg333txCzVHT/kdRzqZykFE2iUqynjo5pEcP3WG7/7XFr/jSCdTOYhIu+Vm9mTudUMpfmcPL2+t8juOdCKVg4h0yBevvYhhmck8sHgzR0/V+R1HOonKQUQ6JC4min+7eRR7j5zk35du9TuOdBKVg4h02GWDejPryhyeWvs+G94/6Hcc6QQqBxHpFF+bcjH9UhK599lNnKrTmVvDncpBRDpFcnwMD95UQNn+o/x81bt+x5EOUjmISKe57uIMpl/aj5+/Usb2qg/9jiMdoHIQkU71zWl5JMfHcK/O3BrWgioHMysys21mVmZm81q438zsEe/+jWY2pq25ZtbHzJab2Q7vsneTxxxoZkfN7GsdWUER6Vp9k+P55ifzeGvXIZ5e+77fcaSd2iwHM4sGHgOmAnnA7WaW12TYVCDX+5sNPB7E3HnASudcLrDSu93YT4GX2rFOIuKzGZf25+ph6fxo6VZ2Hzrhdxxph2A+OYwFypxz5c65WmAhML3JmOnAU67BWiDVzLLamDsdWOBdXwDM+OjBzGwGUA6UtGutRMRXZsb3ZxTggAd05tawFEw59AcqGt2u9JYFM6a1uZnOub0A3mUGgJklAfcC32ktlJnNNrP1Zra+uro6iNUQka6U3acHX/vExazaVk3xO3v8jiPnKZhysBaWNX0bcK4xwcxt6jvAT51zR1sb5Jyb75wrdM4Vpqent/GQIuKHWVflcEl2Kt/5ry0cPFbrdxw5D8GUQyWQ3ej2AKDp24BzjWltbpW36Qnvcr+3/ArgR2b2HvAV4H4zuzuInCISYqKjjB/eMpIjJ07z4Is6c2s4CaYc1gG5ZjbYzOKAmUBxkzHFwB3eUUvjgMPepqLW5hYDs7zrs4DnAZxzH3PO5TjncoCHgR84537W7jUUEV8ND/Tii9dexHNv7uav27UJOFy0WQ7OuTrgbmAZUAoscs6VmNkcM5vjDVtCww7kMuBJ4EutzfXmPARMNrMdwGTvtoh0Q3d/fCgXpSdx/3ObOKYzt4aFmGAGOeeW0FAAjZc90ei6A+YGO9dbfgCY2MbzfjuYfCIS2uJjonnollHc+sQafrJ8O9+Y1vRoeAk1+oa0iHSJy3P68JlxA/n16p28XXHI7zjSBpWDiHSZe4uGk9EzgXnPbqS2rt7vONIKlYOIdJmeCbF8b0YBW/d9yPxXdebWUKZyEJEuNTkvkxtGZfHIyjLK9rf6dSbxkcpBRLrctz+ZT2JcNPc/t4l6nbk1JKkcRKTLpfeM519vGMEb7x3k92/s8juOtEDlICK+uPWyAYwf2peHXtrKvsMn/Y4jTagcRMQXZsYPbhpJXX09D/x5s87cGmJUDiLim0F9k/h/k4exorSKlzbv8zuONKJyEBFf3Tl+MCP7p/DN50s4dFxnbg0VKgcR8VVMdBQP3TKSD47X8oMlpX7HEY/KQUR8l98vhdlXD2HR+kpWl9X4HUdQOYhIiPjyxFxy+vbg/sWbOFF7xu84EU/lICIhISE2mn+7eRTvHzjOwyu2+x0n4qkcRCRkXHlRX24fm82Tfytn8+7DfseJaCoHEQkp86aOoG9yPF9/ZiOnz+jMrX5ROYhISElJjOV70/PZsvcIv3ptp99xIpbKQURCTlFBFkX5AX66fDs7a475HSciqRxEJCR9Z3o+cTFR3PfcRp1awwcqBxEJSZm9Erj/+hGsLT/IovUVfseJOCoHEQlZtxVmc8XgPnz/xVL2H9GZW7uSykFEQlZUlPHQLaM4WVfPt4pL/I4TUVQOIhLSBqcl8ZVJuby0eR+b9N2HLqNyEJGQ94WPDWFEVi8Wv7WbE6d1ao2uoHIQkZAXGx3FD28ZydGTdby4ca/fcSKCykFEwsKoAalcMyyd18sP8vVn3uGkPkFcUCoHEQkbU0dmMSkvg0XrK7nl8b9TcfC435G6LZWDiISNKIMp+QF+eUchuw4eZ9qjr7Fq236/Y3VLKgcRCTuT8jJ54Z4J9EtN5M7frOPhFdupr9e3qDuTykFEwtKgvkk898WruGl0fx5esYM7F6zTb1B3IpWDiIStxLhofnzrJTw4o4DVZTVMe/Q1/Q5EJ1E5iEhYMzM+M24Qi+66kjP1jpsf/7vOxdQJVA4i0i2MHtibF+6ZQOGg3nz9mY3c99xGHe7aAUGVg5kVmdk2Myszs3kt3G9m9oh3/0YzG9PWXDPrY2bLzWyHd9nbWz7ZzDaY2Sbv8uOdsaIi0v31TY7nqTvH8sVrL+IPb1TwqV+sofIDHe7aHm2Wg5lFA48BU4E84HYzy2sybCqQ6/3NBh4PYu48YKVzLhdY6d0GqAE+6ZwbCcwCftvutRORiBMTHcW9RcOZ/9nL2Fl9jGmPvsar26v9jhV2gvnkMBYoc86VO+dqgYXA9CZjpgNPuQZrgVQzy2pj7nRggXd9ATADwDn3lnNuj7e8BEgws/j2rZ6IRKpP5AcovmcCmT0TmPXrN3h05Q4d7noegimH/kDjvTuV3rJgxrQ2N9M5txfAu8xo4blvAd5yzp1qeoeZzTaz9Wa2vrpa7wpEpLnBaUksnnsVN17Sjx8v384XnlrP4ROn/Y4VFoIpB2thWdP6PdeYYOa2/KRm+cAPgbtaut85N985V+icK0xPTw/mIUUkAvWIi+Hh2y7lOzfm89ft1dz4s9fYsueI37FCXjDlUAlkN7o9ANgT5JjW5lZ5m57wLv/xHXgzGwAsBu5wzr0bREYRkXMyM2ZdlcMf7xrHydNnuPnx1Ty7odLvWCEtmHJYB+Sa2WAziwNmAsVNxhQDd3hHLY0DDnubilqbW0zDDme8y+cBzCwVeBG4zzm3uv2rJiJytssG9eGFez7GpdmpfPVP7/DAnzdxqk6Hu7akzXJwztUBdwPLgFJgkXOuxMzmmNkcb9gSoBwoA54EvtTaXG/OQ8BkM9sBTPZu440fCnzDzN72/lraHyEict7Se8bz9Oev4K6rh/D02l3c9ou17Dl0wu9YIScmmEHOuSU0FEDjZU80uu6AucHO9ZYfACa2sPxB4MFgcomItEdMdBT3XT+CS7NT+ZdnNjLt0dd49PbRjB+a5ne0kKFvSItIxJo6Movn7x5P36Q4Pvur1/n5K2U0vNcVlYOIRLSL0pP589zxXD8yix8t3cZdv93AkZM63FXlICIRLyk+hkdvH803p+Xx8tb9TP/Zarbui+zDXVUOIiI0HO5654TB/GH2OI6equOmx/7O82/v9juWb1QOIiKNXJ7ThxfvmcDI/il8eeHbfLu4hNq6er9jdTmVg4hIExm9EvjdF67g8xMG85u/v8fM+WvYd/ik37G6lMpBRKQFsdFRfGNaHj/79Gi27vuQaY/+jTXvHvA7VpdROYiItGLaqH48P3c8KYmxfOZXrzP/1cg4o4/KQUSkDbmZPXn+7glMyc/kB0u28tSa97v9r8wF9Q1pEZFIlxwfw2OfHsMv/7aTNaXv8d6yY2RnfcCnCrMZ1DfJ73idTuUgIhIkM+MLVw9hdEY0L2/dz+OvvMtjq97lyiF9mTk2myn5ARJio/2O2SlUDiIi5yknLYk7Jwzms9cV8Mz6ShZtqODLC98mJTGWm0b351OF2eT16+V3zA5ROYiItFNWSiL3TMxl7nVDWVN+gIXrKvj967v4zd/fY9SAFG67PJsbL+lHz4RYv6OeN5WDiEgHRUUZ44emMX5oGh8cq+XPb+9m4RsV/OvizTz4QinXj8xi5thsCgf1xqylH8gMPSoHEZFO1Dspjs+NH8z/uiqHdyoP88d1uyh+ew/PvlnJkPQkZl6ezc1jBpCWHO931FapHERELgAz49LsVC7NTuWBG/J4cdNe/riugh8s2cqPlm5j0ohMbhubzdW56URHhd6nCZWDiMgFlhQfw6cKs/lUYTZl+z/kj+sqePbN3Swt2UdWSgK3XjaAWwuzye7Tw++o/6ByEBHpQkMzevKvN+TxL1OGs6K0ioXrKnh0VRmPripjwtA0brs8m8l5mcTH+HtIrMpBRMQHcTFRXD8yi+tHZrH70An+tL6CP62v5O7fv0XvHrHcNHoAt12ezcWBnr7kUzmIiPisf2oiX5k0jHs+nstrZTUsWlfBb9e+x3+u3snogancVpjNtEv6kRzfdf/LVjmIiISI6CjjmmHpXDMsnQNHT7H4rd0sXFfBvOc28d0XtvDJUf24bWw2o7NTL/ghsSoHEZEQ1Dc5nv/9sSF8fsJg3tx1qOGQ2Hf28Mf1FeRmJDfsmxiSRFL8hdk3oXIQEQlhZsZlg3pz2aDefGNaHi9sbDgk9sEXS1k7OJHxQ9P43MS0Tn9elYOISJjomRDL7WMHcvvYgWzdd4TVm3fSOynugjyXykFEJAwND/QiLab/BXt8/diPiIg0o3IQEZFmVA4iItKMykFERJpROYiISDMqBxERaUblICIizagcRESkGXPO+Z2hw8ysGnjf7xwdlAbU+B0ihOj1OJtej3/Sa3G2jrweg5xz6S3d0S3KoTsws/XOuUK/c4QKvR5n0+vxT3otznahXg9tVhIRkWZUDiIi0ozKIXTM9ztAiNHrcTa9Hv+k1+JsF+T10D4HERFpRp8cRESkGZWDiIg0o3LwmZllm9kqMys1sxIz+7LfmfxmZtFm9paZveB3Fr+ZWaqZPWNmW73/Rq70O5OfzOz/ev9ONpvZH8wswe9MXcnM/tPM9pvZ5kbL+pjZcjPb4V327oznUjn4rw74qnNuBDAOmGtmeT5n8tuXgVK/Q4SI/w8sdc4NBy4hgl8XM+sP/B+g0DlXAEQDM/1N1eV+AxQ1WTYPWOmcywVWerc7TOXgM+fcXufcm971D2n4x3/hfvsvxJnZAOAG4Jd+Z/GbmfUCrgZ+BeCcq3XOHfI1lP9igEQziwF6AHt8ztOlnHOvAgebLJ4OLPCuLwBmdMZzqRxCiJnlAKOB132O4qeHga8D9T7nCAVDgGrg195mtl+aWZLfofzinNsN/AewC9gLHHbO/cXfVCEh0zm3FxrebAIZnfGgKocQYWbJwLPAV5xzR/zO4wczmwbsd85t8DtLiIgBxgCPO+dGA8fopE0G4cjblj4dGAz0A5LM7DP+puq+VA4hwMxiaSiG3znnnvM7j4/GAzea2XvAQuDjZva0v5F8VQlUOuc++iT5DA1lEakmATudc9XOudPAc8BVPmcKBVVmlgXgXe7vjAdVOfjMzIyGbcqlzrmf+J3HT865+5xzA5xzOTTsaHzZORex7wydc/uACjO72Fs0EdjiYyS/7QLGmVkP79/NRCJ4B30jxcAs7/os4PnOeNCYzngQ6ZDxwGeBTWb2trfsfufcEv8iSQi5B/idmcUB5cDnfM7jG+fc62b2DPAmDUf5vUWEnUrDzP4AXAukmVkl8C3gIWCRmX2ehgK9tVOeS6fPEBGRprRZSUREmlE5iIhIMyoHERFpRuUgIiLNqBxERKQZlYOIiDSjchARkWb+G+V90zTQ9MMOAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = torch.nn.Parameter(torch.randn(4, 4))\n",
"optimizer = optim.Adam([x], lr=1e-3)\n",
"scheduler = MultiStepLR(optimizer, milestones=[6, 7, 8, 9], gamma=0.5)\n",
"\n",
"def get_lr(optimizer):\n",
" for param_group in optimizer.param_groups:\n",
" return param_group['lr']\n",
" \n",
"lrs = []\n",
"for i in range(10):\n",
" lrs.append(get_lr(optimizer))\n",
" scheduler.step()\n",
"\n",
"plt.plot(range(1, 11), lrs)\n",
"for x in range(1, 11):\n",
" plt.axvline(x, c='0.8', alpha=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([19, 11, 0, 15, 18, 10, 12, 14, 17, 6, 13, 1, 4, 9, 8, 5, 2,\n",
" 3, 7, 16])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.random.permutation(20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dohoon",
"language": "python",
"name": "dohoon"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions spliceai_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from spliceai_pytorch import SpliceAI_80nt, SpliceAI_400nt, SpliceAI_2k, SpliceAI_10k
16 changes: 16 additions & 0 deletions spliceai_pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from torch.utils.data import TensorDataset, DataLoader

if __name__ == '__main__':
import torch
import h5py

h5f = h5py.File('../spliceai_train_code/Canonical/dataset_train_all.h5')
idx = 1

X, Y = h5f[f'X{idx}'][:], h5f[f'Y{idx}'][0, ...]
ds = TensorDataset(torch.from_numpy(X), torch.from_numpy(Y))
loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=8)

for batch in loader:
print(batch[0].shape, batch[1].shape)
break
11 changes: 6 additions & 5 deletions spliceai_pytorch/spliceai_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn

from einops import rearrange

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
Expand Down Expand Up @@ -42,7 +44,7 @@ def forward(self, x):
x = detour + self.block1(x)
x = self.conv_last(x)

return x[..., 40:5000 + 40].softmax(dim=-1)
return rearrange(x[..., 40:5000 + 40], 'b c l -> b l c')

class SpliceAI_400nt(nn.Module):
S = 400
Expand Down Expand Up @@ -82,8 +84,7 @@ def forward(self, x):
x = self.block2(x) + detour
x = self.conv_last(x)

return x[..., 200:5000 + 200].softmax(dim=-1)

return rearrange(x[..., 200:5000 + 200], 'b c l -> b l c')

class SpliceAI_2k(nn.Module):
S = 2000
Expand Down Expand Up @@ -135,7 +136,7 @@ def forward(self, x):
x = self.block3(x) + detour
x = self.conv_last(x)

return x[..., 1000:5000 + 1000].softmax(dim=-1)
return rearrange(x[..., 1000:5000 + 1000], 'b c l -> b l c')

class SpliceAI_10k(nn.Module):
S = 10000
Expand Down Expand Up @@ -198,7 +199,7 @@ def forward(self, x):
x = self.block4(x) + detour
x = self.conv_last(x)

return x[..., 5000:5000 + 5000].softmax(dim=-1)
return rearrange(x[..., 5000:5000 + 5000], 'b c l -> b l c')

class SpliceAI():

Expand Down
Loading

0 comments on commit 45fcdf4

Please sign in to comment.