kunwarsaaim commited on
Commit
2e1968d
1 Parent(s): 490c46c

gpu inference

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -16,6 +16,9 @@ DEBIASING_KEYWORDS = [
16
  "(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
17
  ]
18
 
 
 
 
19
  def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
20
  """
21
  Debiasing inference function.
@@ -24,7 +27,7 @@ def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
24
  :param max_length: The maximum length of the output sentence.
25
  :return: The debiased output sentence.
26
  """
27
- wrapper = GPT2Wrapper(model_name=str(model), use_cuda=False)
28
  if use_prefix == 'Prefixes':
29
  debiasing_prefixes = DEBIASING_PREFIXES
30
  else:
 
16
  "(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
17
  ]
18
 
19
+ if torch.cuda.is_available():
20
+ use_cuda = True
21
+
22
  def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
23
  """
24
  Debiasing inference function.
 
27
  :param max_length: The maximum length of the output sentence.
28
  :return: The debiased output sentence.
29
  """
30
+ wrapper = GPT2Wrapper(model_name=str(model), use_cuda=use_cuda)
31
  if use_prefix == 'Prefixes':
32
  debiasing_prefixes = DEBIASING_PREFIXES
33
  else: