reflection / app.py
Sadmank's picture
Add system prompt, add correct stop tokens
19e04b5 verified
raw
history blame
6.28 kB
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Solutions
Pricing
Spaces:
Sadmank
/
reflection
private
Logs
App
Files
Community
Settings
reflection
/
app.py
Sadmank's picture
Sadmank
Update app.py
6c53f07
verified
3 minutes ago
raw
Copy download link
history
blame
edit
delete
No virus
5.92 kB
import gradio as gr
import requests
import os
import json
import traceback
import sys
import re
# Enable or disable tracing
ENABLE_TRACING = False
# Set up the API endpoint and key
API_BASE_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")
API_URL = f"{API_BASE_URL}/chat/completions"
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
import re
def style_xml_content(text):
def replace_content(match):
full_match = match.group(0)
tag = match.group(1)
content = match.group(2)
if tag == 'thinking':
styled_content = f'<i><b>{content}</b></i>'
return f'<details open><summary>&lt;thinking&gt;</summary>{styled_content}<br>&lt;/thinking&gt;</details>'
elif tag == 'reflection':
styled_content = f'<u><b>{content}</b></u>'
return f'<details open><summary>&lt;reflection&gt;</summary>{styled_content}<br>&lt;/reflection&gt;</details>'
else:
return full_match.replace('<', '&lt;').replace('>', '&gt;')
# First, escape all < and > characters
text = text.replace('<', '&lt;').replace('>', '&gt;')
# Then, unescape the specific tags we want to process
text = text.replace('&lt;thinking&gt;', '<thinking>').replace('&lt;/thinking&gt;', '</thinking>')
text = text.replace('&lt;reflection&gt;', '<reflection>').replace('&lt;/reflection&gt;', '</reflection>')
# Apply styling to content inside tags
styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL)
# Remove blacklisted text
styled_text = styled_text.replace("&lt;|im_start|&gt;", "")
return styled_text
# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally you output a brief and small to the point <output>."
def debug_print(*args, **kwargs):
if ENABLE_TRACING:
print(*args, file=sys.stderr, **kwargs)
def parse_sse(data):
if data:
data = data.decode('utf-8').strip()
debug_print(f"Raw SSE data: {data}")
if data.startswith('data: '):
data = data[6:] # Remove 'data: ' prefix
if data == '[DONE]':
return None
try:
return json.loads(data)
except json.JSONDecodeError:
debug_print(f"Failed to parse SSE data: {data}")
return None
def stream_response(message, history, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
data = {
"model": "forcemultiplier/fmx-reflective-2b",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"system" : "You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags.",
"top_p": top_p,
"stream": True,
"stop": [ "<|start_header_id|>",
"<|end_header_id|>",
"<|eot_id|>"] # Add stop sequence
}
debug_print(f"Sending request to API: {API_URL}")
debug_print(f"Request data: {json.dumps(data, indent=2)}")
try:
response = requests.post(API_URL, headers=headers, json=data, stream=True)
debug_print(f"Response status code: {response.status_code}")
debug_print(f"Response headers: {response.headers}")
response.raise_for_status()
accumulated_content = ""
for line in response.iter_lines():
if line:
debug_print(f"Received line: {line}")
parsed = parse_sse(line)
if parsed:
debug_print(f"Parsed SSE data: {parsed}")
if 'choices' in parsed and len(parsed['choices']) > 0:
content = parsed['choices'][0]['delta'].get('content', '')
if content:
accumulated_content += content
styled_content = style_xml_content(accumulated_content)
yield styled_content
# Check if we've reached the stop sequence
if accumulated_content.endswith("</output>"):
break
except requests.exceptions.RequestException as e:
debug_print(f"Request exception: {str(e)}")
debug_print(f"Request exception traceback: {traceback.format_exc()}")
yield f"Error: {str(e)}"
except Exception as e:
debug_print(f"Unexpected error: {str(e)}")
debug_print(f"Error traceback: {traceback.format_exc()}")
yield f"Unexpected error: {str(e)}"
demo = gr.ChatInterface(
stream_response,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
debug_print(f"Starting application with API URL: {API_URL}")
debug_print(f"Using system prompt: {SYSTEM_PROMPT}")
debug_print(f"Tracing enabled: {ENABLE_TRACING}")
demo.launch()