TYH71
qol: linting
2449b1f
"""CLIP interface module"""
# libraries
from typing import Dict, List, Union
from PIL import Image
# modules
from src.core.logger import logger
from src.model.clip import ClipModel
MODEL = ClipModel()
def clean_text(text: List[str]) -> List[str]:
"""function to clean gradio input text
Args:
text (str): string of comma separated text
Returns:
List[str]: list of cleaned text
"""
return list(map(lambda x: x.strip(), text))
def clip_demo_fn(image: Image.Image, text: Union[str, List[str]]) -> Dict[str, float]:
"""demo function for gradio interface
Args:
image (Image.Image): expects PIL image_
text (str): string of comma separated text
Returns:
Dict[str, float]: dictionary of text classes and its associated probability
"""
try:
logger.info("demo function invoked")
if isinstance(text, str):
text = clean_text(text.split(","))
if isinstance(text, list):
text = clean_text(text)
logger.debug("clean text: %s", text)
return MODEL(image, text)
finally:
logger.info("demo function completed")