File size: 4,849 Bytes
b7304c4
c08680c
2dd44a8
 
e18ee0c
c08680c
23814a9
421d392
b7304c4
e4b0b00
2dd44a8
 
 
 
e18ee0c
 
2dd44a8
 
 
 
7f977c5
 
 
 
 
2dd44a8
 
7f977c5
 
 
 
2dd44a8
7f977c5
 
 
2dd44a8
7f977c5
 
 
 
 
 
 
 
 
 
 
 
 
2dd44a8
e18ee0c
7f977c5
 
 
 
 
 
 
 
 
 
 
e18ee0c
 
 
b7304c4
 
7f977c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e18ee0c
7f977c5
 
2dd44a8
e18ee0c
7f977c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23814a9
 
7f977c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3

import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

def _rj(p, d):
    try:
        return json.load(open(p, encoding="utf-8"))
    except:
        return d


def _aw(p, o):
    t = p + ".tmp"
    open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
    os.replace(t, p)

prompts = _rj(PROMPTS_PATH, [])
if not prompts:
    raise Exception("No prompts found in full_prompts.json")

tok = os.environ.get("HF_READ_TOKEN")
log.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=False,
    token=tok
)
model.to("cpu"); model.eval()
log.info("Model is ready.")

lock = threading.Lock()

def _init():
    state = _rj(STATE_PATH, {})
    if not state or state.get("finished"):
        idx = random.randrange(len(prompts))
        state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
        _aw(STATE_PATH, state)
    return state

def _es(start_time):
    elapsed = int(time.time() - start_time)
    h, rem = divmod(elapsed, 3600)
    m, s = divmod(rem, 60)
    return f"{h}h {m}m {s}s"

def _loop():
    while True:
        with lock:
            st = _init()
        if st["finished"]:
            time.sleep(SECS_PER_TOKEN)
            continue
        context = st["p"] + st["g"]
        ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
        with torch.no_grad():
            out = model.generate(
                ids,
                max_new_tokens=1,
                do_sample=True,
                temperature=TEMP,
                top_p=TOP_P
            )
        next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        with lock:
            st["g"] += next_token
            st["c"] += 1
            if st["c"] >= TOKENS_PER_PROMPT:
                st["finished"] = True
            _aw(STATE_PATH, st)
        time.sleep(SECS_PER_TOKEN)

threading.Thread(target=_loop, daemon=True).start()

def _fetch():
    state = _rj(STATE_PATH, {})
    if not state:
        return "...", "", "0h 0m 0s"
    return state["p"], state["g"], _es(state["t"])

def _submit_prediction(detailed, summary):
    det = detailed.strip()
    if not det:
        return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
    prompt_text, oracle_resp, elapsed = _fetch()
    record = {
        "ts": datetime.now(timezone.utc).isoformat(),
        "prompt": prompt_text,
        "time": elapsed,
        "resp": oracle_resp,
        "prediction": det,
        "summary": summary.strip()
    }
    with lock:
        open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
    return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")

with gr.Blocks(theme="darkdefault") as demo:
    gr.Markdown(
        "# What Comes Next\n"
        "Enter what you think will come next in the text.\n"
        "Provide a detailed continuation and optionally a brief summary for context."
    )
    prompt_md = gr.Markdown()
    oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
    time_info = gr.Textbox(interactive=False, label="Elapsed Time")

    with gr.Row():
        prompt_md, oracle_output, time_info

    detailed = gr.Textbox(
        label="Your Detailed Prediction",
        placeholder="Enter the full text continuation you expect...",
        lines=3
    )
    summary = gr.Textbox(
        label="Prediction Summary (Optional)",
        placeholder="Optionally, summarize your prediction in a few words...",
        lines=2
    )
    status = gr.Textbox(interactive=False, label="Status")
    submit_btn = gr.Button("Submit Prediction")
    refresh_btn = gr.Button("Refresh Oracle")

    demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
    refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
    submit_btn.click(
        _submit_prediction,
        inputs=[detailed, summary],
        outputs=[status, detailed, summary]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)