Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Depends, Security, BackgroundTasks | |
from fastapi.security import APIKeyHeader | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
from typing import Literal, List, Dict | |
import os | |
from functools import lru_cache | |
from openai import OpenAI | |
from uuid import uuid4 | |
import tiktoken | |
import sqlite3 | |
import time | |
from datetime import datetime, timedelta | |
import asyncio | |
import requests | |
from prompts import CODING_ASSISTANT_PROMPT, NEWS_ASSISTANT_PROMPT, generate_news_prompt, SEARCH_ASSISTANT_PROMPT, generate_search_prompt | |
from fastapi_cache import FastAPICache | |
from fastapi_cache.backends.inmemory import InMemoryBackend | |
from fastapi_cache.decorator import cache | |
import logging | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("app.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
API_KEY_NAME = "X-API-Key" | |
API_KEY = os.environ.get("CHAT_AUTH_KEY", "default_secret_key") | |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
ModelID = Literal[ | |
"openai/gpt-4o-mini", | |
"meta-llama/llama-3-70b-instruct", | |
"anthropic/claude-3.5-sonnet", | |
"deepseek/deepseek-coder", | |
"anthropic/claude-3-haiku", | |
"openai/gpt-3.5-turbo-instruct", | |
"qwen/qwen-72b-chat", | |
"google/gemma-2-27b-it" | |
] | |
class QueryModel(BaseModel): | |
user_query: str = Field(..., description="User's coding query") | |
model_id: ModelID = Field( | |
default="meta-llama/llama-3-70b-instruct", | |
description="ID of the model to use for response generation" | |
) | |
conversation_id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the conversation") | |
user_id: str = Field(..., description="Unique identifier for the user") | |
class Config: | |
schema_extra = { | |
"example": { | |
"user_query": "How do I implement a binary search in Python?", | |
"model_id": "meta-llama/llama-3-70b-instruct", | |
"conversation_id": "123e4567-e89b-12d3-a456-426614174000", | |
"user_id": "user123" | |
} | |
} | |
class NewsQueryModel(BaseModel): | |
query: str = Field(..., description="News topic to search for") | |
model_id: ModelID = Field( | |
default="openai/gpt-4o-mini", | |
description="ID of the model to use for response generation" | |
) | |
class Config: | |
schema_extra = { | |
"example": { | |
"query": "Latest developments in AI", | |
"model_id": "openai/gpt-4o-mini" | |
} | |
} | |
def get_api_keys(): | |
logger.info("Loading API keys") | |
return { | |
"OPENROUTER_API_KEY": f"sk-or-v1-{os.environ['OPENROUTER_API_KEY']}", | |
"BRAVE_API_KEY": os.environ['BRAVE_API_KEY'] | |
} | |
api_keys = get_api_keys() | |
or_client = OpenAI(api_key=api_keys["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1") | |
# In-memory storage for conversations | |
conversations: Dict[str, List[Dict[str, str]]] = {} | |
last_activity: Dict[str, float] = {} | |
# Token encoding | |
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
def limit_tokens(input_string, token_limit=6000): | |
return encoding.decode(encoding.encode(input_string)[:token_limit]) | |
def calculate_tokens(msgs): | |
return sum(len(encoding.encode(str(m))) for m in msgs) | |
def chat_with_llama_stream(messages, model="openai/gpt-4o-mini", max_llm_history=4, max_output_tokens=2500): | |
logger.info(f"Starting chat with model: {model}") | |
while calculate_tokens(messages) > (8000 - max_output_tokens): | |
if len(messages) > max_llm_history: | |
messages = [messages[0]] + messages[-max_llm_history:] | |
else: | |
max_llm_history -= 1 | |
if max_llm_history < 2: | |
error_message = "Token limit exceeded. Please shorten your input or start a new conversation." | |
logger.error(error_message) | |
raise HTTPException(status_code=400, detail=error_message) | |
try: | |
response = or_client.chat.completions.create( | |
model=model, | |
messages=messages, | |
max_tokens=max_output_tokens, | |
stream=True | |
) | |
full_response = "" | |
for chunk in response: | |
if chunk.choices[0].delta.content is not None: | |
content = chunk.choices[0].delta.content | |
full_response += content | |
yield content | |
# After streaming, add the full response to the conversation history | |
messages.append({"role": "assistant", "content": full_response}) | |
logger.info("Chat completed successfully") | |
except Exception as e: | |
logger.error(f"Error in model response: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error in model response: {str(e)}") | |
async def verify_api_key(api_key: str = Security(api_key_header)): | |
if api_key != API_KEY: | |
logger.warning("Invalid API key used") | |
raise HTTPException(status_code=403, detail="Could not validate credentials") | |
return api_key | |
# SQLite setup | |
DB_PATH = '/app/data/conversations.db' | |
def init_db(): | |
logger.info("Initializing database") | |
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) | |
conn = sqlite3.connect(DB_PATH) | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS conversations | |
(id INTEGER PRIMARY KEY AUTOINCREMENT, | |
user_id TEXT, | |
conversation_id TEXT, | |
message TEXT, | |
response TEXT, | |
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') | |
conn.commit() | |
conn.close() | |
logger.info("Database initialized successfully") | |
init_db() | |
def update_db(user_id, conversation_id, message, response): | |
logger.info(f"Updating database for conversation: {conversation_id}") | |
conn = sqlite3.connect(DB_PATH) | |
c = conn.cursor() | |
c.execute('''INSERT INTO conversations (user_id, conversation_id, message, response) | |
VALUES (?, ?, ?, ?)''', (user_id, conversation_id, message, response)) | |
conn.commit() | |
conn.close() | |
logger.info("Database updated successfully") | |
async def clear_inactive_conversations(): | |
while True: | |
logger.info("Clearing inactive conversations") | |
current_time = time.time() | |
inactive_convos = [conv_id for conv_id, last_time in last_activity.items() | |
if current_time - last_time > 1800] # 30 minutes | |
for conv_id in inactive_convos: | |
if conv_id in conversations: | |
del conversations[conv_id] | |
if conv_id in last_activity: | |
del last_activity[conv_id] | |
logger.info(f"Cleared {len(inactive_convos)} inactive conversations") | |
await asyncio.sleep(60) # Check every minute | |
async def startup_event(): | |
logger.info("Starting up the application") | |
FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache") | |
asyncio.create_task(clear_inactive_conversations()) | |
async def coding_assistant(query: QueryModel, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)): | |
""" | |
Coding assistant endpoint that provides programming help based on user queries. | |
Available models: | |
- meta-llama/llama-3-70b-instruct (default) | |
- anthropic/claude-3.5-sonnet | |
- deepseek/deepseek-coder | |
- anthropic/claude-3-haiku | |
- openai/gpt-3.5-turbo-instruct | |
- qwen/qwen-72b-chat | |
- google/gemma-2-27b-it | |
- openai/gpt-4o-mini | |
Requires API Key authentication via X-API-Key header. | |
""" | |
logger.info(f"Received coding assistant query: {query.user_query}") | |
if query.conversation_id not in conversations: | |
conversations[query.conversation_id] = [ | |
{"role": "system", "content": "You are a helpful assistant proficient in coding tasks. Help the user in understanding and writing code."} | |
] | |
conversations[query.conversation_id].append({"role": "user", "content": query.user_query}) | |
last_activity[query.conversation_id] = time.time() | |
# Limit tokens in the conversation history | |
limited_conversation = conversations[query.conversation_id] | |
def process_response(): | |
full_response = "" | |
for content in chat_with_llama_stream(limited_conversation, model=query.model_id): | |
full_response += content | |
yield content | |
background_tasks.add_task(update_db, query.user_id, query.conversation_id, query.user_query, full_response) | |
logger.info(f"Completed coding assistant response for query: {query.user_query}") | |
return StreamingResponse(process_response(), media_type="text/event-stream") | |
# New functions for news assistant | |
def internet_search(query, type = "web", num_results=20): | |
logger.info(f"Performing internet search for query: {query}, type: {type}") | |
if type == "web": | |
url = "https://api.search.brave.com/res/v1/web/search" | |
else: | |
url = "https://api.search.brave.com/res/v1/news/search" | |
headers = { | |
"Accept": "application/json", | |
"Accept-Encoding": "gzip", | |
"X-Subscription-Token": api_keys["BRAVE_API_KEY"] | |
} | |
params = {"q": query} | |
response = requests.get(url, headers=headers, params=params) | |
if response.status_code != 200: | |
logger.error(f"Failed to fetch search results. Status code: {response.status_code}") | |
return [] | |
if type == "web": | |
search_data = response.json()["web"]["results"] | |
else: | |
search_data = response.json()["results"] | |
processed_results = [] | |
for item in search_data: | |
if not item.get("extra_snippets"): | |
continue | |
result = { | |
"title": item["title"], | |
"snippet": item["extra_snippets"][0], | |
"last_updated": item.get("age", "") | |
} | |
processed_results.append(result) | |
logger.info(f"Retrieved {len(processed_results)} search results") | |
return processed_results[:num_results] | |
def cached_internet_search(query: str): | |
logger.info(f"Performing cached internet search for query: {query}") | |
return internet_search(query, type = "news") | |
def analyze_news(query): | |
logger.info(f"Analyzing news for query: {query}") | |
news_data = cached_internet_search(query) | |
if not news_data: | |
logger.error("Failed to fetch news data") | |
return "Failed to fetch news data.", [] | |
# Prepare the prompt for the AI | |
# Use the imported function to generate the prompt (now includes today's date) | |
prompt = generate_news_prompt(query, news_data) | |
messages = [ | |
{"role": "system", "content": NEWS_ASSISTANT_PROMPT}, | |
{"role": "user", "content": prompt} | |
] | |
logger.info("News analysis completed") | |
return messages | |
async def news_assistant(query: NewsQueryModel, api_key: str = Depends(verify_api_key)): | |
""" | |
News assistant endpoint that provides summaries and analysis of recent news based on user queries. | |
Requires API Key authentication via X-API-Key header. | |
""" | |
logger.info(f"Received news assistant query: {query.query}") | |
messages = analyze_news(query.query) | |
if not messages: | |
logger.error("Failed to fetch news data") | |
raise HTTPException(status_code=500, detail="Failed to fetch news data") | |
def process_response(): | |
for content in chat_with_llama_stream(messages, model=query.model_id): | |
yield content | |
logger.info(f"Completed news assistant response for query: {query.query}") | |
return StreamingResponse(process_response(), media_type="text/event-stream") | |
class SearchQueryModel(BaseModel): | |
query: str = Field(..., description="Search query") | |
model_id: ModelID = Field( | |
default="openai/gpt-4o-mini", | |
description="ID of the model to use for response generation" | |
) | |
class Config: | |
schema_extra = { | |
"example": { | |
"query": "What are the latest advancements in quantum computing?", | |
"model_id": "meta-llama/llama-3-70b-instruct" | |
} | |
} | |
def analyze_search_results(query): | |
logger.info(f"Analyzing search results for query: {query}") | |
search_data = internet_search(query, type="web") | |
if not search_data: | |
logger.error("Failed to fetch search data") | |
return "Failed to fetch search data.", [] | |
# Prepare the prompt for the AI | |
prompt = generate_search_prompt(query, search_data) | |
messages = [ | |
{"role": "system", "content": SEARCH_ASSISTANT_PROMPT}, | |
{"role": "user", "content": prompt} | |
] | |
logger.info("Search results analysis completed") | |
return messages | |
async def search_assistant(query: SearchQueryModel, api_key: str = Depends(verify_api_key)): | |
""" | |
Search assistant endpoint that provides summaries and analysis of web search results based on user queries. | |
Requires API Key authentication via X-API-Key header. | |
""" | |
logger.info(f"Received search assistant query: {query.query}") | |
messages = analyze_search_results(query.query) | |
if not messages: | |
logger.error("Failed to fetch search data") | |
raise HTTPException(status_code=500, detail="Failed to fetch search data") | |
def process_response(): | |
logger.info(f"Generating response using LLM: {messages}") | |
full_response = "" | |
for content in chat_with_llama_stream(messages, model=query.model_id): | |
full_response+=content | |
yield content | |
logger.info(f"Completed search assistant response for query: {query.query}") | |
logger.info(f"LLM Response: {full_response}") | |
return StreamingResponse(process_response(), media_type="text/event-stream") | |
from pydantic import BaseModel, Field | |
import yaml | |
import json | |
from yaml.loader import SafeLoader | |
class FollowupQueryModel(BaseModel): | |
query: str = Field(..., description="User's query for the followup agent") | |
model_id: ModelID = Field( | |
default="openai/gpt-4o-mini", | |
description="ID of the model to use for response generation" | |
) | |
conversation_id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the conversation") | |
user_id: str = Field(..., description="Unique identifier for the user") | |
class Config: | |
schema_extra = { | |
"example": { | |
"query": "How can I improve my productivity?", | |
"model_id": "openai/gpt-4o-mini", | |
"conversation_id": "123e4567-e89b-12d3-a456-426614174000", | |
"user_id": "user123" | |
} | |
} | |
FOLLOWUP_AGENT_PROMPT = """ | |
You are a helpful assistant with the following skills, use them, as necessary. If the user request needs further clarification, analyze it and generate clarifying questions with options. Else respond with a helpful answer. <response>response to user request in markdown</response> <clarification> questions: - text: [First clarifying question] options: - [Option 1] - [Option 2] - [Option 3] - [Option 4 (if needed)] - text: [Second clarifying question] options: - [Option 1] - [Option 2] - [Option 3] # Add more questions as needed # make sure this section is in valid YAML format </clarification> | |
""" | |
import re | |
def parse_followup_response(input_text): | |
# Define patterns for response and clarification | |
response_pattern = re.compile(r'<response>(.*?)<\/response>', re.DOTALL) | |
clarification_pattern = re.compile(r'<clarification>(.*?)<\/clarification>', re.DOTALL) | |
# Find all matches for response and clarification | |
response_matches = response_pattern.finditer(input_text) | |
clarification_matches = clarification_pattern.finditer(input_text) | |
# Initialize variables to keep track of the position | |
last_end = 0 | |
combined_response = "" | |
parsed_clarifications = [] | |
# Combine responses and capture everything in between | |
for response_match in response_matches: | |
# Capture text before the current response tag | |
combined_response += input_text[last_end:response_match.start()].strip() + "\n" | |
# Add the response content | |
combined_response += response_match.group(1).strip() + "\n" | |
# Update the last end position | |
last_end = response_match.end() | |
# Check for clarifications and parse them | |
for clarification_match in clarification_matches: | |
# Capture text before the current clarification tag | |
combined_response += input_text[last_end:clarification_match.start()].strip() + "\n" | |
# Process the clarification block | |
clarification_text = clarification_match.group(1).strip() | |
if clarification_text: | |
# Split by "text:" to separate each question block | |
question_blocks = clarification_text.split("- text:") | |
# Loop through each block and extract the question and its options | |
for block in question_blocks[1:]: | |
# Extract the question using regex (up to the "options:" part) | |
question_match = re.search(r'^(.*?)\s*options:', block, re.DOTALL) | |
if question_match: | |
question = question_match.group(1).strip() | |
# Extract the options using regex | |
options_match = re.search(r'options:\s*(.*?)$', block, re.DOTALL) | |
if options_match: | |
options = [option.strip() for option in options_match.group(1).split('-') if option.strip()] | |
# Add the parsed question and options to the list | |
parsed_clarifications.append({'question': question, 'options': options}) | |
# Update the last end position | |
last_end = clarification_match.end() | |
# Capture any remaining text after the last tag | |
combined_response += input_text[last_end:].strip() | |
return combined_response.strip(), parsed_clarifications | |
async def followup_agent(query: FollowupQueryModel, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)): | |
""" | |
Followup agent endpoint that provides helpful responses or generates clarifying questions based on user queries. | |
Requires API Key authentication via X-API-Key header. | |
""" | |
logger.info(f"Received followup agent query: {query.query}") | |
if query.conversation_id not in conversations: | |
conversations[query.conversation_id] = [ | |
{"role": "system", "content": FOLLOWUP_AGENT_PROMPT} | |
] | |
conversations[query.conversation_id].append({"role": "user", "content": query.query}) | |
last_activity[query.conversation_id] = time.time() | |
# Limit tokens in the conversation history | |
limited_conversation = conversations[query.conversation_id] | |
def process_response(): | |
full_response = "" | |
for content in chat_with_llama_stream(limited_conversation, model=query.model_id): | |
full_response += content | |
yield content | |
logger.info(f"LLM RAW response for query: {query.query}: {full_response}") | |
response_content, clarification = parse_followup_response(full_response) | |
result = { | |
"response": response_content, | |
"clarification": clarification | |
} | |
yield "\n\n" + json.dumps(result) | |
# Add the assistant's response to the conversation history | |
conversations[query.conversation_id].append({"role": "assistant", "content": full_response}) | |
background_tasks.add_task(update_db, query.user_id, query.conversation_id, query.query, full_response) | |
logger.info(f"Completed followup agent response for query: {query.query}") | |
return StreamingResponse(process_response(), media_type="text/event-stream") | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting the application") | |
uvicorn.run(app, host="0.0.0.0", port=7860) |