timeki's picture
Add content recommandation (#17)
bcc8503 verified
from langchain_core.runnables.schema import StreamEvent
from gradio import ChatMessage
from climateqa.engine.chains.prompts import audience_prompts
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
import numpy as np
def init_audience(audience :str) -> str:
if audience == "Children":
audience_prompt = audience_prompts["children"]
elif audience == "General public":
audience_prompt = audience_prompts["general"]
elif audience == "Experts":
audience_prompt = audience_prompts["experts"]
else:
audience_prompt = audience_prompts["experts"]
return audience_prompt
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
"""
Handles the retrieved documents and returns the HTML representation of the documents
Args:
event (StreamEvent): The event containing the retrieved documents
history (list[ChatMessage]): The current message history
used_documents (list[str]): The list of used documents
Returns:
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
"""
try:
docs = event["data"]["output"]["documents"]
docs_html = []
textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
for i, d in enumerate(textual_docs, 1):
if d.metadata["chunk_type"] == "text":
docs_html.append(make_html_source(d, i))
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
if used_documents!=[]:
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
docs_html = "".join(docs_html)
related_contents = event["data"]["output"]["related_contents"]
except Exception as e:
print(f"Error getting documents: {e}")
print(event)
return docs, docs_html, history, used_documents, related_contents
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
"""
Handles the streaming of the answer and updates the history with the new message content
Args:
history (list[ChatMessage]): The current message history
event (StreamEvent): The event containing the streamed answer
start_streaming (bool): A flag indicating if the streaming has started
new_message_content (str): The content of the new message
Returns:
tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
"""
if start_streaming == False:
start_streaming = True
history.append(ChatMessage(role="assistant", content = ""))
answer_message_content += event["data"]["chunk"].content
answer_message_content = parse_output_llm_with_sources(answer_message_content)
history[-1] = ChatMessage(role="assistant", content = answer_message_content)
# history.append(ChatMessage(role="assistant", content = new_message_content))
return history, start_streaming, answer_message_content
def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
"""
Handles the retrieved OWID graphs and returns the HTML representation of the graphs
Args:
event (StreamEvent): The event containing the retrieved graphs
graphs_html (str): The current HTML representation of the graphs
Returns:
str: The updated HTML representation
"""
try:
recommended_content = event["data"]["output"]["recommended_content"]
unique_graphs = []
seen_embeddings = set()
for x in recommended_content:
embedding = x.metadata["returned_content"]
# Check if the embedding has already been seen
if embedding not in seen_embeddings:
unique_graphs.append({
"embedding": embedding,
"metadata": {
"source": x.metadata["source"],
"category": x.metadata["category"]
}
})
# Add the embedding to the seen set
seen_embeddings.add(embedding)
categories = {}
for graph in unique_graphs:
category = graph['metadata']['category']
if category not in categories:
categories[category] = []
categories[category].append(graph['embedding'])
for category, embeddings in categories.items():
graphs_html += f"<h3>{category}</h3>"
for embedding in embeddings:
graphs_html += f"<div>{embedding}</div>"
except Exception as e:
print(f"Error getting graphs: {e}")
return graphs_html