|
import gradio as gr |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import time |
|
import os |
|
from huggingface_hub import login |
|
|
|
|
|
login(token=os.environ.get("HF_TOKEN")) |
|
|
|
|
|
model_id = "kristianfischerai12345/fischgpt-sft" |
|
print("Loading FischGPT model...") |
|
model = GPT2LMHeadModel.from_pretrained(model_id) |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model.eval() |
|
print("Model loaded successfully!") |
|
|
|
|
|
def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9): |
|
if not user_message or not user_message.strip(): |
|
return { |
|
"error": "Empty message", |
|
"response": None, |
|
"metadata": None |
|
} |
|
try: |
|
prompt = f"<|user|>{user_message.strip()}<|assistant|>" |
|
inputs = tokenizer.encode(prompt, return_tensors='pt') |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=max_length, |
|
temperature=float(temperature), |
|
top_p=float(top_p), |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
attention_mask=torch.ones_like(inputs) |
|
) |
|
generation_time = time.time() - start_time |
|
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = full_text.split("<|assistant|>", 1)[1].strip() |
|
input_tokens = len(inputs[0]) |
|
output_tokens = len(outputs[0]) |
|
new_tokens = output_tokens - input_tokens |
|
tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0 |
|
return { |
|
"error": None, |
|
"response": response, |
|
"metadata": { |
|
"input_tokens": input_tokens, |
|
"output_tokens": output_tokens, |
|
"new_tokens": new_tokens, |
|
"generation_time": round(generation_time, 3), |
|
"tokens_per_second": round(tokens_per_sec, 1), |
|
"model": "FischGPT-SFT", |
|
"parameters": { |
|
"temperature": temperature, |
|
"max_length": max_length, |
|
"top_p": top_p |
|
} |
|
} |
|
} |
|
except Exception as e: |
|
return { |
|
"error": str(e), |
|
"response": None, |
|
"metadata": None |
|
} |
|
|
|
|
|
def wake_up(): |
|
return {"status": "awake"} |
|
|
|
|
|
with gr.Blocks(title="FischGPT API") as app: |
|
gr.Markdown("### FischGPT API is running.") |
|
|
|
|
|
gr.Interface( |
|
fn=generate_api, |
|
inputs=[ |
|
gr.Textbox(label="User Message"), |
|
gr.Slider(0.1, 2.0, 0.8, label="Temperature"), |
|
gr.Slider(50, 300, 150, label="Max Length"), |
|
gr.Slider(0.1, 1.0, 0.9, label="Top-p") |
|
], |
|
outputs=gr.JSON(label="Response"), |
|
api_name="predict" |
|
) |
|
|
|
gr.Interface( |
|
fn=wake_up, |
|
inputs=[], |
|
outputs=gr.JSON(label="Status"), |
|
api_name="wake-up" |
|
) |
|
|
|
|
|
app.queue(api_open=True).launch() |
|
|