raminass's picture
create app.py
f80cc50
raw
history blame
1.32 kB
import gradio as gr
from transformers import pipeline, TextClassificationPipeline
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True)
def average_text(text, model):
# result = classifier(df_train[(df_train.case_name==case) & (df_train.category=='per_curiam')]['clean_text'].to_list())
result = model(text)
pred = {}
for c in result:
for d in c:
if d['label'] not in pred:
pred[d['label']] = [round(d['score'],2)]
else:
pred[d['label']].append(round(d['score'],2))
sumary = {k:round(sum(v)/len(v),2) for k,v in pred.items()}
result = [[{k: round(v, 2) if k=='score' else v for k, v in dct.items()} for dct in lst ] for lst in result]
return dict(sorted(sumary.items(), key=lambda x: x[1],reverse=True)), result
def greet(opinion):
result = average_text(chunk_data(remove_citations(opinion))['text'].to_list(),pipe)
# print(f"average prediction:")
# display(result[0])
# print(f"paragraph prediction:")
# display(result[1])
return result[0]
with gr.Blocks() as demo:
opinion = gr.Textbox(label="Opinion")
output = gr.Textbox(label="Result")
greet_btn = gr.Button("Predict")
greet_btn.click(fn=greet, inputs=opinion, outputs=output, api_name="SCOTUS")
if __name__ == "__main__":
demo.launch()