HealthAI-Chef / meal_image_search.py
kikomiko's picture
init
96f6720
import json
import os
from dotenv import load_dotenv
from duckduckgo_search import DDGS
from langchain_core.messages.tool import BaseMessage, ToolMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langgraph.graph import END, MessageGraph
from langgraph.prebuilt import ToolNode
from typing import TypedDict
from llm import get_text_llm
from log_util import logger
from time_it import time_it
from util import load_prompt
load_dotenv()
MAX_IMAGE_SEARCH_RESULTS = int(os.getenv('MAX_IMAGE_SEARCH_RESULTS', '3'))
class ImageSearchResult(TypedDict):
title: str
url: str
@time_it
def search_meal_image(meal: str) -> str:
prompt = load_prompt('validate_is_meal.prompt.txt')
llm = get_text_llm()
tools = [search_meal_images]
def is_meal_router(messages: list[BaseMessage]) -> str:
if messages[-1].content.lower() == 'yes':
return 'is_meal'
return END
graph = MessageGraph()
graph.add_node('validate_is_meal', llm)
graph.add_conditional_edges('validate_is_meal', is_meal_router)
graph.add_node('is_meal', llm.bind_tools(tools))
graph.add_edge('is_meal', 'call_tools')
graph.add_node('call_tools', ToolNode(tools))
graph.add_edge('call_tools', END)
graph.set_entry_point('validate_is_meal')
prompt_template = PromptTemplate.from_template(prompt)
prompt = prompt_template.format(phrase=meal)
workflow = graph.compile()
messages: list = workflow.invoke(prompt)
tool_messages = [message for message in messages if isinstance(message, ToolMessage)]
if tool_messages and tool_messages[0].content:
meal_images: list[ImageSearchResult] = json.loads(tool_messages[0].content)
if meal_images:
meal_image_url = meal_images[0]['url']
logger.info(f'{meal_image_url=}')
return meal_image_url
return None
@tool
def search_meal_images(meal: str) -> list[ImageSearchResult]:
'''Searches for images of the given meal.'''
return search_images(meal)
@time_it
def search_images(keywords: str, max_results: int | None=MAX_IMAGE_SEARCH_RESULTS) -> list[ImageSearchResult]:
results = DDGS().images(
keywords=keywords,
region='wt-wt',
safesearch='on',
size=None,
color='color',
type_image='photo',
layout=None,
license_image=None,
max_results=max_results,
)
logger.info(f'{keywords=}: {results=}')
results = [ImageSearchResult(title=result['title'], url=result['image']) for result in results]
return results