-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_sequence_learning_data.py
117 lines (94 loc) · 4.65 KB
/
create_sequence_learning_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import numpy as np
import pickle
from pathlib import Path
from sklearn.model_selection import train_test_split
if __name__ == '__main__':
"""
Task: Every next number in a sequence is an increment of the previous one. Predict the number succeeding a sequence.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--seq_len', type=int, default=10, help='length of the sequence/sample')
parser.add_argument('--samples', type=int, default=5000, help='the number of samples to return')
parser.add_argument('--increment', type=int, default=1000,
help='the difference between consecutive numbers')
parser.add_argument('--max_starting_point', type=int, default=1000)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--test_fraction', type=float, default=0.1)
parser.add_argument('--output_dir', type=str, help='name of the file to write', default="data/sequence_learning")
parser.add_argument('--variable_length', action='store_true', help='if set, the sequences will have different length')
parser.add_argument('--variable_increment', action='store_true', help='if set, the increment will be different for each sequence')
args = parser.parse_args()
np.random.seed(args.seed)
if args.variable_length:
out_array = []
targets_array = []
for i in range(args.samples):
if args.variable_increment:
increment = np.random.uniform(-args.max_starting_point, high=args.max_starting_point, size=(1))
else:
increment = args.increment
out_array.append(np.random.uniform(
-args.max_starting_point, high=args.max_starting_point, size=(1)))
for j in range(1, args.seq_len - i % args.seq_len):
out_array[i] = np.append(out_array[i], out_array[i][j-1] + increment)
targets_array.append(out_array[i][-1] + args.increment)
else:
out_array = np.zeros((args.samples, args.seq_len))
if args.variable_increment:
increment = np.random.uniform(-args.max_starting_point, high=args.max_starting_point, size=(args.samples))
else:
increment = args.increment
out_array[:, 0] = np.random.uniform(
-args.max_starting_point, high=args.max_starting_point, size=args.samples)
for i in range(1, args.seq_len):
out_array[:, i] = out_array[:, i - 1] + increment
targets_array = out_array[:, -1] + increment
# create data splits
X_train, X_test, y_train, y_test = train_test_split(out_array, targets_array,
test_size=args.test_fraction,
random_state=args.seed)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train,
test_size=args.test_fraction,
random_state=args.seed)
# create directory if doesn't exist
if args.variable_length:
Path(args.output_dir+"/varied_length").mkdir(parents=True, exist_ok=True)
else:
Path(args.output_dir+"/fixed_length").mkdir(parents=True, exist_ok=True)
split_data_mapping = {"train": (X_train, y_train),
"val": (X_val, y_val),
"test": (X_test, y_test)}
for split in ["train", "val", "test"]:
X, y = split_data_mapping[split]
if args.variable_length:
p = args.output_dir+"/varied_length"
if args.variable_increment:
p += "/variable_increment/"
else:
p += "/fixed_increment"
Path(p).mkdir(parents=True, exist_ok=True)
Xpath = os.path.join(p, f"X_{split}")
ypath = os.path.join(p, f"y_{split}")
else:
p = args.output_dir+"/fixed_length"
if args.variable_increment:
p += "/variable_increment/"
else:
p += "/fixed_increment"
Path(p).mkdir(parents=True, exist_ok=True)
Xpath = os.path.join(p, f"X_{split}")
ypath = os.path.join(p, f"y_{split}")
# if file already exists remove it
if os.path.isfile(Xpath):
os.remove(Xpath)
Xfile = open(Xpath, 'ab')
pickle.dump(X, Xfile)
Xfile.close()
# if file already exists remove it
if os.path.isfile(ypath):
os.remove(ypath)
yfile = open(ypath, 'ab')
pickle.dump(y, yfile)
yfile.close()