File size: 6,279 Bytes
19e04b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01c5acc
f231295
 
9a94757
 
 
 
 
 
 
01c5acc
f231295
9a94757
 
 
01c5acc
f231295
 
 
 
01c5acc
9a94757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f231295
 
 
 
 
01c5acc
f231295
 
9a94757
f231295
 
 
19e04b5
9a94757
 
19e04b5
 
 
f231295
 
9a94757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f231295
9a94757
 
 
 
 
 
 
 
01c5acc
 
9a94757
01c5acc
f231295
9a94757
 
01c5acc
 
 
 
9a94757
 
 
01c5acc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Solutions
Pricing



Spaces:

Sadmank
/
reflection

private

Logs
App
Files
Community
Settings
reflection
/
app.py

Sadmank's picture
Sadmank
Update app.py
6c53f07
verified
3 minutes ago
raw

Copy download link
history
blame
edit
delete
No virus

5.92 kB
import gradio as gr
import requests
import os
import json
import traceback
import sys
import re

# Enable or disable tracing
ENABLE_TRACING = False

# Set up the API endpoint and key
API_BASE_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")
API_URL = f"{API_BASE_URL}/chat/completions"

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

import re

def style_xml_content(text):
    def replace_content(match):
        full_match = match.group(0)
        tag = match.group(1)
        content = match.group(2)
        
        if tag == 'thinking':
            styled_content = f'<i><b>{content}</b></i>'
            return f'<details open><summary>&lt;thinking&gt;</summary>{styled_content}<br>&lt;/thinking&gt;</details>'
        elif tag == 'reflection':
            styled_content = f'<u><b>{content}</b></u>'
            return f'<details open><summary>&lt;reflection&gt;</summary>{styled_content}<br>&lt;/reflection&gt;</details>'
        else:
            return full_match.replace('<', '&lt;').replace('>', '&gt;')
    
    # First, escape all < and > characters
    text = text.replace('<', '&lt;').replace('>', '&gt;')
    
    # Then, unescape the specific tags we want to process
    text = text.replace('&lt;thinking&gt;', '<thinking>').replace('&lt;/thinking&gt;', '</thinking>')
    text = text.replace('&lt;reflection&gt;', '<reflection>').replace('&lt;/reflection&gt;', '</reflection>')
    
    # Apply styling to content inside tags
    styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL)
    
    # Remove blacklisted text
    styled_text = styled_text.replace("&lt;|im_start|&gt;", "")
    
    return styled_text

# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally  you output a brief and small to the point <output>."

def debug_print(*args, **kwargs):
    if ENABLE_TRACING:
        print(*args, file=sys.stderr, **kwargs)

def parse_sse(data):
    if data:
        data = data.decode('utf-8').strip()
        debug_print(f"Raw SSE data: {data}")
        if data.startswith('data: '):
            data = data[6:]  # Remove 'data: ' prefix
        if data == '[DONE]':
            return None
        try:
            return json.loads(data)
        except json.JSONDecodeError:
            debug_print(f"Failed to parse SSE data: {data}")
    return None

def stream_response(message, history, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    data = {
        "model": "forcemultiplier/fmx-reflective-2b",
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "system" : "You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags.",
        "top_p": top_p,
        "stream": True,
        "stop": [ "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>"]  # Add stop sequence
    }
    
    debug_print(f"Sending request to API: {API_URL}")
    debug_print(f"Request data: {json.dumps(data, indent=2)}")
    
    try:
        response = requests.post(API_URL, headers=headers, json=data, stream=True)
        debug_print(f"Response status code: {response.status_code}")
        debug_print(f"Response headers: {response.headers}")
        
        response.raise_for_status()
        
        accumulated_content = ""
        for line in response.iter_lines():
            if line:
                debug_print(f"Received line: {line}")
                parsed = parse_sse(line)
                if parsed:
                    debug_print(f"Parsed SSE data: {parsed}")
                    if 'choices' in parsed and len(parsed['choices']) > 0:
                        content = parsed['choices'][0]['delta'].get('content', '')
                        if content:
                            accumulated_content += content
                            styled_content = style_xml_content(accumulated_content)
                            yield styled_content
                            
                            # Check if we've reached the stop sequence
                            if accumulated_content.endswith("</output>"):
                                break
    
    except requests.exceptions.RequestException as e:
        debug_print(f"Request exception: {str(e)}")
        debug_print(f"Request exception traceback: {traceback.format_exc()}")
        yield f"Error: {str(e)}"
    except Exception as e:
        debug_print(f"Unexpected error: {str(e)}")
        debug_print(f"Error traceback: {traceback.format_exc()}")
        yield f"Unexpected error: {str(e)}"

demo = gr.ChatInterface(
    stream_response,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    debug_print(f"Starting application with API URL: {API_URL}")
    debug_print(f"Using system prompt: {SYSTEM_PROMPT}")
    debug_print(f"Tracing enabled: {ENABLE_TRACING}")
    demo.launch()