-
Notifications
You must be signed in to change notification settings - Fork 12
/
constants.py
31 lines (25 loc) · 894 Bytes
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
DATA_AUDIO_DIR = '/home/philippe/UrbanSound8K/audio'
# DATA_AUDIO_DIR = '/Users/philipperemy/Downloads/UrbanSound8K/audio'
TARGET_SR = 8000
OUTPUT_DIR = '/tmp/very-deep-conv-nets-raw-waveforms'
OUTPUT_DIR_TRAIN = os.path.join(OUTPUT_DIR, 'train')
OUTPUT_DIR_TEST = os.path.join(OUTPUT_DIR, 'test')
AUDIO_LENGTH = 32000
def print_delimiter():
print('-' * 80)
def print_total_trainable_parameters_count():
import tensorflow as tf
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
print(shape)
# print(len(shape))
variable_parameters = 1
for dim in shape:
# print(dim)
variable_parameters *= dim.value
# print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)