DexterSptizu commited on
Commit
543559a
1 Parent(s): d8c8393

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -1,18 +1,33 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
-
5
- # Load pre-trained model and tokenizer
6
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
7
- model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
8
-
9
- def classify_text(text):
10
- inputs = tokenizer(text, return_tensors="pt")
11
- with torch.no_grad():
12
- logits = model(**inputs).logits
13
- predicted_class_id = logits.argmax().item()
14
- return model.config.id2label[predicted_class_id]
15
-
16
- # Create Gradio interface
17
- interface = gr.Interface(fn=classify_text, inputs="text", outputs="label")
18
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+
5
+ # Load pre-trained model and tokenizer
6
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
7
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
8
+
9
+ def classify_text(text):
10
+ inputs = tokenizer(text, return_tensors="pt")
11
+ with torch.no_grad():
12
+ logits = model(**inputs).logits
13
+ predicted_class_id = logits.argmax().item()
14
+ return model.config.id2label[predicted_class_id]
15
+
16
+ def try_launch(interface, port, max_attempts=5):
17
+ current_port = port
18
+ attempt = 0
19
+ while attempt < max_attempts:
20
+ try:
21
+ interface.launch(server_port=current_port)
22
+ print(f"Gradio running on http://localhost:{current_port}")
23
+ break
24
+ except OSError as e:
25
+ print(f"Port {current_port} is in use, trying next port.")
26
+ current_port += 1
27
+ attempt += 1
28
+ else:
29
+ print("Failed to find an open port.")
30
+
31
+ # Create Gradio interface
32
+ interface = gr.Interface(fn=classify_text, inputs="text", outputs="label")
33
+ try_launch(interface, 7861)