File size: 1,760 Bytes
91785e6
 
2d6504a
91785e6
 
2d6504a
91785e6
2d6504a
91785e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d6504a
91785e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pandas as pd
import numpy as np
import gradio as gr
import torch
from transformers import AutoModelForMultipleChoice, AutoTokenizer

model_id = "deepset/deberta-v3-large-squad2"

# Load the model and tokenizer
model = AutoModelForMultipleChoice.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Define the preprocessing function
def preprocess(sample):
    first_sentences = [sample["prompt"]] * 5
    second_sentences = [sample[option] for option in "ABCDE"]
    tokenized_sentences = tokenizer(first_sentences, second_sentences, truncation=True, padding=True, return_tensors="pt")
    sample["input_ids"] = tokenized_sentences["input_ids"]
    sample["attention_mask"] = tokenized_sentences["attention_mask"]
    return sample

# Define the prediction function
def predict(data):
    inputs = torch.stack(data["input_ids"])
    masks = torch.stack(data["attention_mask"])
    with torch.no_grad():
        logits = model(inputs, attention_mask=masks).logits
    predictions_as_ids = torch.argsort(-logits, dim=1)
    answers = np.array(list("ABCDE"))[predictions_as_ids.tolist()]
    return ["".join(i) for i in answers[:, :3]]
text=gr.Textbox(placeholder="paste multiple choice questions.....")
label=gr.Label(num_top_classes=3)
# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=text  # Use the correct class with type="json"
    outputs=label,
    live=True,
    examples=[
        {"prompt": "This is the prompt", "A": "Option A text", "B": "Option B text", "C": "Option C text", "D": "Option D text", "E": "Option E text"}
    ],
    title="LLM Science Exam Demo",
    description="Enter the prompt and options (A to E) below and get predictions.",
)

# Run the interface
iface.launch()