|
import gradio as gr |
|
import os |
|
from langchain_community.document_loaders import JSONLoader |
|
from langchain_community.vectorstores import Qdrant |
|
from qdrant_client.http import models as rest |
|
from qdrant_client import QdrantClient, models |
|
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings |
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
from groq import Groq |
|
|
|
client = Groq( |
|
api_key=os.environ.get("GROQ_API"), |
|
) |
|
|
|
|
|
json_path = "format_food.json" |
|
|
|
def metadata_func(record: dict, metadata: dict) -> dict: |
|
metadata["title"] = record.get("title") |
|
metadata["cuisine"] = record.get("cuisine") |
|
metadata["time"] = record.get("time") |
|
metadata["instructions"] = record.get("instructions") |
|
return metadata |
|
|
|
def reranking_results(query, top_k_results, rerank_model): |
|
|
|
top_results_formatted = [f"{item.metadata['title']}, {item.page_content}" for item in top_k_results] |
|
reranked_results = rerank_model.rank(query, top_results_formatted, return_documents=True) |
|
return reranked_results |
|
|
|
|
|
loader = JSONLoader( |
|
file_path=json_path, |
|
jq_schema='.dishes[].dish', |
|
text_content=False, |
|
content_key='doc', |
|
metadata_func=metadata_func |
|
) |
|
|
|
data = loader.load() |
|
country_list = list(set([item.metadata['cuisine'] for item in data])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "BAAI/bge-small-en" |
|
model_kwargs = {"device": "cpu"} |
|
encode_kwargs = {"normalize_embeddings": True} |
|
hf_embedding = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
qdrant = Qdrant.from_documents( |
|
data, |
|
hf_embedding, |
|
location=":memory:", |
|
collection_name="my_documents", |
|
) |
|
|
|
def format_to_markdown(response_list): |
|
response_list[0] = "- " + response_list[0] |
|
temp_string = "\n- ".join(response_list) |
|
return temp_string |
|
|
|
def run_query(query: str, groq: bool, countries: str = "None"): |
|
print("Running Query") |
|
if countries != "None": |
|
countries_select = models.Filter( |
|
must=[ |
|
models.FieldCondition( |
|
key="metadata.cuisine", |
|
match=models.MatchValue(value=countries), |
|
) |
|
] |
|
) |
|
else: |
|
countries_select = None |
|
|
|
answer = qdrant.similarity_search( |
|
query=query, |
|
k=10, |
|
filter=countries_select |
|
) |
|
title_and_description = f"# Best Choice:\nA {answer[0].metadata['title']}: {answer[0].page_content}" |
|
instructions = format_to_markdown(answer[0].metadata['instructions']) |
|
recipe = f"# Standard Method\n## Cooking time:\n{answer[0].metadata['time']}\n\n## Recipe:\n{instructions}" |
|
print("Returning query") |
|
if groq: |
|
chat_completion = client.chat.completions.create( |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": f"please write a more detailed recipe for the following recipe:\n{recipe}\n\n please return it in the same format.", |
|
} |
|
], |
|
model="Llama3-70b-8192", |
|
) |
|
groq_update = "# Groq Update\n"+chat_completion.choices[0].message.content |
|
else: |
|
groq_update = "# Groq Update \nPlease select the tick box if you need more information." |
|
return title_and_description, recipe, groq_update |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Start typing below and then click **Run** to see the output.") |
|
inp = gr.Textbox(placeholder="What sort of meal are you after?") |
|
dropdown = gr.Dropdown(['None'] + country_list, label='Filter on countries', value='None') |
|
groq_button = gr.Checkbox(value=False, label="Use Llama for a better recipe?") |
|
title_output = gr.Markdown(label="Title and description") |
|
instructions_output = gr.Markdown(label="Recipe") |
|
updated_recipe = gr.Markdown(label="Updated Recipe") |
|
btn = gr.Button("Run") |
|
btn.click(fn=run_query, inputs=[inp, groq_button, dropdown], outputs=[title_output, instructions_output, updated_recipe]) |
|
|
|
demo.launch() |