-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
115 lines (90 loc) · 4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import pandas as pd
import os
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator, load_img
from keras.models import load_model
import matplotlib.pyplot as plt
from keras.utils import plot_model
from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
import datetime
from count_exp import count_exp
from display_distribution import display_distribution
from fer_model import fer_model
from model_data import model_data
from loss_plots import loss_plots
train_dir = '/kaggle/input/fer2013/train/' # Input directory path for training images
test_dir = '/kaggle/input/fer2013/test/' # Input directory path for testing images
img_l, img_w = 48, 48
num_classes = 7
train_count = count_exp(train_dir, 'train')
test_count = count_exp(test_dir, 'test')
print(train_count)
print(test_count)
display_distribution(train_count, test_count)
plt.figure(figsize = (14, 22))
i = 1
for x in os.listdir(train_dir):
img = load_img((train_dir + x + '/' + os.listdir(train_dir + x)[1]))
plt.subplot(1, num_classes, i)
plt.imshow(img)
plt.title(x)
plt.axis('off')
i += 1
plt.show()
image_size= 48 # image is a 48*48 grid of pixels
batch_size= 64 # number of images to process at a time
train_datagen = ImageDataGenerator(rescale = 1./255,
zoom_range = 0.3,
horizontal_flip = True)
training_set = train_datagen.flow_from_directory(train_dir,
batch_size = 64,
target_size = (image_size, image_size),
shuffle = True,
color_mode = 'grayscale',
class_mode = 'categorical')
test_datagen = ImageDataGenerator(rescale = 1./255)
test_set = test_datagen.flow_from_directory(test_dir,
batch_size = 64,
target_size = (image_size, image_size),
shuffle = True,
color_mode = 'grayscale',
class_mode = 'categorical')
training_set.class_indices
fer_model = fer_model((img_l, img_w, 1), num_classes)
steps_per_epoch = training_set.n // training_set.batch_size
validation_steps = test_set.n // test_set.batch_size
chkpt_path = 'fer_model.h5'
log_dir = "checkpoint/logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint = ModelCheckpoint(filepath = chkpt_path,
save_best_only = True,
verbose = 1,
mode = 'min',
moniter = 'val_loss')
earlystop = EarlyStopping(monitor = 'val_loss',
min_delta = 0,
patience = 3,
verbose = 1,
restore_best_weights = True)
reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',
factor = 0.2,
patience = 6,
verbose = 1,
min_delta = 0.0001)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir, histogram_freq = 1)
csv_logger = CSVLogger('training.log')
callback = [checkpoint, reduce_lr, csv_logger]
# Training
fer_model_hist = fer_model.fit(x = training_set,
validation_data = test_set,
epochs = 100,
callbacks = callback,
steps_per_epoch = steps_per_epoch,
validation_steps = validation_steps,
class_weight=class_weights)
fer_model.save_weights('fer_model_bestweights.h5')
model = load_model('fer_model.h5')
model_data(model, training_set, test_set)
predict(model)
loss_plots(fer_model_hist)
loss_plots()