Spaces:
Runtime error
Runtime error
import torch | |
from transformers import XLNetTokenizer, XLNetForSequenceClassification | |
import gradio as gr | |
from pydrive.auth import GoogleAuth | |
from pydrive.drive import GoogleDrive | |
# Authenticate and create GoogleDrive instance | |
gauth = GoogleAuth() | |
gauth.LocalWebserverAuth() | |
drive = GoogleDrive(gauth) | |
# ID of the file in Google Drive | |
file_id = '1-7O5gAFgcIzgJ68WkSSpmh1H6kJL6fAO' # Replace this with your file's ID from Google Drive | |
destination_path = '/content/XLNet_model_project_Core.pt' # Path to save the downloaded model file | |
# Download the model file from Google Drive | |
downloaded_file = drive.CreateFile({'id': file_id}) | |
downloaded_file.GetContentFile(destination_path) | |
# Load the saved model | |
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') | |
model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=2) | |
model.load_state_dict(torch.load(destination_path)) | |
model.eval() | |
# Function for prediction | |
def xl_net_predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=100) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1) | |
predicted_class = torch.argmax(probabilities).item() | |
return "Severe" if predicted_class == 1 else "Non-severe" | |
# Customizing the interface | |
iface = gr.Interface( | |
fn=xl_net_predict, | |
inputs=gr.Textbox(lines=2, label="Summary", placeholder="Enter text here..."), | |
outputs=gr.Textbox(label="Predicted Severity"), | |
title="XLNet Based Bug Report Severity Prediction", | |
description="Enter text and predict its severity (Severe or Non-severe).", | |
theme="huggingface", | |
examples=[ | |
["Can't open multiple bookmarks at once from the bookmarks sidebar using the context menu"], | |
["Minor enhancements to make-source-package.sh"] | |
], | |
allow_flagging=False | |
) | |
iface.launch() | |