Spaces:
Sleeping
Sleeping
File size: 5,540 Bytes
22bd0c2 fcabc39 59b000d e23537b fcabc39 e23537b 59b000d e23537b fcabc39 e23537b 7e9b0bd e23537b be52360 7e9b0bd fcabc39 22bd0c2 fcabc39 59b000d fcabc39 7e9b0bd e23537b fcabc39 e23537b f7d8d0c 7e9b0bd e23537b f7d8d0c e23537b f7d8d0c e23537b f7d8d0c e23537b 7e9b0bd f7d8d0c e23537b 7e9b0bd 22bd0c2 7e9b0bd 22bd0c2 fcabc39 e23537b 22bd0c2 e23537b 22bd0c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import requests
import os
import json
import traceback
import sys
import re
# Enable or disable tracing
ENABLE_TRACING = False
# Set up the API endpoint and key
API_BASE_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")
API_URL = f"{API_BASE_URL}/chat/completions"
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
import re
def style_xml_content(text):
def replace_content(match):
full_match = match.group(0)
tag = match.group(1)
content = match.group(2)
if tag == 'thinking':
styled_content = f'<i><b>{content}</b></i>'
return f'<details open><summary><thinking></summary>{styled_content}<br></thinking></details>'
elif tag == 'reflection':
styled_content = f'<u><b>{content}</b></u>'
return f'<details open><summary><reflection></summary>{styled_content}<br></reflection></details>'
else:
return full_match.replace('<', '<').replace('>', '>')
# First, escape all < and > characters
text = text.replace('<', '<').replace('>', '>')
# Then, unescape the specific tags we want to process
text = text.replace('<thinking>', '<thinking>').replace('</thinking>', '</thinking>')
text = text.replace('<reflection>', '<reflection>').replace('</reflection>', '</reflection>')
# Apply styling to content inside tags
styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL)
# Remove blacklisted text
styled_text = styled_text.replace("<|im_start|>", "")
return styled_text
# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally you output a brief and small to the point <output>."
def debug_print(*args, **kwargs):
if ENABLE_TRACING:
print(*args, file=sys.stderr, **kwargs)
def parse_sse(data):
if data:
data = data.decode('utf-8').strip()
debug_print(f"Raw SSE data: {data}")
if data.startswith('data: '):
data = data[6:] # Remove 'data: ' prefix
if data == '[DONE]':
return None
try:
return json.loads(data)
except json.JSONDecodeError:
debug_print(f"Failed to parse SSE data: {data}")
return None
def stream_response(message, history, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
data = {
"model": "forcemultiplier/fmx-reflective-2b",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
"stop": ["</output>"] # Add stop sequence
}
debug_print(f"Sending request to API: {API_URL}")
debug_print(f"Request data: {json.dumps(data, indent=2)}")
try:
response = requests.post(API_URL, headers=headers, json=data, stream=True)
debug_print(f"Response status code: {response.status_code}")
debug_print(f"Response headers: {response.headers}")
response.raise_for_status()
accumulated_content = ""
for line in response.iter_lines():
if line:
debug_print(f"Received line: {line}")
parsed = parse_sse(line)
if parsed:
debug_print(f"Parsed SSE data: {parsed}")
if 'choices' in parsed and len(parsed['choices']) > 0:
content = parsed['choices'][0]['delta'].get('content', '')
if content:
accumulated_content += content
styled_content = style_xml_content(accumulated_content)
yield styled_content
# Check if we've reached the stop sequence
if accumulated_content.endswith("</output>"):
break
except requests.exceptions.RequestException as e:
debug_print(f"Request exception: {str(e)}")
debug_print(f"Request exception traceback: {traceback.format_exc()}")
yield f"Error: {str(e)}"
except Exception as e:
debug_print(f"Unexpected error: {str(e)}")
debug_print(f"Error traceback: {traceback.format_exc()}")
yield f"Unexpected error: {str(e)}"
demo = gr.ChatInterface(
stream_response,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
debug_print(f"Starting application with API URL: {API_URL}")
debug_print(f"Using system prompt: {SYSTEM_PROMPT}")
debug_print(f"Tracing enabled: {ENABLE_TRACING}")
demo.launch() |