File size: 4,860 Bytes
16dc509
 
2dd258a
 
 
 
 
 
 
 
 
 
 
fd7f475
 
 
 
 
 
 
 
 
 
 
2dd258a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7f475
 
 
 
2dd258a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7f475
 
 
 
 
 
 
 
2dd258a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7f475
2dd258a
 
 
 
 
 
fd7f475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dd258a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
from huggingface_hub import InferenceClient
import os # Import os to potentially get token from environment

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# !! REPLACE THIS WITH YOUR HUGGING FACE MODEL ID !!
MODEL_ID = "drwlf/PsychoQwen14b"
# It's recommended to use HF_TOKEN from environment/secrets
HF_TOKEN = os.getenv("HF_TOKEN")

# Initialize client, handle potential missing token
client = None # Initialize client to None
if not HF_TOKEN:
    print("Warning: HF_TOKEN secret not found. Cannot initialize InferenceClient.")
    # Optionally raise an error or handle this case in the respond function
else:
    try:
        client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
        print("InferenceClient initialized successfully.")
    except Exception as e:
        print(f"Error initializing InferenceClient: {e}")
        # Client remains None


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k # Added top_k parameter
):
    """
    Generator function to stream responses from the HF Inference API.
    """
    if not client:
         yield "Error: Inference Client not initialized. Check HF_TOKEN secret."
         return
    if not message or not message.strip():
         yield "Please enter a message."
         return

    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""
    stream = None # Initialize stream variable

    # Handle Top-K value (API often expects None to disable, not 0)
    top_k_val = top_k if top_k > 0 else None

    # Debugging: Print parameters being sent
    print(f"--- Sending Request ---")
    print(f"Model: {MODEL_ID}")
    print(f"Messages: {messages}")
    print(f"Max Tokens: {max_tokens}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k_val}")
    print(f"-----------------------")


    try:
        stream = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k_val # Pass the adjusted top_k value
        )

        for message_chunk in stream:
            # Check for content and delta before accessing
            if (hasattr(message_chunk, 'choices') and
                len(message_chunk.choices) > 0 and
                hasattr(message_chunk.choices[0], 'delta') and
                message_chunk.choices[0].delta and
                hasattr(message_chunk.choices[0].delta, 'content')):

                token = message_chunk.choices[0].delta.content
                if token: # Ensure token is not None or empty
                    response += token
                    # print(token, end="") # Debugging stream locally
                    yield response
            # Optional: Add error checking within the loop if needed

    except Exception as e:
         print(f"Error during chat completion: {e}")
         yield f"Sorry, an error occurred: {str(e)}"
    # No finally block needed unless specific cleanup is required


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(height=500), # Set chatbot height
    additional_inputs=[
        gr.Textbox(value="You are a friendly psychotherapy AI capable of thinking.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), # Adjusted max temp based on common usage
        gr.Slider(
            minimum=0.05, # Min Top-P often > 0
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-P (nucleus sampling)",
        ),
        # Added Top-K slider
        gr.Slider(
            minimum=0, # 0 disables Top-K
            maximum=100, # Common range, adjust if needed
            value=0, # Default to disabled
            step=1,
            label="Top-K (0 = disabled)",
        ),
    ],
     title="PsychoQwen Chat",
     description=f"Chat with {MODEL_ID}. Adjust generation parameters below.",
     retry_btn="Retry",
     undo_btn="Undo",
     clear_btn="Clear Chat",
)

# --- Launch the app directly ---
# The if __name__ == "__main__": block is removed or commented out
demo.queue().launch(debug=True) # debug=True is useful for seeing logs in the Space