-
Notifications
You must be signed in to change notification settings - Fork 5
/
data.py
106 lines (95 loc) · 3.33 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from typing import Union
from torch.utils.data import Dataset, DataLoader
from interfaces import IDataLoader, IPipeline, IPadder
SPK_ID = 0
PATH_ID = 1
TEXT_ID = 2
DURATION_ID = 3
class Data(Dataset):
def __init__(
self,
data_loader: IDataLoader,
aud_pipeline: IPipeline,
text_pipeline: IPipeline,
aud_padder: IPadder,
text_padder: IPadder,
batch_size: int,
sep: str
) -> None:
super().__init__()
self.sep = sep
self.aud_pipeline = aud_pipeline
self.text_pipeline = text_pipeline
self.aud_padder = aud_padder
self.text_padder = text_padder
self.batch_size = batch_size
self.data = self.process(data_loader)
self.max_speech_lens = []
self.max_text_lens = []
self.n_batches = len(self.data) // self.batch_size
if len(self.data) % batch_size > 0:
self.n_batches += 1
self.__set_max_text_lens()
def process(self, data_loader: IDataLoader):
data = data_loader.load().split('\n')
data = [item.split(self.sep) for item in data]
data = sorted(data, key=lambda x: x[DURATION_ID], reverse=True)
return data
def __set_max_text_lens(self):
for i, item in enumerate(self.data):
idx = i // self.batch_size
length = len(item[TEXT_ID])
if idx >= len(self.max_text_lens):
self.max_text_lens.append(length)
else:
self.max_text_lens[idx] = max(length, self.max_text_lens[idx])
def __len__(self) -> int:
return len(self.data)
def _get_max_len(self, idx: int) -> Union[None, int]:
bucket_id = idx // self.batch_size
if bucket_id >= len(self.max_speech_lens):
return None, self.max_text_lens[bucket_id] + 1
return (
self.max_speech_lens[bucket_id],
self.max_text_lens[bucket_id] + 1
)
def __getitem__(self, idx: int):
[spk_id, file_path, text, _] = self.data[idx]
spk_id = int(spk_id)
max_speech_len, max_text_len = self._get_max_len(idx)
text = self.text_pipeline.run(text)
text = self.text_padder.pad(text, max_text_len)
speech = self.aud_pipeline.run(file_path)
speech_length = speech.shape[0]
mask = [True] * speech_length
if max_speech_len is not None:
mask.extend([False] * (max_speech_len - speech_length))
speech = self.aud_padder.pad(speech, max_speech_len)
else:
self.max_speech_lens.append(speech_length)
mask = torch.BoolTensor(mask)
spk_id = torch.LongTensor([spk_id])
return speech, speech_length, mask, text, spk_id
def get_batch_loader(
data_loader: IDataLoader,
aud_pipeline: IPipeline,
text_pipeline: IPipeline,
aud_padder: IPadder,
text_padder: IPadder,
batch_size: int,
sep: str
):
return DataLoader(
Data(
data_loader=data_loader,
aud_pipeline=aud_pipeline,
text_pipeline=text_pipeline,
aud_padder=aud_padder,
text_padder=text_padder,
batch_size=batch_size,
sep=sep
),
batch_size=batch_size,
shuffle=False
)