Omarrran commited on
Commit
e6f29fe
·
verified ·
1 Parent(s): 2d50f60

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from transformers import GPT2Config
5
+ from torch import nn
6
+ import requests
7
+ from pathlib import Path
8
+
9
+ class TextGenerator(nn.Module):
10
+ def __init__(self, vocab_size, embedding_dim, hidden_dim):
11
+ super().__init__()
12
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
13
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
14
+ self.fc = nn.Linear(hidden_dim, vocab_size)
15
+
16
+ def forward(self, x):
17
+ x = self.embedding(x)
18
+ lstm_out, _ = self.lstm(x)
19
+ return self.fc(lstm_out)
20
+
21
+ def download_file(url, local_path):
22
+ response = requests.get(url)
23
+ if response.status_code == 200:
24
+ Path(local_path).parent.mkdir(parents=True, exist_ok=True)
25
+ with open(local_path, 'wb') as f:
26
+ f.write(response.content)
27
+ else:
28
+ raise Exception(f"Failed to download {url}")
29
+
30
+ def load_model_and_tokenizers():
31
+ # Create a local directory for downloaded files
32
+ cache_dir = Path("model_cache")
33
+ cache_dir.mkdir(exist_ok=True)
34
+
35
+ # URLs for the files
36
+ base_url = "https://huggingface.co/Omarrran/temp_data/raw/main"
37
+ files = {
38
+ "model.pt": f"{base_url}/model.pt",
39
+ "word_to_int.json": f"{base_url}/word_to_int.json",
40
+ "int_to_word.json": f"{base_url}/int_to_word.json",
41
+ "model_config.json": f"{base_url}/model_config.json"
42
+ }
43
+
44
+ # Download all files
45
+ for filename, url in files.items():
46
+ local_path = cache_dir / filename
47
+ if not local_path.exists():
48
+ print(f"Downloading {filename}...")
49
+ download_file(url, local_path)
50
+
51
+ # Load configuration
52
+ with open(cache_dir / "model_config.json", "r") as f:
53
+ config = json.load(f)
54
+
55
+ # Load tokenizers
56
+ with open(cache_dir / "word_to_int.json", "r") as f:
57
+ word_to_int = json.load(f)
58
+ with open(cache_dir / "int_to_word.json", "r") as f:
59
+ int_to_word = json.load(f)
60
+
61
+ # Initialize model
62
+ model = TextGenerator(
63
+ vocab_size=config['vocab_size'],
64
+ embedding_dim=config['embedding_dim'],
65
+ hidden_dim=config['hidden_dim']
66
+ )
67
+
68
+ # Load model weights
69
+ model.load_state_dict(torch.load(cache_dir / "model.pt", map_location=torch.device('cpu')))
70
+ model.eval()
71
+
72
+ return model, word_to_int, int_to_word
73
+
74
+ def generate_text(prompt, max_length=100):
75
+ # Load model and tokenizers (will use cached files after first load)
76
+ model, word_to_int, int_to_word = load_model_and_tokenizers()
77
+
78
+ # Tokenize input prompt
79
+ input_ids = [word_to_int.get(word, word_to_int['<UNK>']) for word in prompt.split()]
80
+ input_tensor = torch.tensor([input_ids])
81
+
82
+ # Generate text
83
+ generated_ids = input_ids.copy()
84
+
85
+ with torch.no_grad():
86
+ for _ in range(max_length):
87
+ current_input = torch.tensor([generated_ids[-50:]]) # Use last 50 tokens as context
88
+ outputs = model(current_input)
89
+ next_token_id = outputs[0, -1, :].argmax().item()
90
+ generated_ids.append(next_token_id)
91
+
92
+ if next_token_id == word_to_int.get('<EOS>', 0):
93
+ break
94
+
95
+ # Convert ids back to text
96
+ generated_text = ' '.join([int_to_word.get(str(idx), '<UNK>') for idx in generated_ids])
97
+ return generated_text
98
+
99
+ # Create Gradio interface
100
+ iface = gr.Interface(
101
+ fn=generate_text,
102
+ inputs=[
103
+ gr.Textbox(label="Enter your prompt", placeholder="Type your text here..."),
104
+ gr.Slider(minimum=10, maximum=200, value=100, label="Maximum length", step=1)
105
+ ],
106
+ outputs=gr.Textbox(label="Generated Text"),
107
+ title="Text Generation Model",
108
+ description="Enter a prompt and the model will generate text based on it.",
109
+ examples=[
110
+ ["The quick brown fox"],
111
+ ["Once upon a time"],
112
+ ["In a galaxy far"]
113
+ ]
114
+ )
115
+
116
+ # Launch the interface
117
+ if __name__ == "__main__":
118
+ iface.launch()