|
from langchain_core.runnables.schema import StreamEvent |
|
from gradio import ChatMessage |
|
from climateqa.engine.chains.prompts import audience_prompts |
|
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs |
|
import numpy as np |
|
|
|
def init_audience(audience :str) -> str: |
|
if audience == "Children": |
|
audience_prompt = audience_prompts["children"] |
|
elif audience == "General public": |
|
audience_prompt = audience_prompts["general"] |
|
elif audience == "Experts": |
|
audience_prompt = audience_prompts["experts"] |
|
else: |
|
audience_prompt = audience_prompts["experts"] |
|
return audience_prompt |
|
|
|
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]: |
|
""" |
|
Handles the retrieved documents and returns the HTML representation of the documents |
|
|
|
Args: |
|
event (StreamEvent): The event containing the retrieved documents |
|
history (list[ChatMessage]): The current message history |
|
used_documents (list[str]): The list of used documents |
|
|
|
Returns: |
|
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents |
|
""" |
|
try: |
|
docs = event["data"]["output"]["documents"] |
|
docs_html = [] |
|
textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"] |
|
for i, d in enumerate(textual_docs, 1): |
|
if d.metadata["chunk_type"] == "text": |
|
docs_html.append(make_html_source(d, i)) |
|
|
|
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs] |
|
if used_documents!=[]: |
|
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents)) |
|
|
|
docs_html = "".join(docs_html) |
|
|
|
related_contents = event["data"]["output"]["related_contents"] |
|
|
|
except Exception as e: |
|
print(f"Error getting documents: {e}") |
|
print(event) |
|
return docs, docs_html, history, used_documents, related_contents |
|
|
|
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]: |
|
""" |
|
Handles the streaming of the answer and updates the history with the new message content |
|
|
|
Args: |
|
history (list[ChatMessage]): The current message history |
|
event (StreamEvent): The event containing the streamed answer |
|
start_streaming (bool): A flag indicating if the streaming has started |
|
new_message_content (str): The content of the new message |
|
|
|
Returns: |
|
tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content |
|
""" |
|
if start_streaming == False: |
|
start_streaming = True |
|
history.append(ChatMessage(role="assistant", content = "")) |
|
answer_message_content += event["data"]["chunk"].content |
|
answer_message_content = parse_output_llm_with_sources(answer_message_content) |
|
history[-1] = ChatMessage(role="assistant", content = answer_message_content) |
|
|
|
return history, start_streaming, answer_message_content |
|
|
|
def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str: |
|
""" |
|
Handles the retrieved OWID graphs and returns the HTML representation of the graphs |
|
|
|
Args: |
|
event (StreamEvent): The event containing the retrieved graphs |
|
graphs_html (str): The current HTML representation of the graphs |
|
|
|
Returns: |
|
str: The updated HTML representation |
|
""" |
|
try: |
|
recommended_content = event["data"]["output"]["recommended_content"] |
|
|
|
unique_graphs = [] |
|
seen_embeddings = set() |
|
|
|
for x in recommended_content: |
|
embedding = x.metadata["returned_content"] |
|
|
|
|
|
if embedding not in seen_embeddings: |
|
unique_graphs.append({ |
|
"embedding": embedding, |
|
"metadata": { |
|
"source": x.metadata["source"], |
|
"category": x.metadata["category"] |
|
} |
|
}) |
|
|
|
seen_embeddings.add(embedding) |
|
|
|
|
|
categories = {} |
|
for graph in unique_graphs: |
|
category = graph['metadata']['category'] |
|
if category not in categories: |
|
categories[category] = [] |
|
categories[category].append(graph['embedding']) |
|
|
|
|
|
for category, embeddings in categories.items(): |
|
graphs_html += f"<h3>{category}</h3>" |
|
for embedding in embeddings: |
|
graphs_html += f"<div>{embedding}</div>" |
|
|
|
|
|
except Exception as e: |
|
print(f"Error getting graphs: {e}") |
|
|
|
return graphs_html |