RecipeGenerator / app.py
parkermoe
'updated text'
f46039b
import torch
import torch.nn as nn
import os
import pickle
from torch.functional import F
import numpy as np
import gradio as gr
import torchtext
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')
VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 128
EPOCHS = 25
# loading model from checkpoint
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.log_softmax = nn.LogSoftmax(dim=2)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.fc(x)
return self.log_softmax(x)
# loading model from checkpoint
model = LSTMModel(VOCAB_SIZE, EMBEDDING_DIM, N_UNITS).to(device)
device = 'cpu'
checkpoint_path = 'recipe_generator_LSTM.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
print('Loaded model from checkpoint')
def load_vocab(file_path):
file_path = os.path.join(file_path)
with open(file_path, 'rb') as input:
vocab = pickle.load(input)
print(f"Vocabulary loaded from {file_path}")
return vocab
vocab = load_vocab('vocab.pkl')
class TextGenerator:
def __init__(self, vocab, top_k=10):
self.vocab = vocab
self.top_k = top_k
def sample_from(self, logits, temperature):
probs = F.softmax(logits / temperature, dim=-1).cpu().numpy()
return np.random.choice(len(probs), p=probs)
def generate(self, model, device, start_prompt, max_tokens, temperature):
model.eval()
tokens = [self.vocab.get_stoi()[token] for token in start_prompt.split()]
tokens = torch.LongTensor(tokens).unsqueeze(0).to(device)
with torch.no_grad():
for _ in range(max_tokens):
output = model(tokens)
next_token_logits = output[0, -1, :]
next_token = self.sample_from(next_token_logits, temperature)
tokens = torch.cat([tokens, torch.LongTensor([[next_token]]).to(device)], dim=1)
generated_tokens = [token for token in tokens[0] if self.vocab.get_itos()[token] != '<pad>']
generated_text = ' '.join(self.vocab.get_itos()[token] for token in generated_tokens)
return generated_text
text_generator = TextGenerator(vocab=vocab, top_k=10)
generated_text = text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5)
print(f"\nGenerated Text: {generated_text}")
def generate_recipe():
return text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5)
iface = gr.Interface(
fn=generate_recipe,
inputs=[],
outputs="text",
title="Recipe Generator",
description="This is a LSTM based Recurrent Neural Network trained to generate recipes. Press submit to generate a new recipe that can sometimes provide humor!",
)
iface.launch()