what-comes-next / app.py
ProCreations's picture
Update app.py
7f977c5 verified
#!/usr/bin/env python3
import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
def _rj(p, d):
try:
return json.load(open(p, encoding="utf-8"))
except:
return d
def _aw(p, o):
t = p + ".tmp"
open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
os.replace(t, p)
prompts = _rj(PROMPTS_PATH, [])
if not prompts:
raise Exception("No prompts found in full_prompts.json")
tok = os.environ.get("HF_READ_TOKEN")
log.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
token=tok
)
model.to("cpu"); model.eval()
log.info("Model is ready.")
lock = threading.Lock()
def _init():
state = _rj(STATE_PATH, {})
if not state or state.get("finished"):
idx = random.randrange(len(prompts))
state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
_aw(STATE_PATH, state)
return state
def _es(start_time):
elapsed = int(time.time() - start_time)
h, rem = divmod(elapsed, 3600)
m, s = divmod(rem, 60)
return f"{h}h {m}m {s}s"
def _loop():
while True:
with lock:
st = _init()
if st["finished"]:
time.sleep(SECS_PER_TOKEN)
continue
context = st["p"] + st["g"]
ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
with torch.no_grad():
out = model.generate(
ids,
max_new_tokens=1,
do_sample=True,
temperature=TEMP,
top_p=TOP_P
)
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
st["g"] += next_token
st["c"] += 1
if st["c"] >= TOKENS_PER_PROMPT:
st["finished"] = True
_aw(STATE_PATH, st)
time.sleep(SECS_PER_TOKEN)
threading.Thread(target=_loop, daemon=True).start()
def _fetch():
state = _rj(STATE_PATH, {})
if not state:
return "...", "", "0h 0m 0s"
return state["p"], state["g"], _es(state["t"])
def _submit_prediction(detailed, summary):
det = detailed.strip()
if not det:
return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
prompt_text, oracle_resp, elapsed = _fetch()
record = {
"ts": datetime.now(timezone.utc).isoformat(),
"prompt": prompt_text,
"time": elapsed,
"resp": oracle_resp,
"prediction": det,
"summary": summary.strip()
}
with lock:
open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")
with gr.Blocks(theme="darkdefault") as demo:
gr.Markdown(
"# What Comes Next\n"
"Enter what you think will come next in the text.\n"
"Provide a detailed continuation and optionally a brief summary for context."
)
prompt_md = gr.Markdown()
oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
time_info = gr.Textbox(interactive=False, label="Elapsed Time")
with gr.Row():
prompt_md, oracle_output, time_info
detailed = gr.Textbox(
label="Your Detailed Prediction",
placeholder="Enter the full text continuation you expect...",
lines=3
)
summary = gr.Textbox(
label="Prediction Summary (Optional)",
placeholder="Optionally, summarize your prediction in a few words...",
lines=2
)
status = gr.Textbox(interactive=False, label="Status")
submit_btn = gr.Button("Submit Prediction")
refresh_btn = gr.Button("Refresh Oracle")
demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
submit_btn.click(
_submit_prediction,
inputs=[detailed, summary],
outputs=[status, detailed, summary]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)