Spaces:
Sleeping
Sleeping
import json | |
import os | |
from dotenv import load_dotenv | |
from duckduckgo_search import DDGS | |
from langchain_core.messages.tool import BaseMessage, ToolMessage | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.tools import tool | |
from langgraph.graph import END, MessageGraph | |
from langgraph.prebuilt import ToolNode | |
from typing import TypedDict | |
from llm import get_text_llm | |
from log_util import logger | |
from time_it import time_it | |
from util import load_prompt | |
load_dotenv() | |
MAX_IMAGE_SEARCH_RESULTS = int(os.getenv('MAX_IMAGE_SEARCH_RESULTS', '3')) | |
class ImageSearchResult(TypedDict): | |
title: str | |
url: str | |
def search_meal_image(meal: str) -> str: | |
prompt = load_prompt('validate_is_meal.prompt.txt') | |
llm = get_text_llm() | |
tools = [search_meal_images] | |
def is_meal_router(messages: list[BaseMessage]) -> str: | |
if messages[-1].content.lower() == 'yes': | |
return 'is_meal' | |
return END | |
graph = MessageGraph() | |
graph.add_node('validate_is_meal', llm) | |
graph.add_conditional_edges('validate_is_meal', is_meal_router) | |
graph.add_node('is_meal', llm.bind_tools(tools)) | |
graph.add_edge('is_meal', 'call_tools') | |
graph.add_node('call_tools', ToolNode(tools)) | |
graph.add_edge('call_tools', END) | |
graph.set_entry_point('validate_is_meal') | |
prompt_template = PromptTemplate.from_template(prompt) | |
prompt = prompt_template.format(phrase=meal) | |
workflow = graph.compile() | |
messages: list = workflow.invoke(prompt) | |
tool_messages = [message for message in messages if isinstance(message, ToolMessage)] | |
if tool_messages and tool_messages[0].content: | |
meal_images: list[ImageSearchResult] = json.loads(tool_messages[0].content) | |
if meal_images: | |
meal_image_url = meal_images[0]['url'] | |
logger.info(f'{meal_image_url=}') | |
return meal_image_url | |
return None | |
def search_meal_images(meal: str) -> list[ImageSearchResult]: | |
'''Searches for images of the given meal.''' | |
return search_images(meal) | |
def search_images(keywords: str, max_results: int | None=MAX_IMAGE_SEARCH_RESULTS) -> list[ImageSearchResult]: | |
results = DDGS().images( | |
keywords=keywords, | |
region='wt-wt', | |
safesearch='on', | |
size=None, | |
color='color', | |
type_image='photo', | |
layout=None, | |
license_image=None, | |
max_results=max_results, | |
) | |
logger.info(f'{keywords=}: {results=}') | |
results = [ImageSearchResult(title=result['title'], url=result['image']) for result in results] | |
return results | |