import torch import torch.nn as nn import os import pickle from torch.functional import F import numpy as np import gradio as gr import torchtext #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device('cpu') VOCAB_SIZE = 10000 MAX_LEN = 200 EMBEDDING_DIM = 100 N_UNITS = 128 VALIDATION_SPLIT = 0.2 SEED = 42 LOAD_MODEL = False BATCH_SIZE = 128 EPOCHS = 25 # loading model from checkpoint class LSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim): super(LSTMModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) self.log_softmax = nn.LogSoftmax(dim=2) def forward(self, x): x = self.embedding(x) x, _ = self.lstm(x) x = self.fc(x) return self.log_softmax(x) # loading model from checkpoint model = LSTMModel(VOCAB_SIZE, EMBEDDING_DIM, N_UNITS).to(device) device = 'cpu' checkpoint_path = 'recipe_generator_LSTM.pth' checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint) print('Loaded model from checkpoint') def load_vocab(file_path): file_path = os.path.join(file_path) with open(file_path, 'rb') as input: vocab = pickle.load(input) print(f"Vocabulary loaded from {file_path}") return vocab vocab = load_vocab('vocab.pkl') class TextGenerator: def __init__(self, vocab, top_k=10): self.vocab = vocab self.top_k = top_k def sample_from(self, logits, temperature): probs = F.softmax(logits / temperature, dim=-1).cpu().numpy() return np.random.choice(len(probs), p=probs) def generate(self, model, device, start_prompt, max_tokens, temperature): model.eval() tokens = [self.vocab.get_stoi()[token] for token in start_prompt.split()] tokens = torch.LongTensor(tokens).unsqueeze(0).to(device) with torch.no_grad(): for _ in range(max_tokens): output = model(tokens) next_token_logits = output[0, -1, :] next_token = self.sample_from(next_token_logits, temperature) tokens = torch.cat([tokens, torch.LongTensor([[next_token]]).to(device)], dim=1) generated_tokens = [token for token in tokens[0] if self.vocab.get_itos()[token] != ''] generated_text = ' '.join(self.vocab.get_itos()[token] for token in generated_tokens) return generated_text text_generator = TextGenerator(vocab=vocab, top_k=10) generated_text = text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5) print(f"\nGenerated Text: {generated_text}") def generate_recipe(): return text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5) iface = gr.Interface( fn=generate_recipe, inputs=[], outputs="text", title="Recipe Generator", description="This is a LSTM based Recurrent Neural Network trained to generate recipes. Press submit to generate a new recipe that can sometimes provide humor!", ) iface.launch()