Pooya-Fallah's picture
Create app.py
27517fc verified
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import transformers
import hazm
import gradio as gr
# Define class of the model
class ParsbertHallu(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
self.transformer_model = transformers.AutoModelForSequenceClassification.from_pretrained("Pooya-Fallah/ParsBERT-nli-FarsTail-FarSick",
num_labels=3)
self.head = nn.Sequential(
nn.Linear(3,1),
nn.Sigmoid()
)
def forward(self, x):
out = self.transformer_model(**x)['logits']
return torch.squeeze(self.head(out))
# Example Inputs
example_1 = [
"""هنگامی که به تناسب اندام فکر می کنید، مصرف غذاهای پروتئین دار قبل و بعد از بدنسازی می تواند پروتئین لازم را برای کمک به رشد عضلات، فراهم کند. در این مطلب، بهترین منابع غذاهای پروتئین دار و میوه‌ های غنی از پروتئین را معرفی می کنیم. تمرینات بدنسازی معمولا سنگین بوده و نیازمند زمان قابل توجهی از روز هستند. اما برای بدنسازی پیش از هر کاری بهتر است، فهرستی از غذاهای پروتئین دار قبل و بعد از تمرین تهیه کنید. علاوه بر یک رژیم ورزشی خوب، رژیم غذایی مناسب نیز برای رسیدن به اندامی دلخواه، لازم است. در هر حال در بدنسازی هم رژیم غذایی و غذاهای پروتئین دار قبل و بعد از تمرین، امر مهمی است."""
,
"""چه ماده غذایی برای بدن سازی مهم است؟"""
,
"""مهم‌ترین ماده غذایی برای بدن سازی، رنگین‌کمان‌های جادویی هستند که باعث افزایش عضلات و کاهش چربی می‌شوند!"""
]
example_2 = [
"""شمشیر بازی نوعی رشته ورزشی است که دو ورزشکار در آن با استفاده از یک شمشیر کوچک با هم مبارزه می‌کنند. شمشیربازی یکی از پنج رشته ورزشی است که در همه دوره‌های بازیهای المپیک برگزار شده‌است (چهار رشته دیگر دو و میدانی، شنا، ژیمناستیک و دوچرخه‌سواری هستند). مسابقات این رشته در سه بخش فلوره، اپه و سابر برگزار می‌شود که از نظر نوع شمشیر مورد استفاده و قوانین بازی با یکدیگر تفاوت دارند."""
,
"""شمشیرهای رشته شمشیربازی؟"""
,
"""فلوره، اپه و سابر"""
]
# Hazm normalizer
normalizer = hazm.Normalizer()
# tokenizer is from ParsBERT (HooshvareLab/bert-fa-zwnj-base)
tokenizer = transformers.AutoTokenizer.from_pretrained('HooshvareLab/bert-fa-zwnj-base')
# load model
model = ParsbertHallu.from_pretrained("Pooya-Fallah/ParsbertHallu")
def get_hallucination_label(knowledge, question, answer):
knowledge = normalizer.normalize(knowledge)
question = normalizer.normalize(question)
answer = normalizer.normalize(answer)
tokens = tokenizer(knowledge, question + " " + answer, truncation=True, padding=True,
max_length=512, return_tensors='pt')
prob = round(model(tokens).item(), 2)
return {"Hallucinated": prob, "Not-Hallucinated": 1-prob}
demo = gr.Interface(fn=get_hallucination_label, inputs=[gr.TextArea(lines=7, placeholder="knowledge"), gr.TextArea(lines=3, placeholder="question"), gr.TextArea(lines=3, placeholder="answer")],
outputs=gr.Label(num_top_classes=2), examples=[example_1, example_2],
title="Hallucination Detection Demo for Persian Question Answering Task",
description="A straightforward binary classifier that determines whether the generated answer is hallucinated or not."
)
demo.launch()