This repository was archived by the owner on Nov 29, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainAI.py
221 lines (174 loc) · 6.66 KB
/
trainAI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import pygame
import numpy
import random
import neat
import os
import pickle
from resources.reference import *
from resources.util import *
from gamesrc.grid import Grid
from gamesrc.snake import Snake
from gamesrc.food import Food
def has_failed(snake, genome):
"""
This function checks to see if the snake has met any of the fail conditions.
Since the grid is a standard 30 by 30 grid, the grid object is not required
to check extremes.
Arguments:
snake {Snake} -- The snake whose status is to be found.
genome -- The genome controlling this snake
Returns:
boolean -- true if it has met a fail condition and False otherwise.
"""
x, y = snake.coords[0][0], snake.coords[0][1]
if ((x, y) in snake.coords[1:]):
genome.fitness -= 3 # Loses a lot of points for hitting itself
return True
elif x < 0 or x > 29 or y < 0 or y > 29:
# Loses slightly less pints for hitting the wall because inputs used for
# the NN generally mean this case is already rare.
genome.fitness -= 2.5
return True
else:
return False
def draw(window, snake, food, score):
"""
This function draws and updates the pygame window with the given information.
Arguments:
window {Surface} -- The active PyGame window
snake {Snake} -- The snake in the game
food {Food} -- The food in the game
score {int} -- The current score
"""
global ANIMATION_TICK
ANIMATION_TICK -= 1 # Used to change food color
window.fill((0, 0, 51))
# Food color changing and drawing is handled here
if ANIMATION_TICK == 0:
global FOOD_RGB
FOOD_RGB = (
random.randrange(
50, 255), random.randrange(
50, 255), random.randrange(
50, 255))
FOOD_IMG.fill(FOOD_RGB)
ANIMATION_TICK = 25
else:
window.blit(FOOD_IMG, (food.x * 15, food.y * 15))
# Draw snake
for i, coord in enumerate(snake.coords):
x = coord[0] * 15
y = coord[1] * 15
if i == 0:
head = pygame.Surface((15, 15))
head.fill((255, 255, 255))
WINDOW.blit(head, (x, y))
else:
WINDOW.blit(SNAKE_IMG, (x, y))
# Draw Score
score_txt = STAT_FONT.render("Score: " + str(score), 1, (255, 255, 255))
window.blit(
score_txt,
(WIN_WIDTH - 10 - score_txt.get_width(), 10)) # top right of screen
pygame.display.update()
def eval(genomes, config):
"""
This function runs the game and evalutes the NNs formed from the given config.
"""
global FOOD_RGB
for _, genome in genomes:
network = neat.nn.FeedForwardNetwork.create(genome, config)
genome.fitness = 0
grid = Grid()
snake = Snake()
food = generate_food(grid, snake)
score = 0
# Used to track time since distance from current food and time since
# last food
food_cur_dist = (abs(food.x -
snake.coords[0][0]), abs(food.y -
snake.coords[0][1]))
last_food = pygame.time.get_ticks()
isRunning = True
while isRunning:
game_clock.tick(1000) # I am speed
# Time since last food in seconds
time_since_food = (pygame.time.get_ticks() - last_food) / 1000
# At 1000fps if it hasn't found food in 4 seconds it is definitely
# self looping so we stop it.
if time_since_food >= 4:
genome.fitness -= 500
isRunning = False
break
# Handle Quitting
for event in pygame.event.get():
if event.type == pygame.QUIT:
isRunning = False
pygame.quit()
quit()
# Make a decision and use it
decision = make_decision(network, grid, snake, food)
max_val = max(decision)
# Decision size is 3 and each index from 0 to 2 represents L,R and
# nothing. We simply choose the maximum of these three to make a
# decision.
if max_val == decision[0]:
snake.move(grid, "L")
elif max_val == decision[1]:
snake.move(grid, "R")
else:
snake.tick(grid)
# Update current distance to food
food_prev_dist = food_cur_dist
food_cur_dist = (
abs(food.x - snake.coords[0][0]), abs(food.y - snake.coords[0][1]))
# If it moved closer to the food give it points and if it moved away
# take away more points. On the whole self loops lose points
if food_cur_dist[0] < food_prev_dist[0] or food_cur_dist[1] < food_prev_dist[1]:
genome.fitness += 1
else:
genome.fitness -= 1.5
# Check collision
if snake.collide(food):
last_food = pygame.time.get_ticks() # Reset last food timer
genome.fitness += 4 # Lots of fitness
score += 1
snake.elongate(grid)
# RGB snakes are essential for training /s
SNAKE_IMG.fill(FOOD_RGB)
food = generate_food(grid, snake)
# Check for failure and deduct points accordingly
if has_failed(snake, genome):
isRunning = False
break
draw(WINDOW, snake, food, score)
# A score of 45 generally means it has gotten as good as it could
# have, so we store the model
if score >= 45:
best_model = network
nn_file = open("best_model.pickle", "wb")
pickle.dump(best_model, nn_file)
nn_file.close()
isRunning = False
break
def run(config_path):
"""
This function runs each generation of NNs using the configuration file
passed to it.
Arguments:
config_path -- Path to the NNs config file
"""
config = neat.config.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_path)
population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
population.add_reporter(stats)
winner = population.run(eval, 50)
def main():
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, "resources/config-feedforward.txt")
run(config_path)
if __name__ == "__main__":
main()