Spaces:
Runtime error
Runtime error
Update app.py
#21
by
Alfaxad
- opened
app.py
CHANGED
|
@@ -1,191 +1,149 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
-
import sys
|
| 4 |
-
import json
|
| 5 |
-
import requests
|
| 6 |
-
import random
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
MODEL = "gpt-4.1-mini"
|
| 10 |
-
API_URL = os.getenv("API_URL")
|
| 11 |
-
DISABLED = os.getenv("DISABLED") == 'True'
|
| 12 |
-
OPENAI_API_KEYS = os.getenv("OPENAI_API_KEYS").split(',')
|
| 13 |
-
print (API_URL)
|
| 14 |
-
print (OPENAI_API_KEYS)
|
| 15 |
-
NUM_THREADS = int(os.getenv("NUM_THREADS"))
|
| 16 |
-
|
| 17 |
-
print (NUM_THREADS)
|
| 18 |
-
|
| 19 |
-
def exception_handler(exception_type, exception, traceback):
|
| 20 |
-
print("%s: %s" % (exception_type.__name__, exception))
|
| 21 |
-
sys.excepthook = exception_handler
|
| 22 |
-
sys.tracebacklimit = 0
|
| 23 |
-
|
| 24 |
-
def predict(inputs, top_p, temperature, chat_counter, chatbot, history, request:gr.Request):
|
| 25 |
-
payload = {
|
| 26 |
-
"model": MODEL,
|
| 27 |
-
"messages": [{"role": "user", "content": f"{inputs}"}],
|
| 28 |
-
"temperature": temperature,
|
| 29 |
-
"top_p": top_p,
|
| 30 |
-
"n" : 1,
|
| 31 |
-
"stream": True,
|
| 32 |
-
"presence_penalty":0,
|
| 33 |
-
"frequency_penalty":0,
|
| 34 |
-
}
|
| 35 |
-
OPENAI_API_KEY = random.choice(OPENAI_API_KEYS)
|
| 36 |
-
print (OPENAI_API_KEY)
|
| 37 |
-
|
| 38 |
-
headers_dict = {key.decode('utf-8'): value.decode('utf-8') for key, value in request.headers.raw}
|
| 39 |
-
|
| 40 |
-
headers = {
|
| 41 |
-
"Content-Type": "application/json",
|
| 42 |
-
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
| 43 |
-
"Headers": f"{headers_dict}"
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
# print(f"chat_counter - {chat_counter}")
|
| 47 |
-
if chat_counter != 0 :
|
| 48 |
-
messages = []
|
| 49 |
-
for i, data in enumerate(history):
|
| 50 |
-
if i % 2 == 0:
|
| 51 |
-
role = 'user'
|
| 52 |
-
else:
|
| 53 |
-
role = 'assistant'
|
| 54 |
-
message = {}
|
| 55 |
-
message["role"] = role
|
| 56 |
-
message["content"] = data
|
| 57 |
-
messages.append(message)
|
| 58 |
-
|
| 59 |
-
message = {}
|
| 60 |
-
message["role"] = "user"
|
| 61 |
-
message["content"] = inputs
|
| 62 |
-
messages.append(message)
|
| 63 |
-
payload = {
|
| 64 |
-
"model": MODEL,
|
| 65 |
-
"messages": messages,
|
| 66 |
-
"temperature" : temperature,
|
| 67 |
-
"top_p": top_p,
|
| 68 |
-
"n" : 1,
|
| 69 |
-
"stream": True,
|
| 70 |
-
"presence_penalty":0,
|
| 71 |
-
"frequency_penalty":0,
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
chat_counter += 1
|
| 75 |
-
|
| 76 |
-
history.append(inputs)
|
| 77 |
-
token_counter = 0
|
| 78 |
-
partial_words = ""
|
| 79 |
-
counter = 0
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
try:
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
if chunk.decode() :
|
| 99 |
-
chunk = chunk.decode()
|
| 100 |
-
# decode each line as response data is in bytes
|
| 101 |
-
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
|
| 102 |
-
partial_words = partial_words + json.loads(chunk[6:])['choices'][0]["delta"]["content"]
|
| 103 |
-
if token_counter == 0:
|
| 104 |
-
history.append(" " + partial_words)
|
| 105 |
-
else:
|
| 106 |
-
history[-1] = partial_words
|
| 107 |
-
token_counter += 1
|
| 108 |
-
yield [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ], history, chat_counter, response, gr.update(interactive=False), gr.update(interactive=False) # resembles {chatbot: chat, state: history}
|
| 109 |
except Exception as e:
|
| 110 |
-
print
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
<p>By continuing to use our app, you provide your explicit consent to the collection, use, and potential sharing of your data as described above. If you do not agree with our data collection, use, and sharing practices, please do not use our app.</p>
|
| 176 |
-
</div>
|
| 177 |
-
""")
|
| 178 |
-
accept_button = gr.Button("I Agree")
|
| 179 |
-
|
| 180 |
-
def enable_inputs():
|
| 181 |
-
return gr.update(visible=False), gr.update(visible=True)
|
| 182 |
-
|
| 183 |
-
accept_button.click(None, None, accept_checkbox, js=js, queue=False)
|
| 184 |
-
accept_checkbox.change(fn=enable_inputs, inputs=[], outputs=[user_consent_block, main_block], queue=False)
|
| 185 |
-
|
| 186 |
-
inputs.submit(reset_textbox, [], [inputs, b1], queue=False)
|
| 187 |
-
inputs.submit(predict, [inputs, top_p, temperature, chat_counter, chatbot, state], [chatbot, state, chat_counter, server_status_code, inputs, b1],) #openai_api_key
|
| 188 |
-
b1.click(reset_textbox, [], [inputs, b1], queue=False)
|
| 189 |
-
b1.click(predict, [inputs, top_p, temperature, chat_counter, chatbot, state], [chatbot, state, chat_counter, server_status_code, inputs, b1],) #openai_api_key
|
| 190 |
-
|
| 191 |
-
demo.queue(max_size=10, default_concurrency_limit=NUM_THREADS, api_open=False).launch(share=False)
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
import torch
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
# --- Configuration ---
|
| 8 |
+
# <<< --- Updated Model ID --- >>>
|
| 9 |
+
MODEL_ID = "Alfaxad/gemma2-2b-swahili-it"
|
| 10 |
+
# <<< ---------------------- >>>
|
| 11 |
+
|
| 12 |
+
# Use bf16 for performance on compatible GPUs, otherwise fp32
|
| 13 |
+
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
| 14 |
+
|
| 15 |
+
# device_map="auto" will handle placing the model on GPU if available, otherwise CPU.
|
| 16 |
+
# For a 2B model, this works perfectly and is simpler than manual device handling.
|
| 17 |
+
device_map = "auto"
|
| 18 |
+
print(f"--- Using model: {MODEL_ID} ---")
|
| 19 |
+
print(f"--- Using dtype: {TORCH_DTYPE} ---")
|
| 20 |
+
print(f"--- Using device_map: {device_map} ---")
|
| 21 |
+
|
| 22 |
+
# --- Load Model and Tokenizer ---
|
| 23 |
+
# Use HF_TOKEN environment variable in Space secrets if the model requires authentication
|
| 24 |
+
# access_token = os.getenv("HF_TOKEN")
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
print("Loading tokenizer...")
|
| 28 |
+
# trust_remote_code=True might be needed depending on the model's specifics
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) #, token=access_token)
|
| 30 |
+
|
| 31 |
+
print("Loading model...")
|
| 32 |
+
# device_map="auto" handles GPU/CPU placement
|
| 33 |
+
# torch_dtype optimizes for performance/memory on GPU
|
| 34 |
+
# Quantization options (load_in_8bit/4bit) are removed as they are unnecessary for a 2B model on A10G/CPU
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
+
MODEL_ID,
|
| 37 |
+
torch_dtype=TORCH_DTYPE,
|
| 38 |
+
device_map=device_map,
|
| 39 |
+
# token=access_token, # Uncomment if needed
|
| 40 |
+
# trust_remote_code=True # Uncomment if required by the model
|
| 41 |
+
)
|
| 42 |
+
print("Model loaded successfully.")
|
| 43 |
+
|
| 44 |
+
# <<< --- Set model to evaluation mode --- >>>
|
| 45 |
+
model.eval()
|
| 46 |
+
print("Model set to evaluation mode.")
|
| 47 |
+
# <<< ----------------------------------- >>>
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"Error loading model: {e}")
|
| 51 |
+
# Use the updated MODEL_ID in the error message
|
| 52 |
+
raise gr.Error(f"Failed to load model {MODEL_ID}. Error: {e}")
|
| 53 |
+
|
| 54 |
+
# --- Define Inference Function ---
|
| 55 |
+
# Use inference mode decorator for efficiency
|
| 56 |
+
@torch.inference_mode()
|
| 57 |
+
def generate_swahili_text(prompt, max_new_tokens=150, temperature=0.7, top_p=0.9):
|
| 58 |
+
"""
|
| 59 |
+
Generates text continuation in Swahili using the loaded model.
|
| 60 |
+
"""
|
| 61 |
+
print(f"\n--- Received Prompt: ---\n{prompt}\n-----------------------")
|
| 62 |
+
|
| 63 |
+
# Tokenize input - device placement handled by device_map="auto"
|
| 64 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 65 |
+
|
| 66 |
+
print("Generating response...")
|
| 67 |
try:
|
| 68 |
+
# Generate text
|
| 69 |
+
outputs = model.generate(
|
| 70 |
+
**inputs,
|
| 71 |
+
max_new_tokens=max_new_tokens,
|
| 72 |
+
temperature=temperature,
|
| 73 |
+
top_p=top_p,
|
| 74 |
+
do_sample=True,
|
| 75 |
+
pad_token_id=tokenizer.eos_token_id
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Decode the generated tokens
|
| 79 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 80 |
+
|
| 81 |
+
print(f"--- Generated Output: ---\n{result}\n------------------------")
|
| 82 |
+
return result
|
| 83 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
except Exception as e:
|
| 85 |
+
print(f"Error during generation: {e}")
|
| 86 |
+
return f"Samahani, kosa limetokea wakati wa kutengeneza maandishi. (Sorry, an error occurred during text generation: {e})"
|
| 87 |
+
|
| 88 |
+
# --- Create Gradio Interface ---
|
| 89 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="neutral").set(
|
| 90 |
+
body_background_fill="#f0f0f0",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
with gr.Blocks(theme=theme) as demo:
|
| 94 |
+
# <<< --- Updated Markdown Description --- >>>
|
| 95 |
+
gr.Markdown(f"""
|
| 96 |
+
# πΉπΏπ°πͺπΊπ¬ Gemma-2 **2B** Swahili IT Demo π°πͺπΊπ¬πΉπΏ
|
| 97 |
+
|
| 98 |
+
This Space runs the `Alfaxad/gemma2-2b-swahili-it` (2 Billion parameter) model.
|
| 99 |
+
|
| 100 |
+
Enter a prompt in Swahili, and the model will generate a continuation.
|
| 101 |
+
*This 2B model should load and respond relatively quickly.*
|
| 102 |
+
""")
|
| 103 |
+
# <<< ---------------------------------- >>>
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column(scale=2):
|
| 107 |
+
prompt_input = gr.Textbox(
|
| 108 |
+
label="Andika Maandishi Yako Hapa (Enter Your Prompt Here)",
|
| 109 |
+
placeholder="Mfano: Habari za asubuhi! Leo hali ya hewa ikoje Nairobi?",
|
| 110 |
+
lines=4
|
| 111 |
+
)
|
| 112 |
+
submit_button = gr.Button("Tengeneza Maandishi (Generate Text)", variant="primary")
|
| 113 |
+
|
| 114 |
+
with gr.Column(scale=3):
|
| 115 |
+
output_text = gr.Textbox(
|
| 116 |
+
label="Matokeo (Output)",
|
| 117 |
+
lines=10,
|
| 118 |
+
interactive=False
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 122 |
+
max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens")
|
| 123 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature (Randomness)")
|
| 124 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)")
|
| 125 |
+
|
| 126 |
+
# Link components to the function
|
| 127 |
+
submit_button.click(
|
| 128 |
+
fn=generate_swahili_text,
|
| 129 |
+
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider],
|
| 130 |
+
outputs=output_text,
|
| 131 |
+
api_name="generate"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
gr.Examples(
|
| 135 |
+
examples=[
|
| 136 |
+
["Kenya ni nchi iliyoko wapi?", 100, 0.7, 0.9],
|
| 137 |
+
["Niambie kuhusu historia ya Zanzibar.", 200, 0.8, 0.95],
|
| 138 |
+
["Eleza maana ya methali 'Haraka haraka haina baraka'.", 150, 0.6, 0.9],
|
| 139 |
+
],
|
| 140 |
+
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider],
|
| 141 |
+
outputs=output_text,
|
| 142 |
+
fn=generate_swahili_text,
|
| 143 |
+
cache_examples=False
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# --- Launch the Gradio App ---
|
| 147 |
+
print("Launching Gradio interface...")
|
| 148 |
+
demo.queue().launch()
|
| 149 |
+
print("Gradio interface launched.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|