File size: 5,540 Bytes
22bd0c2
fcabc39
 
59b000d
e23537b
 
 
 
 
 
fcabc39
 
e23537b
59b000d
e23537b
fcabc39
 
 
 
 
 
e23537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9b0bd
e23537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be52360
7e9b0bd
 
fcabc39
 
 
 
 
22bd0c2
fcabc39
 
59b000d
fcabc39
 
 
7e9b0bd
e23537b
 
fcabc39
 
e23537b
 
 
f7d8d0c
7e9b0bd
e23537b
 
f7d8d0c
e23537b
f7d8d0c
e23537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7d8d0c
e23537b
 
7e9b0bd
f7d8d0c
e23537b
 
7e9b0bd
22bd0c2
 
7e9b0bd
22bd0c2
fcabc39
e23537b
 
22bd0c2
 
 
 
e23537b
 
 
22bd0c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import requests
import os
import json
import traceback
import sys
import re

# Enable or disable tracing
ENABLE_TRACING = False

# Set up the API endpoint and key
API_BASE_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")
API_URL = f"{API_BASE_URL}/chat/completions"

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

import re

def style_xml_content(text):
    def replace_content(match):
        full_match = match.group(0)
        tag = match.group(1)
        content = match.group(2)
        
        if tag == 'thinking':
            styled_content = f'<i><b>{content}</b></i>'
            return f'<details open><summary>&lt;thinking&gt;</summary>{styled_content}<br>&lt;/thinking&gt;</details>'
        elif tag == 'reflection':
            styled_content = f'<u><b>{content}</b></u>'
            return f'<details open><summary>&lt;reflection&gt;</summary>{styled_content}<br>&lt;/reflection&gt;</details>'
        else:
            return full_match.replace('<', '&lt;').replace('>', '&gt;')
    
    # First, escape all < and > characters
    text = text.replace('<', '&lt;').replace('>', '&gt;')
    
    # Then, unescape the specific tags we want to process
    text = text.replace('&lt;thinking&gt;', '<thinking>').replace('&lt;/thinking&gt;', '</thinking>')
    text = text.replace('&lt;reflection&gt;', '<reflection>').replace('&lt;/reflection&gt;', '</reflection>')
    
    # Apply styling to content inside tags
    styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL)
    
    # Remove blacklisted text
    styled_text = styled_text.replace("&lt;|im_start|&gt;", "")
    
    return styled_text

# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally  you output a brief and small to the point <output>."

def debug_print(*args, **kwargs):
    if ENABLE_TRACING:
        print(*args, file=sys.stderr, **kwargs)

def parse_sse(data):
    if data:
        data = data.decode('utf-8').strip()
        debug_print(f"Raw SSE data: {data}")
        if data.startswith('data: '):
            data = data[6:]  # Remove 'data: ' prefix
        if data == '[DONE]':
            return None
        try:
            return json.loads(data)
        except json.JSONDecodeError:
            debug_print(f"Failed to parse SSE data: {data}")
    return None

def stream_response(message, history, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    data = {
        "model": "forcemultiplier/fmx-reflective-2b",
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": True,
        "stop": ["</output>"]  # Add stop sequence
    }
    
    debug_print(f"Sending request to API: {API_URL}")
    debug_print(f"Request data: {json.dumps(data, indent=2)}")
    
    try:
        response = requests.post(API_URL, headers=headers, json=data, stream=True)
        debug_print(f"Response status code: {response.status_code}")
        debug_print(f"Response headers: {response.headers}")
        
        response.raise_for_status()
        
        accumulated_content = ""
        for line in response.iter_lines():
            if line:
                debug_print(f"Received line: {line}")
                parsed = parse_sse(line)
                if parsed:
                    debug_print(f"Parsed SSE data: {parsed}")
                    if 'choices' in parsed and len(parsed['choices']) > 0:
                        content = parsed['choices'][0]['delta'].get('content', '')
                        if content:
                            accumulated_content += content
                            styled_content = style_xml_content(accumulated_content)
                            yield styled_content
                            
                            # Check if we've reached the stop sequence
                            if accumulated_content.endswith("</output>"):
                                break
    
    except requests.exceptions.RequestException as e:
        debug_print(f"Request exception: {str(e)}")
        debug_print(f"Request exception traceback: {traceback.format_exc()}")
        yield f"Error: {str(e)}"
    except Exception as e:
        debug_print(f"Unexpected error: {str(e)}")
        debug_print(f"Error traceback: {traceback.format_exc()}")
        yield f"Unexpected error: {str(e)}"

demo = gr.ChatInterface(
    stream_response,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    debug_print(f"Starting application with API URL: {API_URL}")
    debug_print(f"Using system prompt: {SYSTEM_PROMPT}")
    debug_print(f"Tracing enabled: {ENABLE_TRACING}")
    demo.launch()