diff --git a/char-rnn-generation/generate.py b/char-rnn-generation/generate.py index 1302d53..a406577 100644 --- a/char-rnn-generation/generate.py +++ b/char-rnn-generation/generate.py @@ -1,46 +1,52 @@ # https://github.com/spro/practical-pytorch +# -*- coding: utf-8 -*- import torch from helpers import * from model import * -def generate(decoder, prime_str='A', predict_len=100, temperature=0.8): +def generate(decoder, all_characters, prime_str='A', predict_len=100, temperature=0.8): hidden = decoder.init_hidden() - prime_input = char_tensor(prime_str) + prime_input = char_tensor(prime_str, all_characters) predicted = prime_str # Use priming string to "build up" hidden state for p in range(len(prime_str) - 1): _, hidden = decoder(prime_input[p], hidden) - + inp = prime_input[-1] - + for p in range(predict_len): output, hidden = decoder(inp, hidden) - + # Sample from the network as a multinomial distribution output_dist = output.data.view(-1).div(temperature).exp() top_i = torch.multinomial(output_dist, 1)[0] - + # Add predicted character to string and use as next input predicted_char = all_characters[top_i] predicted += predicted_char - inp = char_tensor(predicted_char) + inp = char_tensor(predicted_char, all_characters) return predicted if __name__ == '__main__': # Parse command line arguments - import argparse + import argparse, pickle argparser = argparse.ArgumentParser() argparser.add_argument('filename', type=str) argparser.add_argument('-p', '--prime_str', type=str, default='A') argparser.add_argument('-l', '--predict_len', type=int, default=100) argparser.add_argument('-t', '--temperature', type=float, default=0.8) + argparser.add_argument('-f', '--charset-file', type=str, default='charset.pickle') args = argparser.parse_args() - + print args + with open(args.charset_file) as fd: + all_characters = pickle.load(fd) decoder = torch.load(args.filename) del args.filename - print(generate(decoder, **vars(args))) - + del args.charset_file + print all_characters + #print(generate(decoder=decoder, all_characters=all_characters, **vars(args))) + print generate(decoder, all_characters=all_characters, prime_str='अध्याय', predict_len=500) diff --git a/char-rnn-generation/helpers.py b/char-rnn-generation/helpers.py index 291f5b1..6c2325f 100644 --- a/char-rnn-generation/helpers.py +++ b/char-rnn-generation/helpers.py @@ -10,19 +10,25 @@ # Reading and un-unicode-encoding data -all_characters = string.printable -n_characters = len(all_characters) +#all_characters = string.printable +#n_characters = len(all_characters) +#all_characters = ['\x83', '\x87', '\x8b', '\x8f', '\x93', '\x97', '\x9b', '\x9f', ' ', '\xa3', '\xa7', '(', '\xab', ',', '\xaf', '\xb7', '\xbb', '\xbf', '\xc3', 'H', 'L', 'P', 'T', 'd', 'h', 'l', '\xef', 'p', 't', '|', '\x80', '\x88', '\x8c', '\x90', '\x94', '\x98', '\x9c', '\xa0', '\xa4', "'", '\xa8', '\xac', '/', '\xb0', '\xb8', '\xbc', '?', 'C', 'S', 'W', '\xe0', 'c', 'g', 'k', 'o', 's', 'w', '\x81', '\x85', '\x89', '\n', '\x8d', '\x95', '\x99', '\x9d', '\xa1', '\xa5', '*', '\xad', '.', '\xb1', '\xb5', '\xb9', ':', '\xbd', 'B', 'F', 'J', 'V', 'b', 'f', 'n', 'r', 'v', 'z', '~', '\x82', '\x86', '\x8a', '\x96', '\x9a', '\x9e', '!', '\xa2', '\xa6', ')', '\xaa', '-', '\xae', '\xb2', '\xb6', '\xbe', 'A', '\xc2', 'E', 'I', 'M', 'U', 'Y', 'a', '\xe2', 'e', 'i', 'm', 'q', 'u', 'y'] +#n_characters = len(all_characters) def read_file(filename): - file = unidecode.unidecode(open(filename).read()) - return file, len(file) + #global all_characters + #global n_characters + s = open(filename).read() + all_characters = [i for i in set(s)] + n_characters = len(all_characters) + return s, len(s), all_characters, n_characters # Turning a string into a tensor -def char_tensor(string): - tensor = torch.zeros(len(string)).long() - for c in range(len(string)): - tensor[c] = all_characters.index(string[c]) +def char_tensor(s, all_characters): + tensor = torch.zeros(len(s)).long() + for c in range(len(s)): + tensor[c] = all_characters.index(s[c]) return Variable(tensor) # Readable time elapsed @@ -32,4 +38,3 @@ def time_since(since): m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) - diff --git a/char-rnn-generation/train.py b/char-rnn-generation/train.py index ab46538..c35b217 100644 --- a/char-rnn-generation/train.py +++ b/char-rnn-generation/train.py @@ -1,4 +1,5 @@ # https://github.com/spro/practical-pytorch +# -*- coding: utf-8 -*- import torch import torch.nn as nn @@ -21,14 +22,14 @@ argparser.add_argument('--chunk_len', type=int, default=200) args = argparser.parse_args() -file, file_len = read_file(args.filename) +file, file_len, all_characters, n_characters = read_file(args.filename) def random_training_set(chunk_len): start_index = random.randint(0, file_len - chunk_len) end_index = start_index + chunk_len + 1 chunk = file[start_index:end_index] - inp = char_tensor(chunk[:-1]) - target = char_tensor(chunk[1:]) + inp = char_tensor(chunk[:-1], all_characters) + target = char_tensor(chunk[1:], all_characters) return inp, target decoder = RNN(n_characters, args.hidden_size, n_characters, args.n_layers) @@ -56,6 +57,9 @@ def train(inp, target): def save(): save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt' torch.save(decoder, save_filename) + import pickle + with open("charset.pickle", "w") as fd: + pickle.dump(all_characters, fd) print('Saved as %s' % save_filename) try: @@ -66,7 +70,8 @@ def save(): if epoch % args.print_every == 0: print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / args.n_epochs * 100, loss)) - print(generate(decoder, 'Wh', 100), '\n') + #print(generate(decoder, 'Wh', 100), '\n') + print generate(decoder, all_characters=all_characters, prime_str='अध्याय', predict_len=500) print("Saving...") save() @@ -74,4 +79,3 @@ def save(): except KeyboardInterrupt: print("Saving before quit...") save() -