Spaces:
Sleeping
Sleeping
import base64 | |
from langchain_core.messages import HumanMessage | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
from image_generator import generate_image | |
from llm import get_text_llm, get_vision_llm, invoke_llm_async | |
from time_it import time_it_async | |
from util import UTF8_ENCODING, load_prompt | |
async def get_recipe_from_image(image_path: str) -> dict: | |
recipe: dict = await _analyze_image_from_path(image_path, 'get_recipe_from_image.prompt.txt') | |
return recipe | |
async def get_altered_recipe(orig_recipe: dict, restrictions: list[str], diagnoses: list[str]) -> dict: | |
recipe: dict = await _generate_text('alter_recipe_for_healthier.prompt.txt', | |
{'recipe': orig_recipe, 'dietary_restrictions': restrictions, 'medical_diagnoses': diagnoses}) | |
if 'recipe' in recipe: | |
recipe = recipe['recipe'] | |
return recipe | |
async def _analyze_image_from_path(image_path: str, prompt_file: str) -> dict | str: | |
with open(image_path, 'rb') as image_file: | |
image_data = image_file.read() | |
image_format = _get_image_format(image_path) | |
return await _analyze_image_from_bytes(image_data, prompt_file, image_format) | |
async def _analyze_image_from_bytes(image_data: bytes, prompt_file: str, image_format: str, output_parser_type=JsonOutputParser) -> dict | str: | |
img_base64 = base64.b64encode(image_data).decode(UTF8_ENCODING) | |
prompt = load_prompt(prompt_file) | |
prompt_template = ChatPromptTemplate.from_messages([ | |
HumanMessage( | |
content=[ | |
{'type': 'text', 'text': prompt}, | |
{ | |
'type': 'image_url', | |
'image_url': { | |
'url': f'data:image/{image_format};base64,{img_base64}', | |
'detail': 'high' | |
} | |
} | |
] | |
), | |
]) | |
llm = get_vision_llm() | |
chain = prompt_template | llm | output_parser_type() | |
response = await invoke_llm_async(chain) | |
return response | |
def _get_image_format(image_path: str) -> str: | |
file_ext = image_path.split('.')[-1].lower() | |
match file_ext: | |
case 'jpg' | 'jpeg': | |
return 'jpeg' | |
case 'png' | 'gif' | 'webp': | |
return file_ext | |
case _: | |
raise ValueError(f'Unsupported image format for {image_path=}') | |
async def _generate_text(prompt_file: str, input: dict) -> dict | str: | |
prompt = load_prompt(prompt_file) | |
prompt_template = PromptTemplate.from_template(prompt) | |
llm = get_text_llm() | |
chain = prompt_template | llm | JsonOutputParser() | |
response = await invoke_llm_async(chain, input) | |
return response | |
def get_image_from_recipe(recipe: dict) -> str: | |
recipe_for_image_gen = {k: v for k, v in recipe.items() if k in {'name', 'ingredients', 'instructions', 'meal_type', 'serves'}} | |
image_url = generate_image('get_image_from_recipe.prompt.txt', {'recipe': recipe_for_image_gen}) | |
return image_url | |