Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,10 @@ import aiohttp
|
|
7 |
import json
|
8 |
import torch
|
9 |
import re
|
|
|
|
|
|
|
|
|
10 |
|
11 |
repo_name = "BeardedMonster/SabiYarn-125M"
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -147,6 +151,14 @@ async def generate_from_api(user_input, generation_config):
|
|
147 |
|
148 |
return "FAILED"
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Sample texts
|
152 |
sample_texts = {
|
@@ -245,10 +257,12 @@ if st.button("Generate"):
|
|
245 |
# Attempt the asynchronous API call
|
246 |
generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
247 |
# generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
|
|
|
|
|
248 |
|
249 |
-
loop = asyncio.new_event_loop()
|
250 |
-
asyncio.set_event_loop(loop)
|
251 |
-
generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
|
252 |
# except Exception as e:
|
253 |
# print(f"API call failed: {e}. Using local model for text generation.")
|
254 |
# Use the locally loaded model for text generation
|
|
|
7 |
import json
|
8 |
import torch
|
9 |
import re
|
10 |
+
import nest_asyncio
|
11 |
+
from hashlib import md5
|
12 |
+
|
13 |
+
nest_asyncio.apply()
|
14 |
|
15 |
repo_name = "BeardedMonster/SabiYarn-125M"
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
151 |
|
152 |
return "FAILED"
|
153 |
|
154 |
+
def generate_cache_key(user_input, generation_config):
|
155 |
+
key_data = f"{user_input}_{json.dumps(generation_config, sort_keys=True)}"
|
156 |
+
return md5(key_data.encode()).hexdigest()
|
157 |
+
|
158 |
+
@st.cache_data(show_spinner=False)
|
159 |
+
def get_cached_response(user_input, generation_config):
|
160 |
+
return asyncio.run(generate_from_api(user_input, generation_config))
|
161 |
+
|
162 |
|
163 |
# Sample texts
|
164 |
sample_texts = {
|
|
|
257 |
# Attempt the asynchronous API call
|
258 |
generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
259 |
# generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
|
260 |
+
cache_key = generate_cache_key(user_input, generation_config)
|
261 |
+
generated_text = get_cached_response(user_input, generation_config)
|
262 |
|
263 |
+
# loop = asyncio.new_event_loop()
|
264 |
+
# asyncio.set_event_loop(loop)
|
265 |
+
# generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
|
266 |
# except Exception as e:
|
267 |
# print(f"API call failed: {e}. Using local model for text generation.")
|
268 |
# Use the locally loaded model for text generation
|