Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
import requests | |
import streamlit as st | |
def truncate_to_tokens(text, max_tokens): | |
""" | |
Truncate a text to an approximate token count by splitting on whitespace. | |
Args: | |
text (str): The text to truncate. | |
max_tokens (int): Maximum number of tokens/words to keep. | |
Returns: | |
str: The truncated text. | |
""" | |
tokens = text.split() | |
if len(tokens) > max_tokens: | |
return " ".join(tokens[:max_tokens]) | |
return text | |
def build_context_for_result(res, compute_title_fn): | |
""" | |
Build a context string (title + objective + description) from a search result. | |
Args: | |
res (dict): A result dictionary with 'payload' key containing metadata. | |
compute_title_fn (callable): Function to compute the title from metadata. | |
Returns: | |
str: Combined text from title, objective, and description. | |
""" | |
metadata = res.payload.get('metadata', {}) | |
title = metadata.get("title", compute_title_fn(metadata)) | |
objective = metadata.get("objective", "") | |
desc_en = metadata.get("description.en", "").strip() | |
desc_de = metadata.get("description.de", "").strip() | |
description = desc_en if desc_en else desc_de | |
return f"{title}\n{objective}\n{description}" | |
def highlight_query(text, query): | |
""" | |
Highlight the query text in the given string with simple bold markdown. | |
Args: | |
text (str): The full text in which to highlight matches. | |
query (str): The substring (query) to highlight. | |
Returns: | |
str: The markdown-formatted string with highlighted matches. | |
""" | |
pattern = re.compile(re.escape(query), re.IGNORECASE) | |
return pattern.sub(lambda m: f"**{m.group(0)}**", text) | |
def format_project_id(pid): | |
""" | |
Format a numeric project ID into the typical GIZ format (e.g. '201940485' -> '2019.4048.5'). | |
Args: | |
pid (str|int): The project ID to format. | |
Returns: | |
str: Formatted project ID if it has enough digits, otherwise the original string. | |
""" | |
s = str(pid) | |
if len(s) > 5: | |
return s[:4] + "." + s[4:-1] + "." + s[-1] | |
return s | |
def compute_title(metadata): | |
""" | |
Compute a default title from metadata using name.en (or name.de if empty). | |
If an ID is present, append it in brackets. | |
Args: | |
metadata (dict): Project metadata dictionary. | |
Returns: | |
str: Computed title string or 'No Title'. | |
""" | |
name_en = metadata.get("name.en", "").strip() | |
name_de = metadata.get("name.de", "").strip() | |
base = name_en if name_en else name_de | |
pid = metadata.get("id", "") | |
if base and pid: | |
return f"{base} [{format_project_id(pid)}]" | |
return base or "No Title" | |
def get_rag_answer(query, top_results, endpoint, token): | |
""" | |
Send a prompt to the LLM endpoint, including the context from top results. | |
Args: | |
query (str): The user question. | |
top_results (list): List of top search results from which to build context. | |
endpoint (str): The HuggingFace Inference endpoint URL. | |
token (str): The Bearer token (from st.secrets, for instance). | |
Returns: | |
str: The LLM-generated answer, or an error message if the call fails. | |
""" | |
# Build the context | |
from appStore.rag_utils import truncate_to_tokens, build_context_for_result, compute_title | |
context = "\n\n".join([build_context_for_result(res, compute_title) for res in top_results]) | |
context = truncate_to_tokens(context,11500) # Truncate to ~11.5k tokens | |
prompt = ( | |
"You are a project portfolio adviser at the development cooperation GIZ. " | |
"Using the context below, answer the question in the same language as the question. " | |
"Your answer must be formatted in bullet points. " | |
"Ensure that every project title and project number in your answer is wrapped in double asterisks (e.g., **Project Title [2018.2101.6]**) to display them as markdown bold. " | |
"Include at least one short sentence per project summarizing what the project does in relation to the query. " | |
"Do not repeat any part of the provided context or the question in your final answer.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question: {query}\n\n" | |
"Answer:" | |
) | |
headers = {"Authorization": f"Bearer {token}"} | |
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 300}} | |
response = requests.post(endpoint, headers=headers, json=payload) | |
if response.status_code == 200: | |
result = response.json() | |
answer = result[0].get("generated_text", "") | |
if "Answer:" in answer: | |
answer = answer.split("Answer:")[-1].strip() | |
return answer | |
elif response.status_code == 503: | |
# Custom message with a larger llama icon and red highlighted text | |
return ( | |
"<span style='color: red;'>" | |
"<span style='font-size: 3em;'>🦙</span> Tzzz Tzzz I'm currently sleeping. " | |
"Please come back in 10 minutes, and I'll be fully awake to answer your question." | |
"</span>" | |
) | |
else: | |
return f"Error in generating answer: {response.text}" | |