JRQi commited on
Commit
05e2529
1 Parent(s): 170a3e6

Update game3.py

Browse files
Files changed (1) hide show
  1. game3.py +5 -3
game3.py CHANGED
@@ -47,7 +47,8 @@ def func3(num_selected, human_predict, num1, num2, user_important):
47
  # Load model directly
48
  # Use a pipeline as a high-level helper
49
 
50
- classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification")
 
51
  output = classifier([text['text']])
52
 
53
  print(output)
@@ -153,7 +154,8 @@ def func3_written(text_written, human_predict, lang_written):
153
  # tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
154
  # model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
155
 
156
- classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification")
 
157
 
158
  output = classifier([text_written])
159
 
@@ -173,7 +175,7 @@ def func3_written(text_written, human_predict, lang_written):
173
 
174
  import shap
175
 
176
- gender_classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification", return_all_scores=True)
177
 
178
  explainer = shap.Explainer(gender_classifier)
179
 
 
47
  # Load model directly
48
  # Use a pipeline as a high-level helper
49
 
50
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
51
+ classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification", device=device)
52
  output = classifier([text['text']])
53
 
54
  print(output)
 
154
  # tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
155
  # model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
156
 
157
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
158
+ classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification", device=device)
159
 
160
  output = classifier([text_written])
161
 
 
175
 
176
  import shap
177
 
178
+ gender_classifier = pipeline("text-classification", model="padmajabfrl/Gender-Classification", return_all_scores=True, device=device)
179
 
180
  explainer = shap.Explainer(gender_classifier)
181