Spaces:
Sleeping
Sleeping
File size: 5,540 Bytes
01c5acc f231295 9a94757 01c5acc f231295 9a94757 01c5acc f231295 01c5acc 9a94757 f231295 01c5acc f231295 9a94757 f231295 9a94757 f231295 9a94757 f231295 9a94757 01c5acc 9a94757 01c5acc f231295 9a94757 01c5acc 9a94757 01c5acc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import requests
import os
import json
import traceback
import sys
import re
# Enable or disable tracing
ENABLE_TRACING = False
# Set up the API endpoint and key
API_BASE_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")
API_URL = f"{API_BASE_URL}/chat/completions"
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
import re
def style_xml_content(text):
def replace_content(match):
full_match = match.group(0)
tag = match.group(1)
content = match.group(2)
if tag == 'thinking':
styled_content = f'<i><b>{content}</b></i>'
return f'<details open><summary><thinking></summary>{styled_content}<br></thinking></details>'
elif tag == 'reflection':
styled_content = f'<u><b>{content}</b></u>'
return f'<details open><summary><reflection></summary>{styled_content}<br></reflection></details>'
else:
return full_match.replace('<', '<').replace('>', '>')
# First, escape all < and > characters
text = text.replace('<', '<').replace('>', '>')
# Then, unescape the specific tags we want to process
text = text.replace('<thinking>', '<thinking>').replace('</thinking>', '</thinking>')
text = text.replace('<reflection>', '<reflection>').replace('</reflection>', '</reflection>')
# Apply styling to content inside tags
styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL)
# Remove blacklisted text
styled_text = styled_text.replace("<|im_start|>", "")
return styled_text
# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally you output a brief and small to the point <output>."
def debug_print(*args, **kwargs):
if ENABLE_TRACING:
print(*args, file=sys.stderr, **kwargs)
def parse_sse(data):
if data:
data = data.decode('utf-8').strip()
debug_print(f"Raw SSE data: {data}")
if data.startswith('data: '):
data = data[6:] # Remove 'data: ' prefix
if data == '[DONE]':
return None
try:
return json.loads(data)
except json.JSONDecodeError:
debug_print(f"Failed to parse SSE data: {data}")
return None
def stream_response(message, history, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
data = {
"model": "forcemultiplier/fmx-reflective-2b",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
"stop": ["</output>"] # Add stop sequence
}
debug_print(f"Sending request to API: {API_URL}")
debug_print(f"Request data: {json.dumps(data, indent=2)}")
try:
response = requests.post(API_URL, headers=headers, json=data, stream=True)
debug_print(f"Response status code: {response.status_code}")
debug_print(f"Response headers: {response.headers}")
response.raise_for_status()
accumulated_content = ""
for line in response.iter_lines():
if line:
debug_print(f"Received line: {line}")
parsed = parse_sse(line)
if parsed:
debug_print(f"Parsed SSE data: {parsed}")
if 'choices' in parsed and len(parsed['choices']) > 0:
content = parsed['choices'][0]['delta'].get('content', '')
if content:
accumulated_content += content
styled_content = style_xml_content(accumulated_content)
yield styled_content
# Check if we've reached the stop sequence
if accumulated_content.endswith("</output>"):
break
except requests.exceptions.RequestException as e:
debug_print(f"Request exception: {str(e)}")
debug_print(f"Request exception traceback: {traceback.format_exc()}")
yield f"Error: {str(e)}"
except Exception as e:
debug_print(f"Unexpected error: {str(e)}")
debug_print(f"Error traceback: {traceback.format_exc()}")
yield f"Unexpected error: {str(e)}"
demo = gr.ChatInterface(
stream_response,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
debug_print(f"Starting application with API URL: {API_URL}")
debug_print(f"Using system prompt: {SYSTEM_PROMPT}")
debug_print(f"Tracing enabled: {ENABLE_TRACING}")
demo.launch() |