Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from langchain_community.document_loaders import JSONLoader | |
| from langchain_community.vectorstores import Qdrant | |
| from qdrant_client.http import models as rest | |
| from qdrant_client import QdrantClient, models | |
| from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings | |
| from sentence_transformers.cross_encoder import CrossEncoder | |
| from groq import Groq | |
| client = Groq( | |
| api_key=os.environ.get("GROQ_API"), | |
| ) | |
| # loading data | |
| json_path = "format_food.json" | |
| def metadata_func(record: dict, metadata: dict) -> dict: | |
| metadata["title"] = record.get("title") | |
| metadata["cuisine"] = record.get("cuisine") | |
| metadata["time"] = record.get("time") | |
| metadata["instructions"] = record.get("instructions") | |
| return metadata | |
| def reranking_results(query, top_k_results, rerank_model): | |
| # Load the model, here we use our base sized model | |
| top_results_formatted = [f"{item.metadata['title']}, {item.page_content}" for item in top_k_results] | |
| reranked_results = rerank_model.rank(query, top_results_formatted, return_documents=True) | |
| return reranked_results | |
| loader = JSONLoader( | |
| file_path=json_path, | |
| jq_schema='.dishes[].dish', | |
| text_content=False, | |
| content_key='doc', | |
| metadata_func=metadata_func | |
| ) | |
| data = loader.load() | |
| country_list = list(set([item.metadata['cuisine'] for item in data])) | |
| # Models | |
| # model_name = "Snowflake/snowflake-arctic-embed-xs" | |
| # rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") | |
| # Embedding | |
| # model_kwargs = {"device": "cpu"} | |
| # encode_kwargs = {"normalize_embeddings": True} | |
| # hf_embedding = HuggingFaceEmbeddings( | |
| # model_name=model_name, | |
| # encode_kwargs=encode_kwargs, | |
| # model_kwargs=model_kwargs, | |
| # show_progress=True | |
| # ) | |
| model_name = "BAAI/bge-small-en" | |
| model_kwargs = {"device": "cpu"} | |
| encode_kwargs = {"normalize_embeddings": True} | |
| hf_embedding = HuggingFaceBgeEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| qdrant = Qdrant.from_documents( | |
| data, | |
| hf_embedding, | |
| location=":memory:", # Local mode with in-memory storage only | |
| collection_name="my_documents", | |
| ) | |
| def format_to_markdown(response_list): | |
| response_list[0] = "- " + response_list[0] | |
| temp_string = "\n- ".join(response_list) | |
| return temp_string | |
| def run_query(query: str, groq: bool, countries: str = "None"): | |
| print("Running Query") | |
| if countries != "None": | |
| countries_select = models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="metadata.cuisine", # Adjust key based on your data structure | |
| match=models.MatchValue(value=countries), | |
| ) | |
| ] | |
| ) | |
| else: | |
| countries_select = None | |
| answer = qdrant.similarity_search( | |
| query=query, | |
| k=10, | |
| filter=countries_select | |
| ) | |
| title_and_description = f"# Best Choice:\nA {answer[0].metadata['title']}: {answer[0].page_content}" | |
| instructions = format_to_markdown(answer[0].metadata['instructions']) | |
| recipe = f"# Standard Method\n## Cooking time:\n{answer[0].metadata['time']}\n\n## Recipe:\n{instructions}" | |
| print("Returning query") | |
| if groq: | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": f"please write a more detailed recipe for the following recipe:\n{recipe}\n\n please return it in the same format.", | |
| } | |
| ], | |
| model="Llama3-70b-8192", | |
| ) | |
| groq_update = "# Groq Update\n"+chat_completion.choices[0].message.content | |
| else: | |
| groq_update = "# Groq Update \nPlease select the tick box if you need more information." | |
| return title_and_description, recipe, groq_update | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Start typing below and then click **Run** to see the output.") | |
| inp = gr.Textbox(placeholder="What sort of meal are you after?") | |
| dropdown = gr.Dropdown(['None'] + country_list, label='Filter on countries', value='None') | |
| groq_button = gr.Checkbox(value=False, label="Use Llama for a better recipe?") | |
| title_output = gr.Markdown(label="Title and description") | |
| instructions_output = gr.Markdown(label="Recipe") | |
| updated_recipe = gr.Markdown(label="Updated Recipe") | |
| btn = gr.Button("Run") | |
| btn.click(fn=run_query, inputs=[inp, groq_button, dropdown], outputs=[title_output, instructions_output, updated_recipe]) | |
| demo.launch() |