gianma's picture
Update app.py
9893cb3
raw
history blame
1.13 kB
import gradio as gr
from os import environ
from transformers import pipeline
api_key = environ.get("api_key")
def app(input, filter_strategy, relevence_threshold, k ):
classifier_pipeline = pipeline(model="gianma/testModel", tokenizer='gianma/testModel', use_auth_token=api_key)
kwargs = {'padding':True,'truncation':True,'max_length':512}
if filter_strategy=='top_k':
kwargs['top_k'] = k
res = classifier_pipeline(input, **kwargs)
res = {el['label']:el['score'] for el in res} # convert to expected format for gradio interface
if filter_strategy == 'soglia_di_confidenza':
res = {k:v for k,v in res.items() if k >= relevence_threshold}
return res
iface = gr.Interface(fn=app,
inputs=["text", gr.Radio(['soglia_di_confidenza', 'top_k'], value='soglia_di_confidenza', label='come vuoi fiiltrare i risultati'), gr.Slider(0, 1, value=0.5,label='soglia di confidenza'), gr.Slider(1, 21, step= 1,, value=4, label='k')],
outputs="label")
iface.launch()
#gr.Interface.load("models/gianma/testModel", api_key=api_key).launch()