-
Notifications
You must be signed in to change notification settings - Fork 3
/
trainer.py
executable file
·50 lines (35 loc) · 1.98 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from dataloader import get_data_for_given_ticker
from tqdm import tqdm
from tqdm import tqdm_notebook
import pandas as po
def train(ticker, window_size, train_from, train_until, model, loss_function, optimizer, num_epochs, input_dim, num_output_classes, hidden_dim, dropout_prob):
train_df, targets, dates = get_data_for_given_ticker(ticker, input_dim, start_date = train_from, end_date = train_until, train = True)
ticker_embeddings_df = po.read_csv('more_data/ticker_embeddings.csv')
embedding = ticker_embeddings_df[ticker_embeddings_df['#RIC'] == ticker]
del ticker_embeddings_df
embedding = embedding.drop('#RIC', axis = 1)
embedding *= 10000
hidden_state = torch.randn(1, 1, hidden_dim)
embedding_tensor = torch.tensor(embedding.to_numpy()).reshape(1, 1, len(embedding.columns))
hidden_state[:, :, :len(embedding.columns)] = embedding_tensor
cell_state = torch.randn(1, 1, hidden_dim)
for epoch in range(int(num_epochs)):
for i in tqdm_notebook(range(window_size, len(train_df))):
model.zero_grad()
#print(train_df[i-window_size:i][0]) #print this to check if the rolling windows are working correctly
input_ = torch.tensor(train_df[i-window_size:i].to_numpy(), dtype = torch.float).view(window_size, 1, input_dim)
prediction, (hidden_state, cell_state) = model(input_, hidden_state, cell_state, dropout_prob)
target = torch.tensor([targets[i]], dtype = torch.long)
prediction = prediction + 10**(-8)
hidden_state.detach_()
cell_state.detach_()
loss = loss_function(prediction, target)
loss.backward()
#torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
print(loss)
return model, (hidden_state, cell_state)