MarziehFadaee's picture
Update app.py
9e6e711 verified
import os
from collections.abc import Iterator
import gradio as gr
from gradio import ChatMessage
from cohere import ClientV2
from cohere.core import RequestOptions
model_id = "command-a-reasoning-08-2025"
# Initialize Cohere client
api_key = os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError("COHERE_API_KEY environment variable is required")
client = ClientV2(api_key=api_key, client_name="hf-command-a-reasoning-08-2025")
def format_chat_history(messages: list) -> list:
"""
Formats the chat history into a structure Cohere can understand
"""
formatted_history = []
for message in messages:
# Handle both ChatMessage objects and regular dictionaries
if hasattr(message, "metadata") and message.metadata:
# Skip thinking messages (messages with metadata)
continue
# Extract role and content safely
if hasattr(message, "role"):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message.get("role")
content = message.get("content")
else:
continue
if role and content:
# Ensure content is a string to prevent validation issues
if content is None:
content = ""
elif not isinstance(content, str):
content = str(content)
formatted_history.append({
"role": role,
"content": content
})
return formatted_history
def generate(message: str, history: list, thinking_budget: int) -> Iterator[list]:
# Create a clean working copy of the history (excluding thinking messages)
working_history = []
for msg in history:
# Skip thinking messages (messages with metadata)
if hasattr(msg, "metadata") and msg.metadata:
continue
working_history.append(msg)
# Format chat history for Cohere API (exclude thinking messages)
messages = format_chat_history(working_history)
# Add current message
if message:
messages.append({"role": "user", "content": message})
try:
# Set thinking type based on thinking_budget
if thinking_budget == 0:
thinking_param = {"type": "disabled"}
else:
thinking_param = {"type": "enabled", "token_budget": thinking_budget}
# Call Cohere API using the correct event type and delta access
response = client.chat_stream(
model=model_id,
messages=messages,
temperature=0.3,
request_options=RequestOptions(additional_body_parameters={"thinking": thinking_param})
)
# Initialize buffers
thought_buffer = ""
response_buffer = ""
thinking_complete = False
# Start with just the new assistant messages for this interaction
current_interaction = [
ChatMessage(
role="assistant",
content="",
metadata={"title": "🧠 Thinking..."}
)
]
for event in response:
if getattr(event, "type", None) == "content-delta":
delta = event.delta
if hasattr(delta, 'message'):
message = delta.message
if hasattr(message, 'content'):
content = message.content
# Check for thinking tokens first
thinking_text = getattr(content, 'thinking', None)
if thinking_text:
thought_buffer += thinking_text
# Update thinking message with metadata
current_interaction[0] = ChatMessage(
role="assistant",
content=thought_buffer,
metadata={"title": "🧠 Thinking..."}
)
# Yield only the current interaction, but ensure proper formatting
yield [
{
"role": msg.role,
"content": msg.content,
"metadata": getattr(msg, "metadata", None)
} for msg in current_interaction
]
continue
# Check for regular text tokens
text = getattr(content, 'text', None)
if text:
# Ensure text is a string
if text is None:
text = ""
elif not isinstance(text, str):
text = str(text)
# If we haven't completed thinking yet, this might be the start of the response
if not thinking_complete and thought_buffer:
thinking_complete = True
# Add response message below thinking
current_interaction.append(
ChatMessage(
role="assistant",
content=""
)
)
if thinking_complete:
# if thinking is complete, we collapse the thinking message
current_interaction[0] = ChatMessage(
role="assistant",
content=thought_buffer,
metadata={"title": "🧠 Thoughts", "status": "done"}
)
response_buffer += text
# Update response message
current_interaction[-1] = ChatMessage(
role="assistant",
content=response_buffer
)
# Yield only the current interaction, but ensure proper formatting
yield [
{
"role": msg.role,
"content": msg.content,
"metadata": getattr(msg, "metadata", None)
} for msg in current_interaction
]
# Final cleanup: ensure the final response is clean
if thought_buffer and response_buffer:
# Keep both thinking and response messages in the final history
# The thinking message will be preserved with its metadata
pass
except Exception as e:
gr.Warning(f"Error calling Cohere API: {str(e)}")
yield []
examples = [
[
"A man walks into a bar and asks the bartender for a glass of water. The bartender pulls out a gun instead. The man says 'thank you' and leaves. Why?"
],
[
"Twenty-four red socks and 24 blue socks are lying in a drawer in a dark room. What is the minimum number of socks I must take out of the drawer which will guarantee that I have at least two socks of the same color?"
],
[
"A farmer needs to transport a fox, a chicken, and a sack of grain across a river in a boat that can only carry the farmer and one other thing at a time. If left alone together, the fox will eat the chicken, and the chicken will eat the grain. How does the farmer get everything across safely?"
],
[
"'''\nX +\n *\n''' \n\nReason about the above scene depicted in the markdown code block. If I interchange the locations of * and X, and then I interchange the locations of * and +, and then I flip the image like a left-right mirror, which symbol is on the leftmost part of the image?"
],
[
"You are running a race and overtake the person at position 76487423. What place are you in now?"
],
[
"A man dies of old age on his 25 birthday. How is this possible?"
],
[
"Como sair de um helicóptero que caiu na água?"
],
[
"What is the best way to learn machine learning?"
],
[
"Explain quantum computing in simple terms"
],
[
"How many months have 28 days?"
],
[
"Explique la théorie de la relativité en français"
],
[
"Write a COBOL function to reverse a string"
]
]
demo = gr.ChatInterface(
fn=generate,
type="messages",
autofocus=True,
title="Command A Reasoning",
examples=examples,
run_examples_on_click=True,
css_paths="style.css",
delete_cache=(1800, 1800),
cache_examples=False,
additional_inputs=[
gr.Slider(label="Thinking Budget", minimum=0, maximum=2000, step=10, value=500),
],
)
if __name__ == "__main__":
demo.launch()