import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin import transformers import hazm import gradio as gr # Define class of the model class ParsbertHallu(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() self.transformer_model = transformers.AutoModelForSequenceClassification.from_pretrained("Pooya-Fallah/ParsBERT-nli-FarsTail-FarSick", num_labels=3) self.head = nn.Sequential( nn.Linear(3,1), nn.Sigmoid() ) def forward(self, x): out = self.transformer_model(**x)['logits'] return torch.squeeze(self.head(out)) # Example Inputs example_1 = [ """هنگامی که به تناسب اندام فکر می کنید، مصرف غذاهای پروتئین دار قبل و بعد از بدنسازی می تواند پروتئین لازم را برای کمک به رشد عضلات، فراهم کند. در این مطلب، بهترین منابع غذاهای پروتئین دار و میوه‌ های غنی از پروتئین را معرفی می کنیم. تمرینات بدنسازی معمولا سنگین بوده و نیازمند زمان قابل توجهی از روز هستند. اما برای بدنسازی پیش از هر کاری بهتر است، فهرستی از غذاهای پروتئین دار قبل و بعد از تمرین تهیه کنید. علاوه بر یک رژیم ورزشی خوب، رژیم غذایی مناسب نیز برای رسیدن به اندامی دلخواه، لازم است. در هر حال در بدنسازی هم رژیم غذایی و غذاهای پروتئین دار قبل و بعد از تمرین، امر مهمی است.""" , """چه ماده غذایی برای بدن سازی مهم است؟""" , """مهم‌ترین ماده غذایی برای بدن سازی، رنگین‌کمان‌های جادویی هستند که باعث افزایش عضلات و کاهش چربی می‌شوند!""" ] example_2 = [ """شمشیر بازی نوعی رشته ورزشی است که دو ورزشکار در آن با استفاده از یک شمشیر کوچک با هم مبارزه می‌کنند. شمشیربازی یکی از پنج رشته ورزشی است که در همه دوره‌های بازیهای المپیک برگزار شده‌است (چهار رشته دیگر دو و میدانی، شنا، ژیمناستیک و دوچرخه‌سواری هستند). مسابقات این رشته در سه بخش فلوره، اپه و سابر برگزار می‌شود که از نظر نوع شمشیر مورد استفاده و قوانین بازی با یکدیگر تفاوت دارند.""" , """شمشیرهای رشته شمشیربازی؟""" , """فلوره، اپه و سابر""" ] # Hazm normalizer normalizer = hazm.Normalizer() # tokenizer is from ParsBERT (HooshvareLab/bert-fa-zwnj-base) tokenizer = transformers.AutoTokenizer.from_pretrained('HooshvareLab/bert-fa-zwnj-base') # load model model = ParsbertHallu.from_pretrained("Pooya-Fallah/ParsbertHallu") def get_hallucination_label(knowledge, question, answer): knowledge = normalizer.normalize(knowledge) question = normalizer.normalize(question) answer = normalizer.normalize(answer) tokens = tokenizer(knowledge, question + " " + answer, truncation=True, padding=True, max_length=512, return_tensors='pt') prob = round(model(tokens).item(), 2) return {"Hallucinated": prob, "Not-Hallucinated": 1-prob} demo = gr.Interface(fn=get_hallucination_label, inputs=[gr.TextArea(lines=7, placeholder="knowledge"), gr.TextArea(lines=3, placeholder="question"), gr.TextArea(lines=3, placeholder="answer")], outputs=gr.Label(num_top_classes=2), examples=[example_1, example_2], title="Hallucination Detection Demo for Persian Question Answering Task", description="A straightforward binary classifier that determines whether the generated answer is hallucinated or not." ) demo.launch()