Spaces:
Runtime error
Runtime error
| from src.information_retrieval import info_retrieval as ir | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(encoding='utf-8', level=logging.DEBUG) | |
| from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT | |
| def generate_prompt(query, context, template=None): | |
| """ | |
| Function that generates the prompt given the user query and retrieved context. A specific prompt template will be | |
| used if provided, otherwise the default base_prompt template is used. | |
| Args: | |
| - query: str | |
| - context: list[dict] | |
| - template: str | |
| """ | |
| if template: | |
| SYS_PROMPT = template | |
| else: | |
| SYS_PROMPT = SYSTEM_PROMPT | |
| formatted_prompt = f"{USER_PROMPT.format(query, context)}" | |
| messages = [ | |
| {"role": "system", "content": SYS_PROMPT}, | |
| {"role": "user", "content": formatted_prompt} | |
| ] | |
| return messages | |
| def format_context(context): | |
| """ | |
| Function that formats the retrieved context in a way that is easy for the LLM to understand. | |
| Args: | |
| - context: list[dict]; retrieved context | |
| """ | |
| formatted_context = '' | |
| for i, (city, info) in enumerate(context.items()): | |
| text = f"Option {i + 1}: {city} is a city in {info['country']}." | |
| info_text = f"Here is some information about the city. {info['text']}" | |
| attractions_text = "Here are some attractions: " | |
| att_flag = 0 | |
| restaurants_text = "Here are some places to eat/drink: " | |
| rest_flag = 0 | |
| hotels_text = "Here are some hotels: " | |
| hotel_flag = 0 | |
| if len(info['listings']): | |
| for listing in info['listings']: | |
| if listing['type'] in ['see', 'do', 'go', 'view']: | |
| att_flag = 1 | |
| attractions_text += f"{listing['name']} ({listing['description']}), " | |
| elif listing['type'] in ['eat', 'drink']: | |
| rest_flag = 1 | |
| restaurants_text += f"{listing['name']} ({listing['description']}), " | |
| else: | |
| hotel_flag = 1 | |
| hotels_text += f"{listing['name']} ({listing['description']}), " | |
| # If we add sustainability in the end then it could get truncated because of context window | |
| if "sustainability" in info: | |
| if info['sustainability']['month'] == 'No data available': | |
| sfairness_text = "This city has no sustainability (or s-fairness) score available." | |
| else: | |
| sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n " | |
| text += sfairness_text | |
| text += info_text | |
| if att_flag: | |
| text += f"\n{attractions_text}" | |
| if rest_flag: | |
| text += f"\n{restaurants_text}" | |
| if hotel_flag: | |
| text += f"\n{hotels_text}" | |
| formatted_context += text + "\n\n " | |
| return formatted_context | |
| def augment_prompt(query: str, starting_point: str, context: dict, **params: dict): | |
| """ | |
| Function that accepts the user query as input, obtains relevant documents and augments the prompt with the | |
| retrieved context, which can be passed to the LLM. | |
| Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested | |
| dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness | |
| scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets | |
| the limit of results and whether to rerank the results | |
| """ | |
| # what about the cities without s-fairness scores? i.e. they don't have seasonality data | |
| updated_query = f"With {starting_point} as the starting point, {query}" | |
| prompt_with_sustainability = SUSTAINABILITY_PROMPT | |
| # format context | |
| formatted_context = format_context(context) | |
| if "sustainability" in params["params"] and params["params"]["sustainability"]: | |
| prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability) | |
| else: | |
| prompt = generate_prompt(updated_query, formatted_context) | |
| return prompt | |
| def test(): | |
| context_params = { | |
| 'limit': 3, | |
| 'reranking': 0, | |
| 'sustainability': 0 | |
| } | |
| query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \ | |
| "in winter. " | |
| # without sustainability | |
| context = ir.get_context(query, **context_params) | |
| without_sfairness = augment_prompt( | |
| query=query, | |
| context=context, | |
| params=context_params | |
| ) | |
| # with sustainability | |
| context_params.update({'sustainability': 1}) | |
| s_context = ir.get_context(query, **context_params) | |
| # s_formatted_context = format_context(s_context) | |
| with_sfairness = augment_prompt( | |
| query=query, | |
| context=s_context, | |
| params=context_params | |
| ) | |
| return with_sfairness | |
| if __name__ == "__main__": | |
| prompt = test() | |
| print(prompt) | |