NeuroSpaceX commited on
Commit
3230342
·
verified ·
1 Parent(s): 2e41a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
 
5
  def load_model(model_name):
 
 
 
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to(device).eval()
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
9
  return model, tokenizer, device
10
 
11
  def classify_message(model, tokenizer, device, message, model_name):
@@ -26,7 +32,12 @@ def classify_message(model, tokenizer, device, message, model_name):
26
  return "Спам" if is_spam else "Не спам"
27
 
28
  def spam_classifier_interface(message, model_choice):
29
- model_name = 'NeuroSpaceX/ruSpamNS_v1' if model_choice == "Model v1" else 'NeuroSpaceX/ruSpamNS_v6'
 
 
 
 
 
30
  model, tokenizer, device = load_model(model_name)
31
  return classify_message(model, tokenizer, device, message, model_name)
32
 
@@ -35,7 +46,7 @@ interface = gr.Interface(
35
  fn=spam_classifier_interface,
36
  inputs=[
37
  gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст..."),
38
- gr.Radio(["Model v1", "Model v6"], label="Выберите модель")
39
  ],
40
  outputs=gr.Textbox(label="Результат"),
41
  title="Классификатор Спам/Не Спам",
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import os
5
 
6
  def load_model(model_name):
7
+ token = os.getenv("HG_TOKEN")
8
+ if not token:
9
+ raise ValueError("Hugging Face API token not found. Please set HG_TOKEN environment variable.")
10
+
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ # Load model with authentication token if necessary
13
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=token).to(device).eval()
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
15
  return model, tokenizer, device
16
 
17
  def classify_message(model, tokenizer, device, message, model_name):
 
32
  return "Спам" if is_spam else "Не спам"
33
 
34
  def spam_classifier_interface(message, model_choice):
35
+ # Choose model based on user's choice
36
+ model_name = {
37
+ "Model v1": 'NeuroSpaceX/ruSpamNS_v1',
38
+ "Model v6": 'NeuroSpaceX/ruSpamNS_v6',
39
+ "Model v7beta": 'NeuroSpaceX/ruSpamNS_v7'
40
+ }[model_choice]
41
  model, tokenizer, device = load_model(model_name)
42
  return classify_message(model, tokenizer, device, message, model_name)
43
 
 
46
  fn=spam_classifier_interface,
47
  inputs=[
48
  gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст..."),
49
+ gr.Radio(["Model v1", "Model v6", "Model v7beta"], label="Выберите модель")
50
  ],
51
  outputs=gr.Textbox(label="Результат"),
52
  title="Классификатор Спам/Не Спам",