-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgenetic_algorithm.py
More file actions
124 lines (97 loc) · 3.89 KB
/
genetic_algorithm.py
File metadata and controls
124 lines (97 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Entry point to evolving the neural network. Start here."""
from __future__ import print_function
from evolver import Evolver
from tqdm import tqdm
from load_data import *
import csv
import datetime
import time
import logging
class GeneticAlgorithm:
def __init__(self, path, params, model, data_set):
self.path = path + '/genetic_algorithm'
self.data_set = data_set
self.params = params
self.model = model
self.population = 20
self.generations = 8
self.create_dirs()
def create_dirs(self):
os.makedirs(self.path)
os.makedirs(self.path + '/models')
os.makedirs(self.path + '/plots')
os.makedirs(self.path + '/confusion_matrix')
os.makedirs(self.path + '/conf_matrix_csv')
os.makedirs(self.path + '/conf_matrix_details')
def run(self):
print("***Evolving for %d generations with population size = %d***" % (self.generations, self.population))
self.generate()
def train_genomes(self, genomes, writer):
logging.info("***train_networks(networks, dataset)***")
pbar = tqdm(total=len(genomes))
for genome in genomes:
genome.train(self.model, self.data_set, self.path)
parameters = list()
params_csv = list()
for p in self.params:
parameters.append(genome.geneparam[p])
params_csv.append(str(genome.geneparam[p]))
params_csv.append(genome.accuracy)
row = params_csv
writer.writerow(row)
pbar.update(1)
pbar.close()
def generate(self):
logging.info("***generate(generations, population, all_possible_genes, dataset)***")
t_start = datetime.datetime.now()
t = time.time()
evolver = Evolver(self.params)
genomes = evolver.create_population(self.population)
ofile = open(self.path + '/result.csv', "w")
writer = csv.writer(ofile, delimiter=',')
table_head = list()
for p in self.params:
table_head.append(str(p))
table_head.append("accuracy")
row = table_head
writer.writerow(row)
# Evolve the generation.
for i in range(self.generations):
logging.info("***Now in generation %d of %d***" % (i + 1, self.generations))
self.print_genomes(genomes)
# Train and get accuracy for networks/genomes.
self.train_genomes(genomes, writer)
# Get the average accuracy for this generation.
average_accuracy = self.get_average_accuracy(genomes)
# Print out the average accuracy each generation.
logging.info("Generation average: %.2f%%" % (average_accuracy * 100))
logging.info('-'*80)
# Evolve, except on the last iteration.
if i != self.generations - 1:
genomes = evolver.evolve(genomes)
# Sort our final population according to performance.
genomes = sorted(genomes, key=lambda x: x.accuracy, reverse=True)
# Print out the top 5 networks/genomes.
self.print_genomes(genomes[:5])
ofile.close()
total = time.time() - t
m, s = divmod(total, 60)
h, m = divmod(m, 60)
d, h = divmod(h, 24)
t_stop = datetime.datetime.now()
file = open(self.path + '/total_time.txt', 'w')
file.write('Start : ' + str(t_start) + '\n')
file.write('Stop : ' + str(t_stop) + '\n')
file.write('Total :' + "%d days, %d:%02d:%02d" % (d, h, m, s) + '\n')
file.close()
@staticmethod
def get_average_accuracy(genomes):
total_accuracy = 0
for genome in genomes:
total_accuracy += genome.accuracy
return total_accuracy / len(genomes)
@staticmethod
def print_genomes(genomes):
logging.info('-'*80)
for genome in genomes:
genome.print_genome()