AliArshad commited on
Commit
5f847ad
1 Parent(s): bf429df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -1,48 +1,49 @@
1
- import torch
2
  import requests
3
- from transformers import XLNetTokenizer
 
4
  import gradio as gr
5
 
6
- # Link to the saved model on Hugging Face Spaces
7
- model_link = 'https://huggingface.co/spaces/AliArshad/SeverityPrediction/blob/main/severitypredictor.pt'
8
 
9
- # Download the model file
10
- response = requests.get(model_link)
11
- model_path = 'severitypredictor.pt'
12
- with open(model_path, 'wb') as f:
13
- f.write(response.content)
14
 
15
- # Try loading the downloaded file as a PyTorch model
16
- try:
17
- model = torch.load(model_path)
18
- tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
19
 
20
- # Function for prediction
21
- def xl_net_predict(text):
22
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=100)
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- logits = outputs.logits
26
- probabilities = torch.softmax(logits, dim=1)
27
- predicted_class = torch.argmax(probabilities).item()
28
- return "Severe" if predicted_class == 1 else "Non-severe"
29
 
30
- # Customizing the interface
31
- iface = gr.Interface(
32
- fn=xl_net_predict,
33
- inputs=gr.Textbox(lines=2, label="Summary", placeholder="Enter text here..."),
34
- outputs=gr.Textbox(label="Predicted Severity"),
35
- title="XLNet Based Bug Report Severity Prediction",
36
- description="Enter text and predict its severity (Severe or Non-severe).",
37
- theme="huggingface",
38
- examples=[
39
- ["Can't open multiple bookmarks at once from the bookmarks sidebar using the context menu"],
40
- ["Minor enhancements to make-source-package.sh"]
41
- ],
42
- allow_flagging=False
43
- )
44
 
45
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- except Exception as e:
48
- print(f"An error occurred: {e}")
 
 
1
  import requests
2
+ import torch
3
+ from transformers import XLNetTokenizer, XLNetForSequenceClassification
4
  import gradio as gr
5
 
6
+ # URL of the saved model on GitHub
7
+ model_url = 'https://github.com/AliArshadswl/severity_prediction/raw/main/XLNet_model_project_Core.pt'
8
 
9
+ # Function to download the model from URL and load it
10
+ def download_model(url):
11
+ response = requests.get(url)
12
+ with open('XLNet_model_project_Core.pt', 'wb') as f:
13
+ f.write(response.content)
14
 
15
+ # Download the model
16
+ download_model(model_url)
 
 
17
 
18
+ # Load the saved model
19
+ tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
20
+ model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=2)
21
+ model.load_state_dict(torch.load('XLNet_model_project_Core.pt', map_location=torch.device('cpu')))
22
+ model.eval()
 
 
 
 
23
 
24
+ # Function for prediction
25
+ def xl_net_predict(text):
26
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=100)
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+ logits = outputs.logits
30
+ probabilities = torch.softmax(logits, dim=1)
31
+ predicted_class = torch.argmax(probabilities).item()
32
+ return "Severe" if predicted_class == 1 else "Non-severe"
 
 
 
 
 
33
 
34
+ # Customizing the interface
35
+ iface = gr.Interface(
36
+ fn=xl_net_predict,
37
+ inputs=gr.Textbox(lines=2, label="Summary", placeholder="Enter text here..."),
38
+ outputs=gr.Textbox(label="Predicted Severity"),
39
+ title="XLNet Based Bug Report Severity Prediction",
40
+ description="Enter text and predict its severity (Severe or Non-severe).",
41
+ theme="huggingface",
42
+ examples=[
43
+ ["Can't open multiple bookmarks at once from the bookmarks sidebar using the context menu"],
44
+ ["Minor enhancements to make-source-package.sh"]
45
+ ],
46
+ allow_flagging=False
47
+ )
48
 
49
+ iface.launch()