Spaces:
Runtime error
Runtime error
import streamlit as st | |
from huggingface_hub import InferenceClient, AsyncInferenceClient | |
from PIL import Image | |
from pathlib import Path | |
import os, subprocess | |
st.set_page_config(page_title='HG Inference Client Demo',layout="wide") | |
# Cache the header of the app to prevent re-rendering on each load | |
def display_app_header(): | |
"""Display the header of the Streamlit app.""" | |
st.title("1️⃣ HG Inference Client Demo 📊 ") | |
st.subheader("Just a little demontstrator") | |
# Display the header of the app | |
display_app_header() | |
# UI sidebar parameters #################################### | |
st.sidebar.header("Loging") | |
if hg_token :=st.sidebar.text_input('Enter your HG token'): | |
st.sidebar.info('Logged', icon="ℹ️") | |
else: | |
st.sidebar.warning("enter your token") | |
st.sidebar.header("Model") | |
selected_model = st.sidebar.radio( | |
"Choose a model or let the client do it", | |
["Not choose", "Choose"] | |
) | |
if selected_model == "Choose": | |
model = st.sidebar.text_input('Enter a model name. ex : facebook/fastspeech2-en-ljspeech') | |
else: | |
model = None | |
st.sidebar.header("Task") | |
dict_hg_tasks = { | |
"Automatic Speech Recognition":"automatic_speech_recognition", | |
"Text-to-Speech (choose model)":"text_to_speech", | |
"Image Classification":"image_classification", | |
"Image Segmentation":"image_segmentation", | |
"Object Detection":"object_detection", | |
"Text-to-Image":"text_to_image", | |
"Visual Question Answering":"visual_question_answering", | |
"Conversational":"conversational", | |
"Feature Extraction":"feature_extraction", | |
"Question Answering":"question_answering", | |
"Summarization":"summarization", | |
"Text Classification":"text_classification", | |
"Text Generation":"text_generation", | |
"Token Classification":"token_classification", | |
"Translation (choose model)":"translation", | |
} | |
dict_hg_tasks_params = { | |
"automatic_speech_recognition": { | |
"input": "upload,url", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"text_to_speech": { | |
"input": "text", | |
"output": "audio", | |
"prompt": False, | |
"context": False | |
}, | |
"image_classification": { | |
"input": "upload,url", | |
"output": "image,text", | |
"prompt": False, | |
"context": False | |
}, | |
"image_segmentation": { | |
"input": "upload,url", | |
"output": "image,text", | |
"prompt": False, | |
"context": False | |
}, | |
"object_detection": { | |
"input": "upload,url", | |
"output": "image,text", | |
"prompt": False, | |
"context": False | |
}, | |
"text_to_image": { | |
"input": "text", | |
"output": "image", | |
"prompt": False, | |
"context": False | |
}, | |
"visual_question_answering": { | |
"input": "upload,url", | |
"output": "image,text", | |
"prompt": True, | |
"context": False | |
}, | |
"image_to_image": { | |
"input": "upload,url", | |
"output": "image,text", | |
"prompt": True, | |
"context": False | |
}, | |
"feature_extraction": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"conversational": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"question_answering": { | |
"input": None, | |
"output": "text", | |
"prompt": True, | |
"context": True | |
}, | |
"text_classification": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"token_classification": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"text_generation": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"text_classification": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"translation": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
"summarization": { | |
"input": "text", | |
"output": "text", | |
"prompt": False, | |
"context": False | |
}, | |
} | |
selected_task = st.sidebar.radio( | |
"Choose the task you want to do", # see https://huggingface.co/docs/huggingface_hub/guides/inference" | |
dict_hg_tasks.keys() | |
) | |
st.write(f"The current selected task is : {dict_hg_tasks[selected_task]}") | |
with st.sidebar.expander("tasks documentation"): | |
st.write("https://huggingface.co/docs/huggingface_hub/package_reference/inference_client") | |
# functions ######################################## | |
cwd = os.getcwd() | |
def get_input(upload,url,text): | |
if upload is not None: | |
return upload | |
else: | |
if url: | |
return url | |
elif text: | |
return text | |
return None # Default return if neither upload nor url is provided | |
def display_inputs(task): | |
if dict_hg_tasks_params[task]["input"] == "upload,url": | |
return st.file_uploader("Choose a file"),st.text_input("or enter a file url"),"" | |
elif dict_hg_tasks_params[task]["input"] == "text": | |
return None,"",st.text_input("Enter a text") | |
else: | |
return None,"","" | |
def display_prompt(task): | |
if dict_hg_tasks_params[task]["prompt"] is True: | |
return st.text_input("Enter a question") | |
return None | |
def display_context(task): | |
if dict_hg_tasks_params[task]["context"] is True: | |
return st.text_area("Enter a context") | |
return None | |
# UI main client #################################### | |
if selected_task : | |
response = None | |
task = dict_hg_tasks[selected_task] | |
if model: | |
client = InferenceClient(model=model,token=hg_token) | |
else: | |
client = InferenceClient(token=hg_token) | |
uploaded_input,url_input,text_input = display_inputs(task) | |
prompt_input = display_prompt(task) | |
context_input = display_context(task) | |
if get_input(uploaded_input,url_input,text_input): | |
input = get_input(uploaded_input,url_input,text_input) | |
st.write(input) | |
response = getattr(client, task)(input) | |
elif prompt_input: | |
if context_input is not None: | |
response = getattr(client, task)(question=prompt_input,context=context_input) | |
else: | |
response = getattr(client, task)(input,prompt=prompt_input) | |
if response is not None: | |
col1,col2 = st.columns(2) | |
with col1: | |
if "text" in dict_hg_tasks_params[task]["output"]: | |
st.write(response) | |
elif "audio" in dict_hg_tasks_params[task]["output"]: | |
Path(os.path.join(cwd,"audio.flac")).write_bytes(response) | |
st.audio(os.path.join(cwd,"audio.flac")) | |
with col2: | |
if dict_hg_tasks_params[task]["output"] == "image,text": | |
image = Image.open(input) | |
st.image(image) | |
elif dict_hg_tasks_params[task]["output"] == "image": | |
response.save(os.path.join(cwd,"generated_image.png")) | |
image = Image.open(os.path.join(cwd,"generated_image.png")) | |
st.image(image) | |