-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
91 lines (73 loc) · 3.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from gender_cnn import GenderCnn
import pandas as panda
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
def save_csv(data_frame, file_path, separator=',', encoding='utf-8', float_format='%.7f'):
data_frame.to_csv(file_path, separator, encoding, float_format)
def train_net(net, train_loader, val_loader, n_epochs, loss_fn, optimizer):
# Uses a list to record down some statistics.
statistics = []
for epoch in range(n_epochs):
total_train_loss = 0.0
# Iterates through every data point in the training data-set.
for i, data in enumerate(train_loader, 0):
# Get inputs.
inputs, labels = data
# Set the parameter gradients to zero.
optimizer.zero_grad()
# Forward pass, backward pass, optimize.
outputs = net(inputs)
loss_size = loss_fn(outputs, labels)
loss_size.backward()
optimizer.step()
# Updates the loss value.
total_train_loss += loss_size.item()
print("Finished the training of %d epoch(es)." % epoch)
storage_path = 'LFW_model_torch/cnn_epoch{}.pkl'.format(epoch)
torch.save(net.state_dict(), storage_path)
print("Saved the network at %s." % storage_path)
# Iterates through every data point in the validation data-set.
total_val_loss = 0.0
total_count = 0
correct_count = 0
for inputs, label in val_loader:
val_outputs = net(inputs)
_, predicted = torch.max(val_outputs.data, 1)
total_count += label.size(0)
correct_count += (predicted == label).sum().item()
val_loss_size = loss_fn(val_outputs, label)
total_val_loss += val_loss_size.item()
print("epoch=%d training loss=%.3f." % (epoch, total_train_loss / len(train_loader)))
print("epoch=%d validation loss=%.3f." % (epoch, total_val_loss / len(val_loader)))
print("epoch=%d accuracy=%.3f.\n" % (epoch, 1.0 * correct_count / total_count))
statistics.append({"train_loss": total_train_loss / len(train_loader),
"val_loss": total_val_loss / len(val_loader),
"accuracy": 1.0 * correct_count / total_count})
save_csv(panda.DataFrame(statistics), 'LFW_model_torch/statistics_{}.csv'.format(time()))
print("Finished the training of all epoch(es).")
# Creates the data-set for training and validation.
transformer = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
train_set = ImageFolder("LFW_extract/train", transformer)
print("Finished loading the training data-set.")
val_set = ImageFolder("LFW_extract/val", transformer)
print("Finished loading the validation data-set.")
# Creates the data-loader for training and validation.
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
print("Finished creating the training data-loader.")
val_loader = DataLoader(val_set, batch_size=4, shuffle=True, num_workers=2)
print("Finished creating the validation data-loader.")
# Initializes a CNN instance.
cnn_net = GenderCnn()
# Defines the loss function.
loss_fn = nn.CrossEntropyLoss()
# Uses a SGD-based optimizer.
sgd_optimizer = optim.SGD(cnn_net.parameters(), lr=0.001, momentum=0.9)
# Starts the training and validation of the model.
print("Going to train the model...")
train_net(cnn_net, train_loader, val_loader, 20, loss_fn, sgd_optimizer)
print("Finished training the model...")