burtenshaw HF Staff commited on
Commit
2148c1c
·
verified ·
1 Parent(s): cb5545e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from app import demo as app
4
+ import os
5
+
6
+ _docs = {'WordleBoard': {'description': 'Interactive Wordle board component.', 'members': {'__init__': {'word_length': {'type': 'int', 'default': '5', 'description': None}, 'max_attempts': {'type': 'int', 'default': '6', 'description': None}, 'return': {'type': 'None', 'description': None}}, 'postprocess': {'value': {'type': 'typing.Union[\n gradio_wordleboard.wordleboard.PublicWordleState,\n typing.Dict,\n str,\n NoneType,\n][PublicWordleState, Dict, str, None]', 'description': None}}, 'preprocess': {'return': {'type': 'typing.Optional[typing.Dict][Dict, None]', 'description': "The preprocessed input data sent to the user's function in the backend."}, 'value': None}}, 'events': {}}, '__meta__': {'additional_interfaces': {'PublicWordleState': {'source': '@dataclass\nclass PublicWordleState:\n board: List[WordleRow]\n current_row: int\n status: str\n message: str\n max_rows: int', 'refs': ['WordleRow']}, 'WordleRow': {'source': '@dataclass\nclass WordleRow:\n letters: List[str] = field(\n default_factory=lambda: [""] * 5\n )\n statuses: List[TileStatus] = field(\n default_factory=lambda: ["empty"] * 5\n )'}}, 'user_fn_refs': {'WordleBoard': ['PublicWordleState']}}}
7
+
8
+ abs_path = os.path.join(os.path.dirname(__file__), "css.css")
9
+
10
+ with gr.Blocks(
11
+ css=abs_path,
12
+ theme=gr.themes.Default(
13
+ font_mono=[
14
+ gr.themes.GoogleFont("Inconsolata"),
15
+ "monospace",
16
+ ],
17
+ ),
18
+ ) as demo:
19
+ gr.Markdown(
20
+ """
21
+ # `gradio_wordleboard`
22
+
23
+ <div style="display: flex; gap: 7px;">
24
+ <img alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.1%20-%20orange">
25
+ </div>
26
+
27
+ A custom Gradio component that renders and plays the Wordle word game
28
+ """, elem_classes=["md-custom"], header_links=True)
29
+ app.render()
30
+ gr.Markdown(
31
+ """
32
+ ## Installation
33
+
34
+ ```bash
35
+ pip install gradio_wordleboard
36
+ ```
37
+
38
+ ## Usage
39
+
40
+ ```python
41
+
42
+ from __future__ import annotations
43
+
44
+ import asyncio
45
+ import os
46
+ import re
47
+ from typing import AsyncIterator, Dict, List
48
+
49
+ import gradio as gr
50
+ from gradio_wordleboard import WordleBoard
51
+ from openai import AsyncOpenAI
52
+
53
+ from envs.textarena_env import TextArenaAction, TextArenaEnv
54
+ from envs.textarena_env.models import TextArenaMessage
55
+
56
+
57
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
58
+ API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
59
+ MODEL = os.getenv("MODEL", "openai/gpt-oss-120b:novita")
60
+ MAX_TURNS = int(os.getenv("MAX_TURNS", "6"))
61
+ DOCKER_IMAGE = os.getenv("TEXTARENA_IMAGE", "textarena-env:latest")
62
+
63
+
64
+ def _format_history(messages: List[TextArenaMessage]) -> str:
65
+ lines: List[str] = []
66
+ for message in messages:
67
+ tag = message.category or "MESSAGE"
68
+ lines.append(f"[{tag}] {message.content}")
69
+ return "\n".join(lines)
70
+
71
+
72
+ def _make_user_prompt(prompt_text: str, messages: List[TextArenaMessage]) -> str:
73
+ history = _format_history(messages)
74
+ return (
75
+ f"Current prompt:\n{prompt_text}\n\n"
76
+ f"Conversation so far:\n{history}\n\n"
77
+ "Reply with your next guess enclosed in square brackets."
78
+ )
79
+
80
+
81
+ async def _generate_guesses(client: AsyncOpenAI, prompt: str, history: List[TextArenaMessage]) -> str:
82
+ response = await client.chat.completions.create(
83
+ model=MODEL,
84
+ messages=[
85
+ {
86
+ "role": "system",
87
+ "content": (
88
+ "You are an expert Wordle solver."
89
+ " Always respond with a single guess inside square brackets, e.g. [crane]."
90
+ " Use lowercase letters, exactly one five-letter word per reply."
91
+ " Reason about prior feedback before choosing the next guess."
92
+ " Words must be 5 letters long and real English words."
93
+ " Do not include any other text in your response."
94
+ " Do not repeat the same guess twice."
95
+ ),
96
+ },
97
+ {"role": "user", "content": _make_user_prompt(prompt, history)},
98
+ ],
99
+ max_tokens=64,
100
+ temperature=0.7,
101
+ )
102
+
103
+ content = response.choices[0].message.content
104
+ response_text = content.strip() if content else ""
105
+ print(f"Response text: {response_text}")
106
+ return response_text
107
+
108
+
109
+ async def _play_wordle(env: TextArenaEnv, client: AsyncOpenAI) -> AsyncIterator[Dict[str, str]]:
110
+ state = await asyncio.to_thread(env.reset)
111
+ observation = state.observation
112
+
113
+ for turn in range(1, MAX_TURNS + 1):
114
+ if state.done:
115
+ break
116
+
117
+ model_output = await _generate_guesses(client, observation.prompt, observation.messages)
118
+ guess = _extract_guess(model_output)
119
+
120
+ state = await asyncio.to_thread(env.step, TextArenaAction(message=guess))
121
+ observation = state.observation
122
+
123
+ feedback = _collect_feedback(observation.messages)
124
+ yield {"guess": guess, "feedback": feedback}
125
+
126
+ yield {
127
+ "guess": "",
128
+ "feedback": _collect_feedback(observation.messages),
129
+ }
130
+
131
+
132
+ def _extract_guess(text: str) -> str:
133
+ if not text:
134
+ return "[crane]"
135
+
136
+ match = re.search(r"\[([A-Za-z]{5})\]", text)
137
+ if match:
138
+ guess = match.group(1).lower()
139
+ return f"[{guess}]"
140
+
141
+ cleaned = re.sub(r"[^a-zA-Z]", "", text).lower()
142
+ if len(cleaned) >= 5:
143
+ return f"[{cleaned[:5]}]"
144
+
145
+ return "[crane]"
146
+
147
+
148
+ def _collect_feedback(messages: List[TextArenaMessage]) -> str:
149
+ parts: List[str] = []
150
+ for message in messages:
151
+ tag = message.category or "MESSAGE"
152
+ if tag.upper() in {"FEEDBACK", "SYSTEM", "MESSAGE"}:
153
+ parts.append(message.content.strip())
154
+ return "\n".join(parts).strip()
155
+
156
+
157
+ async def inference_handler(api_key: str) -> AsyncIterator[str]:
158
+ if not api_key:
159
+ raise RuntimeError("HF_TOKEN or API_KEY environment variable must be set.")
160
+
161
+ client = AsyncOpenAI(base_url=API_BASE_URL, api_key=api_key)
162
+ env = TextArenaEnv.from_docker_image(
163
+ DOCKER_IMAGE,
164
+ env_vars={
165
+ "TEXTARENA_ENV_ID": "Wordle-v0",
166
+ "TEXTARENA_NUM_PLAYERS": "1",
167
+ },
168
+ ports={8000: 8000},
169
+ )
170
+
171
+ try:
172
+ async for result in _play_wordle(env, client):
173
+ yield result["feedback"]
174
+ finally:
175
+ env.close()
176
+
177
+
178
+ wordle_component = WordleBoard()
179
+
180
+
181
+ async def run_inference() -> AsyncIterator[Dict]:
182
+ feedback_history: List[str] = []
183
+
184
+ async for feedback in inference_handler(API_KEY):
185
+ stripped = feedback.strip()
186
+ if not stripped:
187
+ continue
188
+
189
+ feedback_history.append(stripped)
190
+ combined_feedback = "\n\n".join(feedback_history)
191
+ state = wordle_component.parse_feedback(combined_feedback)
192
+ yield wordle_component.to_public_dict(state)
193
+
194
+ if not feedback_history:
195
+ yield wordle_component.to_public_dict(wordle_component.create_game_state())
196
+
197
+
198
+ with gr.Blocks() as demo:
199
+ gr.Markdown("# Wordle TextArena Inference Demo")
200
+
201
+ board = WordleBoard(value=wordle_component.to_public_dict(wordle_component.create_game_state()))
202
+ run_button = gr.Button("Run Inference", variant="primary")
203
+
204
+ run_button.click(
205
+ fn=run_inference,
206
+ inputs=None,
207
+ outputs=board,
208
+ show_progress=True,
209
+ api_name="run",
210
+ )
211
+
212
+ demo.queue()
213
+
214
+
215
+ if __name__ == "__main__":
216
+ if not API_KEY:
217
+ raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.")
218
+
219
+ demo.launch()
220
+
221
+ ```
222
+ """, elem_classes=["md-custom"], header_links=True)
223
+
224
+
225
+ gr.Markdown("""
226
+ ## `WordleBoard`
227
+
228
+ ### Initialization
229
+ """, elem_classes=["md-custom"], header_links=True)
230
+
231
+ gr.ParamViewer(value=_docs["WordleBoard"]["members"]["__init__"], linkify=['PublicWordleState', 'WordleRow'])
232
+
233
+
234
+
235
+
236
+ gr.Markdown("""
237
+
238
+ ### User function
239
+
240
+ The impact on the users predict function varies depending on whether the component is used as an input or output for an event (or both).
241
+
242
+ - When used as an Input, the component only impacts the input signature of the user function.
243
+ - When used as an output, the component only impacts the return signature of the user function.
244
+
245
+ The code snippet below is accurate in cases where the component is used as both an input and an output.
246
+
247
+ - **As input:** Is passed, the preprocessed input data sent to the user's function in the backend.
248
+
249
+
250
+ ```python
251
+ def predict(
252
+ value: typing.Optional[typing.Dict][Dict, None]
253
+ ) -> typing.Union[
254
+ gradio_wordleboard.wordleboard.PublicWordleState,
255
+ typing.Dict,
256
+ str,
257
+ NoneType,
258
+ ][PublicWordleState, Dict, str, None]:
259
+ return value
260
+ ```
261
+ """, elem_classes=["md-custom", "WordleBoard-user-fn"], header_links=True)
262
+
263
+
264
+
265
+
266
+ code_PublicWordleState = gr.Markdown("""
267
+ ## `PublicWordleState`
268
+ ```python
269
+ @dataclass
270
+ class PublicWordleState:
271
+ board: List[WordleRow]
272
+ current_row: int
273
+ status: str
274
+ message: str
275
+ max_rows: int
276
+ ```""", elem_classes=["md-custom", "PublicWordleState"], header_links=True)
277
+
278
+ code_WordleRow = gr.Markdown("""
279
+ ## `WordleRow`
280
+ ```python
281
+ @dataclass
282
+ class WordleRow:
283
+ letters: List[str] = field(
284
+ default_factory=lambda: [""] * 5
285
+ )
286
+ statuses: List[TileStatus] = field(
287
+ default_factory=lambda: ["empty"] * 5
288
+ )
289
+ ```""", elem_classes=["md-custom", "WordleRow"], header_links=True)
290
+
291
+ demo.load(None, js=r"""function() {
292
+ const refs = {
293
+ PublicWordleState: ['WordleRow'],
294
+ WordleRow: [], };
295
+ const user_fn_refs = {
296
+ WordleBoard: ['PublicWordleState'], };
297
+ requestAnimationFrame(() => {
298
+
299
+ Object.entries(user_fn_refs).forEach(([key, refs]) => {
300
+ if (refs.length > 0) {
301
+ const el = document.querySelector(`.${key}-user-fn`);
302
+ if (!el) return;
303
+ refs.forEach(ref => {
304
+ el.innerHTML = el.innerHTML.replace(
305
+ new RegExp("\\b"+ref+"\\b", "g"),
306
+ `<a href="#h-${ref.toLowerCase()}">${ref}</a>`
307
+ );
308
+ })
309
+ }
310
+ })
311
+
312
+ Object.entries(refs).forEach(([key, refs]) => {
313
+ if (refs.length > 0) {
314
+ const el = document.querySelector(`.${key}`);
315
+ if (!el) return;
316
+ refs.forEach(ref => {
317
+ el.innerHTML = el.innerHTML.replace(
318
+ new RegExp("\\b"+ref+"\\b", "g"),
319
+ `<a href="#h-${ref.toLowerCase()}">${ref}</a>`
320
+ );
321
+ })
322
+ }
323
+ })
324
+ })
325
+ }
326
+
327
+ """)
328
+
329
+ demo.launch()