Spaces:
Sleeping
Sleeping
import gradio as gr | |
import anthropic | |
import json | |
import logging | |
from tool_handler import process_tool_call, tools | |
from config import SYSTEM_PROMPT, API_KEY, MODEL_NAME | |
from datasets import load_dataset | |
import pandas as pd | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Initialize Anthropoc client with API key | |
client = anthropic.Client(api_key=API_KEY) | |
def simple_chat(user_message, history): | |
# Reconstruct the message history | |
messages = [] | |
for i, (user_msg, assistant_msg) in enumerate(history): | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": user_message}) | |
full_response = "" | |
MAX_ITERATIONS = 5 | |
iteration_count = 0 | |
while iteration_count < MAX_ITERATIONS: | |
try: | |
logger.info(f"Sending messages to LLM API: {json.dumps(messages, indent=2)}") | |
response = client.messages.create( | |
model=MODEL_NAME, | |
system=SYSTEM_PROMPT, | |
max_tokens=4096, | |
tools=tools, | |
messages=messages, | |
) | |
logger.info(f"LLM API response: {json.dumps(response.to_dict(), indent=2)}") | |
assistant_message = response.content[0].text if isinstance(response.content, list) else response.content | |
if response.stop_reason == "tool_use": | |
tool_use = response.content[-1] | |
tool_name = tool_use.name | |
tool_input = tool_use.input | |
tool_result = process_tool_call(tool_name, tool_input) | |
# Add assistant message indicating tool use | |
messages.append({"role": "assistant", "content": assistant_message}) | |
# Add user message with tool result to maintain role alternation | |
messages.append({ | |
"role": "user", | |
"content": json.dumps({ | |
"type": "tool_result", | |
"tool_use_id": tool_use.id, | |
"content": tool_result, | |
}) | |
}) | |
full_response += f"\nUsing tool: {tool_name}\n" | |
iteration_count += 1 | |
continue | |
else: | |
# Add the assistant's reply to the full response | |
full_response += assistant_message | |
messages.append({"role": "assistant", "content": assistant_message}) | |
break | |
except anthropic.BadRequestError as e: | |
logger.error(f"BadRequestError: {str(e)}") | |
full_response = f"Error: {str(e)}" | |
break | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
full_response = f"An unexpected error occurred: {str(e)}" | |
break | |
logger.info(f"Final messages: {json.dumps(messages, indent=2)}") | |
if iteration_count == MAX_ITERATIONS: | |
logger.warning("Maximum iterations reached in simple_chat") | |
history.append((user_message, full_response)) | |
return history, "", messages # Return messages as well | |
def messages_to_dataframe(messages): | |
data = [] | |
for msg in messages: | |
row = { | |
'role': msg['role'], | |
'content': msg['content'] if isinstance(msg['content'], str) else json.dumps(msg['content']), | |
'tool_use': None, | |
'tool_result': None | |
} | |
if msg['role'] == 'assistant' and isinstance(msg['content'], list): | |
for item in msg['content']: | |
if isinstance(item, dict) and 'type' in item: | |
if item['type'] == 'tool_use': | |
row['tool_use'] = json.dumps(item) | |
elif item['type'] == 'tool_result': | |
row['tool_result'] = json.dumps(item) | |
data.append(row) | |
return pd.DataFrame(data) | |
def submit_message(message, history): | |
history, _, messages = simple_chat(message, history) | |
df = messages_to_dataframe(messages) | |
print(df) # For console output | |
return history, "", df | |
def load_customers_dataset(): | |
dataset = load_dataset("dwb2023/blackbird-customers", split="train") | |
df = pd.DataFrame(dataset) | |
return df | |
def load_orders_dataset(): | |
dataset = load_dataset("dwb2023/blackbird-orders", split="train") | |
df = pd.DataFrame(dataset) | |
return df | |
example_inputs = [ | |
"Can you lookup my user id? My email is...", | |
"I'm checking on the status of an order, the order id is...", | |
"Can you send me a list of my recent orders? My customer id is...", | |
"I need to cancel Order ID...", | |
"I lost my phone and need to update my contact information. My user id is...", | |
"I need to confirm my current user info and order status. My email is...", | |
] | |
# Create Gradio App | |
app = gr.Blocks(theme="sudeepshouche/minimalist") | |
with app: | |
with gr.Tab("Chatbot"): | |
gr.Markdown("# BlackBird Customer Support Chat") | |
gr.Markdown("## leveraging **Claude Sonnet 3.5** for microservice-based function calling") | |
gr.Markdown("FastAPI Backend - runing on Docker: [blackbird-svc](https://huggingface.co/spaces/dwb2023/blackbird-svc)") | |
gr.Markdown("Data Sources - HF Datasets: [blackbird-customers](https://huggingface.co/datasets/dwb2023/blackbird-customers) [blackbird-orders](https://huggingface.co/datasets/dwb2023/blackbird-orders)") | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox(label="Your message") | |
gr.Markdown("β¬οΈ checkout the *Customers* and *Orders* tabs above π for sample email addresses, order ids, etc.*") | |
examples = gr.Examples( | |
examples=example_inputs, | |
inputs=msg | |
) | |
submit = gr.Button("Submit", variant="primary") | |
clear = gr.Button("Clear", variant="secondary") | |
with gr.Column(): | |
chatbot = gr.Chatbot() | |
df_output = gr.Dataframe(label="Conversation Analysis") | |
def handle_submit(message, history): | |
return submit_message(message, history) | |
submit_event = msg.submit(handle_submit, [msg, chatbot], [chatbot, msg, df_output]).then( | |
lambda: "", None, msg | |
) | |
submit.click(submit_message, [msg, chatbot], [chatbot, msg, df_output], show_progress="full").then( | |
lambda: "", None, msg | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
with gr.Tab("Customers"): | |
customers_df = gr.Dataframe(load_customers_dataset(), label="Customers Data") | |
with gr.Tab("Orders"): | |
orders_df = gr.Dataframe(load_orders_dataset(), label="Orders Data") | |
if __name__ == "__main__": | |
app.launch() | |