|
import os |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import logging |
|
import time |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers.utils import logging as hf_logging |
|
|
|
logging.basicConfig( |
|
filename="/tmp/app.log", |
|
level=logging.DEBUG, |
|
format="%(asctime)s %(levelname)s: %(message)s" |
|
) |
|
|
|
logging.info("Starting app.py logging") |
|
hf_logging.set_verbosity_debug() |
|
hf_logging.set_verbosity_info() |
|
hf_logging.enable_default_handler() |
|
hf_logging.enable_explicit_format() |
|
hf_logging.add_handler(logging.FileHandler("/tmp/transformers.log")) |
|
|
|
|
|
model_id = "futurehouse/ether0" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
@spaces.GPU |
|
def chat_fn(prompt, max_tokens=512): |
|
t0 = time.time() |
|
max_tokens = min(int(max_tokens), 32_000) |
|
|
|
try: |
|
messages = [{"role": "user", "content": prompt}] |
|
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device) |
|
t1 = time.time() |
|
logging.info(f"🧠 Tokenization complete in {t1 - t0:.2f}s") |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
do_sample=True, |
|
temperature=0.1, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
t2 = time.time() |
|
logging.info(f"⚡️ Generation complete in {t2 - t1:.2f}s (max_tokens={max_tokens})") |
|
|
|
generated_text = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
skip_special_tokens=True |
|
) |
|
t3 = time.time() |
|
logging.info(f"🔓 Decoding complete in {t3 - t2:.2f}s (output length: {len(generated_text)})") |
|
|
|
return generated_text |
|
|
|
except Exception: |
|
logging.exception("❌ Exception during generation") |
|
return "⚠️ Generation failed" |
|
|
|
gr.Interface( |
|
fn=chat_fn, |
|
inputs=[ |
|
gr.Textbox(label="prompt"), |
|
gr.Number(label="max_tokens", value=512, precision=0) |
|
], |
|
outputs="text", |
|
title="Ether0" |
|
).launch(ssr_mode=False) |