Skip to content

Commit 09f8a28

Browse files
committed
feat: merge PR aimacode#1280
1 parent 0a83120 commit 09f8a28

File tree

4 files changed

+872
-532
lines changed

4 files changed

+872
-532
lines changed

deep_learning4e.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import numpy as np
77
from keras import Sequential, optimizers
88
from keras.layers import Embedding, SimpleRNN, Dense
9-
from keras.preprocessing import sequence
9+
#from keras.preprocessing import sequence
10+
from keras.utils.data_utils import pad_sequences
1011

1112
from utils4e import (conv1D, gaussian_kernel, element_wise_product, vector_add, random_weights,
1213
scalar_vector_product, map_vector, mean_squared_error_loss)
@@ -518,8 +519,10 @@ def keras_dataset_loader(dataset, max_length=500):
518519
# init dataset
519520
(X_train, y_train), (X_val, y_val) = dataset
520521
if max_length > 0:
521-
X_train = sequence.pad_sequences(X_train, maxlen=max_length)
522-
X_val = sequence.pad_sequences(X_val, maxlen=max_length)
522+
#X_train = sequence.pad_sequences(X_train, maxlen=max_length)
523+
#X_val = sequence.pad_sequences(X_val, maxlen=max_length)
524+
X_train = pad_sequences(X_train, maxlen=max_length)
525+
X_val = pad_sequences(X_val, maxlen=max_length)
523526
return (X_train[10:], y_train[10:]), (X_val, y_val), (X_train[:10], y_train[:10])
524527

525528

gui/grid_mdp.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import tkinter.messagebox
55
from functools import partial
66
from tkinter import ttk
7+
import time
8+
9+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
710

811
import matplotlib
912
import matplotlib.animation as animation
@@ -15,12 +18,11 @@
1518

1619
from mdp import *
1720

18-
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
1921

2022
matplotlib.use('TkAgg')
2123
style.use('ggplot')
2224

23-
fig = Figure(figsize=(20, 15))
25+
fig = Figure(figsize=(20, 25))
2426
sub = fig.add_subplot(111)
2527
plt.rcParams['axes.grid'] = False
2628

@@ -47,11 +49,12 @@ def extents(f):
4749
return [f[0] - delta / 2, f[-1] + delta / 2]
4850

4951

50-
def display(gridmdp, _height, _width):
52+
def display(gridmdp, _height, _width, _a):
5153
"""displays matrix"""
5254

5355
dialog = tk.Toplevel()
54-
dialog.wm_title('Values')
56+
#dialog.wm_title('Values')
57+
dialog.wm_title(_a)
5558

5659
container = tk.Frame(dialog)
5760
container.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
@@ -125,7 +128,8 @@ def initialize_dialogbox(_width, _height, gridmdp, terminals, buttons):
125128
btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy)
126129
btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5)
127130

128-
dialog.geometry('400x200')
131+
#dialog.geometry('400x200')
132+
dialog.geometry('1600x200')
129133
dialog.mainloop()
130134

131135

@@ -393,7 +397,7 @@ def view_matrix(self):
393397
_height = self.shared_data['height'].get()
394398
_width = self.shared_data['width'].get()
395399
print(build_page.gridmdp)
396-
display(build_page.gridmdp, _height, _width)
400+
display(build_page.gridmdp, _height, _width, "aaaaa")
397401

398402
def view_terminals(self):
399403
"""prints current terminals to console"""
@@ -570,6 +574,9 @@ def __init__(self, parent, controller):
570574
self.epsilon = 0.001
571575
self.delta = 0
572576

577+
def print_inter_matrix(self, values, h, w):
578+
display(values, h, w, "aaaaa")
579+
573580
def process_data(self, terminals, _height, _width, gridmdp):
574581
"""preprocess variables"""
575582

@@ -606,7 +613,7 @@ def create_graph(self, gridmdp, terminals, _height, _width):
606613

607614
self.canvas = FigureCanvasTkAgg(fig, self.frame)
608615
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
609-
self.anim = animation.FuncAnimation(fig, self.animate_graph, interval=50)
616+
self.anim = animation.FuncAnimation(fig, self.animate_graph, interval=300)
610617
self.canvas.show()
611618

612619
def animate_graph(self, i):
@@ -620,8 +627,14 @@ def animate_graph(self, i):
620627
y = np.linspace(0, len(self.gridmdp) - 1, y_interval)
621628

622629
sub.clear()
623-
sub.imshow(self.grid_to_show, cmap='BrBG', aspect='auto', interpolation='none', extent=extents(x) + extents(y),
630+
#sub.imshow(self.grid_to_show, cmap='BrBG', aspect='auto', interpolation='none', extent=extents(x) + extents(y),
631+
# origin='lower')
632+
633+
for (j,i),label in np.ndenumerate(self.grid_to_show):
634+
sub.text(i,j,label,ha='center',va='center')
635+
sub.imshow(self.grid_to_show, aspect='auto', interpolation='none', extent=extents(x) + extents(y),
624636
origin='lower')
637+
625638
fig.tight_layout()
626639

627640
U = self.U1.copy()
@@ -634,16 +647,20 @@ def animate_graph(self, i):
634647
self.grid_to_show = grid_to_show = [[0.0] * max(1, self._width) for _ in range(max(1, self._height))]
635648
for k, v in U.items():
636649
self.grid_to_show[k[1]][k[0]] = v
650+
651+
#time.sleep(1)
637652

638653
if (self.delta < self.epsilon * (1 - self.gamma) / self.gamma) or (
639654
self.iterations > 60) and self.terminated is False:
640655
self.terminated = True
641-
display(self.grid_to_show, self._height, self._width)
656+
display(self.grid_to_show, self._height, self._width, "Final Value")
642657

643658
pi = best_policy(self.sequential_decision_environment,
644659
value_iteration(self.sequential_decision_environment, .01))
645660
display_best_policy(self.sequential_decision_environment.to_arrows(pi), self._height, self._width)
646661

662+
663+
647664
ax = fig.gca()
648665
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
649666
ax.yaxis.set_major_locator(MaxNLocator(integer=True))

0 commit comments

Comments
 (0)