nlp-cirlce-demo / src /nlp_circle_demo /wrapped_gradio_objects.py
JOELDSL's picture
Update src/nlp_circle_demo/wrapped_gradio_objects.py
6cf00e4
raw
history blame
2.66 kB
import gradio as gr
import yaml
class WrappedGradioObject:
def __init__(self, title, gradio_element):
self.title = title
self.gradio_element = gradio_element
def launch(self):
return self.gradio_element.launch()
@staticmethod
def read_yaml(path):
with open(path) as f:
return yaml.safe_load(f)
class GradioInterfaceWrapper(WrappedGradioObject):
@classmethod
def from_yaml(cls, path):
"""Initializes Interface from YAML file."""
content_dict = cls.read_yaml(path)
return cls.create_interface(**content_dict)
@classmethod
def create_interface(cls, name, title, description, examples=None):
"""Creates Gradio-Element containing an interface."""
description = cls._prepend_link_to_description(name, title, description)
interface = gr.Interface.load(
name,
title=None, # Having the Tab-Name is sufficient.
description=description,
examples=examples,
)
return cls(title, interface)
@staticmethod
def _prepend_link_to_description(name, title, description):
without_huggingface = name.removeprefix("huggingface/")
link = f"https://huggingface.co/{without_huggingface}"
return f'<a href="{link}">{title}</a> </br> {description}'
class GradioTabWrapper(WrappedGradioObject):
@classmethod
def from_gradio_object_list(cls, title, gradio_objects):
"""Constructs a GradioTabWrapper from a title and a list of WrappedGradioObjects."""
interface = gr.TabbedInterface(
[obj.gradio_element for obj in gradio_objects],
[obj.title for obj in gradio_objects],
)
return cls(title, interface)
@classmethod
def from_yaml(cls, path):
content_dict = cls.read_yaml(path)
gradio_objects = [
cls._read_dependency(dependency)
for dependency in content_dict["dependencies"]
]
return cls.from_gradio_object_list(
content_dict["title"],
gradio_objects,
)
@staticmethod
def _read_dependency(path):
full_path = f"resources/{path}"
if path.startswith("interfaces"):
return GradioInterfaceWrapper.from_yaml(full_path)
if path.startswith("tabs"):
return GradioTabWrapper.from_yaml(full_path)
raise ValueError(
"Gradio Object Type could not be inferred from path name. Make sure "
"that all interface object yamls are in resources/interfaces, and that "
"all tab object yamls are in resources/tabs."
)