Annalyn Ng commited on
Commit
600139e
1 Parent(s): e8d5985

enable GPU use

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -5,11 +5,13 @@ import plotly.express as px
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForMaskedLM
7
 
 
8
 
9
  model_checkpoint = "facebook/xlm-v-base"
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
 
13
  mask_token = tokenizer.mask_token
14
 
15
 
@@ -27,7 +29,7 @@ def eval_prob(target_word, text):
27
  target_idx = tokenizer.encode(target_word)[-2]
28
 
29
  # Convert masked text to token IDs
30
- inputs = tokenizer(text_masked, return_tensors="pt")
31
 
32
  # Calculate logits score (for each token, for each position)
33
  token_logits = model(**inputs).logits
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForMaskedLM
7
 
8
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
 
10
  model_checkpoint = "facebook/xlm-v-base"
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
14
+ model = model.to(device)
15
  mask_token = tokenizer.mask_token
16
 
17
 
 
29
  target_idx = tokenizer.encode(target_word)[-2]
30
 
31
  # Convert masked text to token IDs
32
+ inputs = tokenizer(text_masked, return_tensors="pt").to(device)
33
 
34
  # Calculate logits score (for each token, for each position)
35
  token_logits = model(**inputs).logits