Mikeplockhart's picture
Update app.py
36d4625 verified
raw
history blame
No virus
3.87 kB
import gradio as gr
import os
from langchain_community.document_loaders import JSONLoader
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
from sentence_transformers.cross_encoder import CrossEncoder
from groq import Groq
client = Groq(
api_key=os.environ.get("GROQ_API"),
)
# loading data
json_path = "format_food.json"
json_path = "llama70b_food_dump.json"
def metadata_func(record: dict, metadata: dict) -> dict:
metadata["title"] = record.get("title")
metadata["cuisine"] = record.get("cuisine")
metadata["time"] = record.get("time")
metadata["instructions"] = record.get("instructions")
return metadata
def reranking_results(query, top_k_results, rerank_model):
# Load the model, here we use our base sized model
top_results_formatted = [f"{item.metadata['title']}, {item.page_content}" for item in top_k_results]
reranked_results = rerank_model.rank(query, top_results_formatted, return_documents=True)
return reranked_results
loader = JSONLoader(
file_path=json_path,
jq_schema='.dishes[].dish',
text_content=False,
content_key='doc',
metadata_func=metadata_func
)
data = loader.load()
# Models
# model_name = "Snowflake/snowflake-arctic-embed-xs"
# rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
# Embedding
# model_kwargs = {"device": "cpu"}
# encode_kwargs = {"normalize_embeddings": True}
# hf_embedding = HuggingFaceEmbeddings(
# model_name=model_name,
# encode_kwargs=encode_kwargs,
# model_kwargs=model_kwargs,
# show_progress=True
# )
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf_embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
qdrant = Qdrant.from_documents(
data,
hf_embedding,
location=":memory:", # Local mode with in-memory storage only
collection_name="my_documents",
)
def format_to_markdown(response_list):
response_list[0] = "- " + response_list[0]
temp_string = "\n- ".join(response_list)
return temp_string
def run_query(query: str, groq: bool):
print("Running Query")
answer = qdrant.similarity_search(query=query, k=10)
title_and_description = f"# Best Choice:\nA {answer[0].metadata['title']}: {answer[0].page_content}"
instructions = format_to_markdown(answer[0].metadata['instructions'])
recipe = f"# Standard Method\n## Cooking time:\n{answer[0].metadata['time']}\n\n## Recipe:\n{instructions}"
print("Returning query")
if groq:
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": f"please write a more detailed recipe for the following recipe:\n{recipe}\n\n please return it in the same format.",
}
],
model="Llama3-70b-8192",
)
groq_update = "# Groq Update\n"+chat_completion.choices[0].message.content
else:
groq_update = "# Groq Update \nPlease select the tick box if you need more information."
return title_and_description, recipe, groq_update
with gr.Blocks() as demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
inp = gr.Textbox(placeholder="What sort of meal are you after?")
groq_button = gr.Checkbox(value=False, label="Use Llama for a better recipe?")
title_output = gr.Markdown(label="Title and description")
instructions_output = gr.Markdown(label="Recipe")
updated_recipe = gr.Markdown(label="Updated Recipe")
btn = gr.Button("Run")
btn.click(fn=run_query, inputs=[inp, groq_button], outputs=[title_output, instructions_output, updated_recipe])
demo.launch()