Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os, json, time, random, threading, logging | |
from datetime import datetime, timezone | |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count()) | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" | |
PROMPTS_PATH = "full_prompts.json" | |
STATE_PATH = "current_state.json" | |
DATA_PATH = "data.json" | |
TOKENS_PER_PROMPT = 2048 | |
SECS_PER_TOKEN = 15 | |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192 | |
logging.basicConfig(level=logging.INFO) | |
log = logging.getLogger() | |
def _rj(p, d): | |
try: | |
return json.load(open(p, encoding="utf-8")) | |
except: | |
return d | |
def _aw(p, o): | |
t = p + ".tmp" | |
open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)) | |
os.replace(t, p) | |
prompts = _rj(PROMPTS_PATH, []) | |
if not prompts: | |
raise Exception("No prompts found in full_prompts.json") | |
tok = os.environ.get("HF_READ_TOKEN") | |
log.info("Loading model...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=False, | |
token=tok | |
) | |
model.to("cpu"); model.eval() | |
log.info("Model is ready.") | |
lock = threading.Lock() | |
def _init(): | |
state = _rj(STATE_PATH, {}) | |
if not state or state.get("finished"): | |
idx = random.randrange(len(prompts)) | |
state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False} | |
_aw(STATE_PATH, state) | |
return state | |
def _es(start_time): | |
elapsed = int(time.time() - start_time) | |
h, rem = divmod(elapsed, 3600) | |
m, s = divmod(rem, 60) | |
return f"{h}h {m}m {s}s" | |
def _loop(): | |
while True: | |
with lock: | |
st = _init() | |
if st["finished"]: | |
time.sleep(SECS_PER_TOKEN) | |
continue | |
context = st["p"] + st["g"] | |
ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids | |
with torch.no_grad(): | |
out = model.generate( | |
ids, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=TEMP, | |
top_p=TOP_P | |
) | |
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
with lock: | |
st["g"] += next_token | |
st["c"] += 1 | |
if st["c"] >= TOKENS_PER_PROMPT: | |
st["finished"] = True | |
_aw(STATE_PATH, st) | |
time.sleep(SECS_PER_TOKEN) | |
threading.Thread(target=_loop, daemon=True).start() | |
def _fetch(): | |
state = _rj(STATE_PATH, {}) | |
if not state: | |
return "...", "", "0h 0m 0s" | |
return state["p"], state["g"], _es(state["t"]) | |
def _submit_prediction(detailed, summary): | |
det = detailed.strip() | |
if not det: | |
return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="") | |
prompt_text, oracle_resp, elapsed = _fetch() | |
record = { | |
"ts": datetime.now(timezone.utc).isoformat(), | |
"prompt": prompt_text, | |
"time": elapsed, | |
"resp": oracle_resp, | |
"prediction": det, | |
"summary": summary.strip() | |
} | |
with lock: | |
open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n") | |
return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="") | |
with gr.Blocks(theme="darkdefault") as demo: | |
gr.Markdown( | |
"# What Comes Next\n" | |
"Enter what you think will come next in the text.\n" | |
"Provide a detailed continuation and optionally a brief summary for context." | |
) | |
prompt_md = gr.Markdown() | |
oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response") | |
time_info = gr.Textbox(interactive=False, label="Elapsed Time") | |
with gr.Row(): | |
prompt_md, oracle_output, time_info | |
detailed = gr.Textbox( | |
label="Your Detailed Prediction", | |
placeholder="Enter the full text continuation you expect...", | |
lines=3 | |
) | |
summary = gr.Textbox( | |
label="Prediction Summary (Optional)", | |
placeholder="Optionally, summarize your prediction in a few words...", | |
lines=2 | |
) | |
status = gr.Textbox(interactive=False, label="Status") | |
submit_btn = gr.Button("Submit Prediction") | |
refresh_btn = gr.Button("Refresh Oracle") | |
demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info]) | |
refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info]) | |
submit_btn.click( | |
_submit_prediction, | |
inputs=[detailed, summary], | |
outputs=[status, detailed, summary] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |