Spaces:
Sleeping
Sleeping
import gradio as gr | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.nodes import EmbeddingRetriever | |
import openai | |
import pandas as pd | |
import os | |
from utils import ( | |
make_pairs, | |
set_openai_api_key, | |
create_user_id, | |
to_completion, | |
) | |
from datetime import datetime | |
# from azure.storage.fileshare import ShareServiceClient | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except: | |
pass | |
theme = gr.themes.Soft( | |
primary_hue="sky", | |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
init_prompt = ( | |
"TKOQA, an AI Assistant for Tikehau. " | |
) | |
sources_prompt = ( | |
"When relevant, use facts and numbers from the following documents in your answer. " | |
) | |
def get_reformulation_prompt(query: str) -> str: | |
return f"""Reformulate the following user message to be a short standalone question in English, in the context of the Universal Registration Document of Tikehau . | |
--- | |
query: what is the AUM of Tikehau in 2022? | |
standalone question: What is the AUM of TIkehau in 2022? | |
language: English | |
--- | |
query: what is T2? | |
standalone question: what is the transition energy fund at Tikehau? | |
language: English | |
--- | |
query: what is the business of Tikehau? | |
standalone question: What are the main business units of Tikehau? | |
language: English | |
--- | |
query: {query} | |
standalone question:""" | |
system_template = { | |
"role": "system", | |
"content": init_prompt, | |
} | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
# BHO | |
# openai.api_base = os.environ["ressource_endpoint"] | |
# openai.api_version = "2022-12-01" | |
document_store = FAISSDocumentStore() | |
ds = FAISSDocumentStore.load(index_path="./tko_urd.faiss", config_path="./tko_urd.json",) | |
retriever = EmbeddingRetriever( | |
document_store=ds, | |
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1", | |
model_format="sentence_transformers", | |
progress_bar=False, | |
) | |
# retrieve_giec = EmbeddingRetriever( | |
# document_store=FAISSDocumentStore.load( | |
# index_path="./documents/climate_gpt_v2_only_giec.faiss", | |
# config_path="./documents/climate_gpt_v2_only_giec.json", | |
# ), | |
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1", | |
# model_format="sentence_transformers", | |
# ) | |
# BHO | |
# For Azure connection in secrets in HuggingFace | |
# credential = { | |
# "account_key": os.environ["account_key"], | |
# "account_name": os.environ["account_name"], | |
# } | |
# BHO | |
# account_url = os.environ["account_url"] | |
# file_share_name = "climategpt" | |
# service = ShareServiceClient(account_url=account_url, credential=credential) | |
# share_client = service.get_share_client(file_share_name) | |
user_id = create_user_id(10) | |
def filter_sources(df, k_summary=3, k_total=10, source="ipcc"): | |
assert source in ["ipcc", "ipbes", "all"] | |
# Filter by source | |
if source == "ipcc": | |
df = df.loc[df["source"] == "IPCC"] | |
elif source == "ipbes": | |
df = df.loc[df["source"] == "IPBES"] | |
else: | |
pass | |
# Prepare summaries | |
df_summaries = df #.loc[df.loc.obj.values] | |
# Separate summaries and full reports | |
#df_summaries = df.loc[df["report_type"].isin(["SPM", "TS"])] | |
#df_full = df.loc[~df["report_type"].isin(["SPM", "TS"])] | |
# Find passages from summaries dataset | |
passages_summaries = df_summaries.head(k_summary) | |
# Find passages from full reports dataset | |
# passages_fullreports = df_full.head(k_total - len(passages_summaries)) | |
# Concatenate passages | |
#passages = pd.concat([passages_summaries, passages_fullreports], axis=0, ignore_index=True) | |
passages = passages_summaries | |
return passages | |
def retrieve_with_summaries(query, retriever, k_summary=3, k_total=10, source="ipcc", max_k=100, threshold=0.555, | |
as_dict=True): | |
assert max_k > k_total | |
docs = retriever.retrieve(query, top_k=max_k) | |
docs = [{**x.meta, "score": x.score, "content": x.content} for x in docs if x.score > threshold] | |
if len(docs) == 0: | |
return [] | |
res = pd.DataFrame(docs) | |
passages_df = filter_sources(res, k_summary, k_total, source) | |
if as_dict: | |
contents = passages_df["content"].tolist() | |
meta = passages_df.drop(columns=["content"]).to_dict(orient="records") | |
passages = [] | |
for i in range(len(contents)): | |
passages.append({"content": contents[i], "meta": meta[i]}) | |
return passages | |
else: | |
return passages_df | |
def make_html_source(source, i): | |
meta = source['meta'] | |
return f""" | |
<div class="card"> | |
<div class="card-content"> | |
<h2>Doc {i} - {meta['file_name']} - Page {meta['page_number']}</h2> | |
<p>{source['content']}</p> | |
</div> | |
</div> | |
""" | |
def chat( | |
user_id: str, | |
query: str, | |
history: list = [system_template], | |
report_type: str = "All available", | |
threshold: float = 0.555, | |
) -> tuple: | |
"""retrieve relevant documents in the document store then query gpt-turbo | |
Args: | |
query (str): user message. | |
history (list, optional): history of the conversation. Defaults to [system_template]. | |
report_type (str, optional): should be "All available" or "IPCC only". Defaults to "All available". | |
threshold (float, optional): similarity threshold, don't increase more than 0.568. Defaults to 0.56. | |
Yields: | |
tuple: chat gradio format, chat openai format, sources used. | |
""" | |
if report_type not in ["IPCC", "IPBES"]: report_type = "all" | |
print("Searching in ", report_type, " reports") | |
reformulated_query = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=get_reformulation_prompt(query), | |
temperature=0, | |
max_tokens=128, | |
stop=["\n---\n", "<|im_end|>"], | |
) | |
reformulated_query = reformulated_query["choices"][0]["text"] | |
reformulated_query, language = reformulated_query.split("\n") | |
language = language.split(":")[1].strip() | |
sources = retrieve_with_summaries(reformulated_query, retriever, k_total=10, k_summary=3, as_dict=True, | |
source=report_type.lower(), threshold=threshold) | |
response_retriever = { | |
"language": language, | |
"reformulated_query": reformulated_query, | |
"query": query, | |
"sources": sources, | |
} | |
# docs = [d for d in retriever.retrieve(query=reformulated_query, top_k=10) if d.score > threshold] | |
messages = history + [{"role": "user", "content": query}] | |
if len(sources) > 0: | |
docs_string = [] | |
docs_html = [] | |
for i, d in enumerate(sources, 1): | |
#docs_string.append(f"📃 Doc {i}: {d['meta']['short_name']} page {d['meta']['page_number']}\n{d['content']}") | |
docs_string.append(f"📃 Doc {i}: {d['meta']['file_name']} page {d['meta']['page_number']}\n{d['content']}") | |
docs_html.append(make_html_source(d, i)) | |
docs_string = "\n\n".join([f"Query used for retrieval:\n{reformulated_query}"] + docs_string) | |
docs_html = "\n\n".join([f"Query used for retrieval:\n{reformulated_query}"] + docs_html) | |
messages.append({"role": "system", "content": f"{sources_prompt}\n\n{docs_string}\n\nAnswer in {language}:"}) | |
response = openai.Completion.create( | |
# engine="climateGPT", | |
engine="text-davinci-003", | |
prompt=to_completion(messages), | |
temperature=0, # deterministic | |
stream=True, | |
max_tokens=1024, | |
) | |
complete_response = "" | |
messages.pop() | |
messages.append({"role": "assistant", "content": complete_response}) | |
timestamp = str(datetime.now().timestamp()) | |
file = user_id[0] + timestamp + ".json" | |
logs = { | |
"user_id": user_id[0], | |
"prompt": query, | |
"retrived": sources, | |
"report_type": report_type, | |
"prompt_eng": messages[0], | |
"answer": messages[-1]["content"], | |
"time": timestamp, | |
} | |
# log_on_azure(file, logs, share_client) | |
print(logs) | |
for chunk in response: | |
if (chunk_message := chunk["choices"][0].get("text")) and chunk_message != "<|im_end|>": | |
complete_response += chunk_message | |
messages[-1]["content"] = complete_response | |
gradio_format = make_pairs([a["content"] for a in messages[1:]]) | |
yield gradio_format, messages, docs_html | |
else: | |
docs_string = "⚠️ No relevant passages found in the URDs" | |
complete_response = "**⚠️ No relevant passages found in the URDs **" | |
messages.append({"role": "assistant", "content": complete_response}) | |
gradio_format = make_pairs([a["content"] for a in messages[1:]]) | |
yield gradio_format, messages, docs_string | |
def save_feedback(feed: str, user_id): | |
if len(feed) > 1: | |
timestamp = str(datetime.now().timestamp()) | |
file = user_id[0] + timestamp + ".json" | |
logs = { | |
"user_id": user_id[0], | |
"feedback": feed, | |
"time": timestamp, | |
} | |
# log_on_azure(file, logs, share_client) | |
print(logs) | |
return "Feedback submitted, thank you!" | |
def reset_textbox(): | |
return gr.update(value="") | |
# def log_on_azure(file, logs, share_client): | |
# file_client = share_client.get_file_client(file) | |
# file_client.upload_file(str(logs)) | |
with gr.Blocks(title="TKO URD Q&A", css="style.css", theme=theme) as demo: | |
user_id_state = gr.State([user_id]) | |
# Gradio | |
gr.Markdown("<h1><center>Tikehau Capital Q&A </center></h1>") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(elem_id="chatbot", label=" Tikehau Capital Q&A chatbot", show_label=False) | |
state = gr.State([system_template]) | |
with gr.Row(): | |
ask = gr.Textbox( | |
show_label=True, | |
placeholder="Ask here your Tikehau-related question and press enter", | |
).style(container=False) | |
#ask_examples_hidden = gr.Textbox(elem_id="hidden-message") | |
# examples_questions = gr.Examples( | |
# [ | |
# "What is the AUM of Tikehau in 2022?", | |
# ], | |
# [ask_examples_hidden], | |
# examples_per_page=15, | |
#) | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("### Sources") | |
sources_textbox = gr.Markdown(show_label=False) | |
# dropdown_sources = gr.inputs.Dropdown( | |
# ["IPCC", "IPBES", "ALL"], | |
# default="ALL", | |
# label="Select reports", | |
# ) | |
dropdown_sources = gr.State(["All"]) | |
ask.submit( | |
fn=chat, | |
inputs=[ | |
user_id_state, | |
ask, | |
state, | |
dropdown_sources | |
], | |
outputs=[chatbot, state, sources_textbox], | |
) | |
ask.submit(reset_textbox, [], [ask]) | |
# ask_examples_hidden.change( | |
# fn=chat, | |
# inputs=[ | |
# user_id_state, | |
# ask_examples_hidden, | |
# state, | |
# dropdown_sources | |
# ], | |
# outputs=[chatbot, state, sources_textbox], | |
# ) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
<div class="warning-box"> | |
Version 0.1-beta - This tool is under active development | |
</div> | |
""" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("*Source : Tikehau Universal Registration Documents *") | |
gr.Markdown("## How to use TKO URD Q&A") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
### 💪 Getting started | |
- In the chatbot section, simply type your Tikehau-related question, answers will be provided with references to relevant URDs. | |
""" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
### ⚠️ Limitations | |
<div class="warning-box"> | |
<ul> | |
<li>Please note that, like any AI, the model may occasionally generate an inaccurate or imprecise answer.</li> | |
</div> | |
""" | |
) | |
gr.Markdown("## 🙏 Feedback and feature requests") | |
gr.Markdown( | |
""" | |
### Beta test | |
- Feedback welcome. Inspired from the Climate tool by Ekimetrics. | |
""" | |
) | |
gr.Markdown( | |
""" | |
## 🛢️ Carbon Footprint | |
Carbon emissions were measured during the development and inference process using CodeCarbon [https://github.com/mlco2/codecarbon](https://github.com/mlco2/codecarbon) | |
| Phase | Description | Emissions | Source | | |
| --- | --- | --- | --- | | |
| Inference | API call to turbo-GPT | ~0.38gCO2e / call | https://medium.com/@chrispointon/the-carbon-footprint-of-chatgpt-e1bc14e4cc2a | | |
Carbon Emissions are **relatively low but not negligible** compared to other usages: one question asked to TKO Q&A is around 0.482gCO2e - equivalent to 2.2m by car (https://datagir.ademe.fr/apps/impact-co2/) | |
Or around 2 to 4 times more than a typical Google search. | |
</b>. | |
""" | |
) | |
demo.queue(concurrency_count=16) | |
demo.launch() | |