File size: 9,411 Bytes
b1c8f17
26e0ddc
2e76cf7
26e0ddc
 
 
 
 
 
2e76cf7
 
22917d7
26e0ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1c8f17
26e0ddc
b1c8f17
2e76cf7
 
 
 
b1c8f17
26e0ddc
1fc729a
26e0ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdd570c
f87e22f
26e0ddc
 
 
 
 
 
00eb6f3
3963c7b
26e0ddc
1fc729a
7a5d35d
26e0ddc
 
7a5d35d
26e0ddc
 
 
 
 
 
 
 
 
 
005f2d7
26e0ddc
005f2d7
26e0ddc
 
 
 
 
 
 
0b115cd
b1c8f17
0b115cd
00eb6f3
b998742
26e0ddc
b1c8f17
1fc729a
26e0ddc
 
 
 
 
 
 
 
c8510e0
 
 
 
 
26e0ddc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
from typing import List, Dict, Any
from datetime import datetime, timedelta
import re
from functools import lru_cache

from fastapi import FastAPI, HTTPException, Request, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
from dotenv import load_dotenv

from helper_functions_api import (
    has_tables, extract_data_from_tag, openrouter_response, md_to_html,
    search_brave, fetch_and_extract_content, limit_tokens, together_response, insert_data
)

# Load environment variables
load_dotenv()

# Constants
LLM_MODELS = {
    "default": {
        "small": "llama3-8b-8192",
        "medium": "llama3-70b-8192"
    },
    "fallback": {
        "small": "meta-llama/Llama-3-8b-chat-hf",
        "medium": "meta-llama/Llama-3-70b-chat-hf"
    }
}

SYSTEM_PROMPTS = {
    "json": "You are now in the role of an expert AI who can extract structured information from user request. Both key and value pairs must be in double quotes. You must respond ONLY with a valid JSON file. Do not add any additional comments.",
    "list": "You are now in the role of an expert AI who can extract structured information from user request. All elements must be in double quotes. You must respond ONLY with a valid python List. Do not add any additional comments.",
    "default": "You are an expert AI, complete the given task. Do not add any additional comments.",
    "md": "You are an expert AI who can create a structured report using information provided in the context from user request. The report should be in markdown format consists of markdown tables structured into subtopics. Do not add any additional comments.",
    "online": """You are an expert AI who can create a detailed structured report using internet search results.
                1. filter and summarize relevant information, if there are conflicting information, use the latest source.
                2. use it to construct a clear and factual answer.
                Your response should be structured and properly formatted using markdown headings, subheadings, tables, use as necessary. Ignore Links and references""",
    "offline": "You are an expert AI who can create detailed answers. Your response should be properly formatted and well readable using markdown formatting."
}

# Prompt templates
PROMPT_TEMPLATES = {
    "online": {
        "chat": "Write a well thought out, detailed and structured answer to the query:: {description} #### , refer the provided internet search results reference:{reference}",
        "report": "Write a well thought out, detailed and structured Report to the query:: {description} #### , refer the provided internet search results reference:{reference}, The report should be well formatted using markdown format structured into subtopics as necessary",
        "report_table": "Write a well thought out Report to the query:: {description},#### , refer the provided internet search results reference:{reference}. The report should be well formatted using markdown format, structured into subtopics, include tables or lists as needed to make it well readable"
    },
    "offline": {
        "chat": "Write a well thought out, detailed and structured answer to the query:: {description}",
        "report": "Write a well thought out, detailed and structured Report to the query:: {description}. The report should be well formatted using markdown format, structured into subtopics",
        "report_table": "Write a detailed and structured Report to the query:: {description}, The report should be well formatted using markdown format, structured into subtopics, include tables or lists as needed to make it well readable"
    }
}

# FastAPI app setup
app = FastAPI()

@app.on_event("startup")
async def startup():
    FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache")

# Pydantic model for query parameters
class QueryModel(BaseModel):
    user_query: str = Field(default="", description="Initial user query")
    topic: str = Field(default="", description="Topic name to generate Report")
    description: str = Field(..., description="Description/prompt for report (REQUIRED)")
    user_id: str = Field(default="", description="unique user id")
    user_name: str = Field(default="", description="user name")
    internet: bool = Field(default=True, description="Enable Internet search")
    output_format: str = Field(default="report_table", description="Output format for the report")
    data_format: str = Field(default="Structured data", description="Type of data to extract from the internet")
    generate_charts: bool = Field(default=False, description="Include generated charts")
    output_as_md: bool = Field(default=False, description="Output report in markdown (default output in HTML)")

    class Config:
        schema_extra = {
            "example": {
                "user_query": "How does climate change affect biodiversity?",
                "topic": "Climate Change and Biodiversity",
                "description": "Provide a detailed report on the impacts of climate change on global biodiversity",
                "user_id": "user123",
                "user_name": "John Doe",
                "internet": True,
                "output_format": "report_table",
                "data_format": "Structured data",
                "generate_charts": True,
                "output_as_md": False
            }
        }

@lru_cache()
def get_api_keys():
    return {
        "TOGETHER_API_KEY": os.getenv('TOGETHER_API_KEY'),
        "BRAVE_API_KEY": os.getenv('BRAVE_API_KEY'),
        "GROQ_API_KEY": os.getenv("GROQ_API_KEY"),
        "HELICON_API_KEY": os.getenv("HELICON_API_KEY"),
        "SUPABASE_USER": os.environ['SUPABASE_USER'],
        "SUPABASE_PASSWORD": os.environ['SUPABASE_PASSWORD'],
        "OPENROUTER_API_KEY": f"sk-or-v1-{os.environ['OPENROUTER_API_KEY']}"
    }

def get_internet_data(description: str, data_format: str):
    search_query = re.sub(r'[^\w\s]', '', description).strip()
    urls, optimized_search_query, full_search_object = search_brave(search_query, num_results=8)
    all_text_with_urls = fetch_and_extract_content(data_format, urls, optimized_search_query)
    reference = limit_tokens(str(all_text_with_urls), token_limit=5000)
    return all_text_with_urls, optimized_search_query, full_search_object, reference

def generate_charts(md_report: str):
    chart_prompt = (
        "Convert the numerical data tables in the given content to embedded html plotly.js charts if appropriate, "
        "use appropriate colors. Output format: <report>output the full content without any other changes in md "
        f"format enclosed in tags like this</report> using the following: {md_report}"
    )
    messages = [{"role": 'user', "content": chart_prompt}]
    return extract_data_from_tag(openrouter_response(messages, model="anthropic/claude-3.5-sonnet"), "report")

@cache(expire=604800)
async def generate_report(query: QueryModel, api_keys: Dict[str, str] = Depends(get_api_keys)):
    internet_mode = "online" if query.internet else "offline"
    user_prompt = PROMPT_TEMPLATES[internet_mode][query.output_format]
    system_prompt = SYSTEM_PROMPTS[internet_mode]
    
    all_text_with_urls = []
    optimized_search_query = ""
    full_search_object = {}

    if query.internet:
        try:
            all_text_with_urls, optimized_search_query, full_search_object, reference = get_internet_data(query.description, query.data_format)
            user_prompt = user_prompt.format(description=query.description, reference=reference)
        except Exception as e:
            print(f"Failed to search/scrape results: {e}")
            internet_mode = "offline"
            user_prompt = PROMPT_TEMPLATES[internet_mode][query.output_format].format(description=query.description)
            system_prompt = SYSTEM_PROMPTS[internet_mode]
    else:
        user_prompt = user_prompt.format(description=query.description)

    md_report = together_response(user_prompt, model=LLM_MODELS["default"]["medium"], SysPrompt=system_prompt)
    
    if query.generate_charts and has_tables(md_to_html(md_report)):
        try:
            md_report = generate_charts(md_report)
        except Exception as e:
            print(f"Failed to generate charts: {e}")

    if query.user_id != "test":
        insert_data(query.user_id, query.topic, query.description, str(all_text_with_urls), md_report)

    references_html = {url: str(md_to_html(text)) for text, url in all_text_with_urls}
    final_report = md_report if query.output_as_md else md_to_html(md_report)
    
    return {
        "report": final_report,
        "references": references_html,
        "search_query": optimized_search_query,
        "search_data_full": full_search_object
    }

@app.post("/generate_report", response_model=Dict[str, Any])
async def api_generate_report(query: QueryModel, api_keys: Dict[str, str] = Depends(get_api_keys)):
    try:
        return await generate_report(query, api_keys)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# CORS middleware setup
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)