trminhnam20082002 commited on
Commit
2b0ea98
1 Parent(s): 9922ab1
Files changed (2) hide show
  1. app.py +3 -0
  2. model.py +8 -2
app.py CHANGED
@@ -7,6 +7,7 @@ from st_utils import (
7
  )
8
  from huggingface_hub import hf_hub_download
9
  import os
 
10
 
11
  # list_files(os.getcwd())
12
 
@@ -19,6 +20,8 @@ Simply select one of the sample Python functions from the dropdown menu below, a
19
  """
20
  )
21
 
 
 
22
  # Download the model from the Hugging Face Hub if it doesn't exist
23
  download_model()
24
 
 
7
  )
8
  from huggingface_hub import hf_hub_download
9
  import os
10
+ import torch
11
 
12
  # list_files(os.getcwd())
13
 
 
20
  """
21
  )
22
 
23
+ st.write(f"Has CUDA: {torch.cuda.is_available()}")
24
+
25
  # Download the model from the Hugging Face Hub if it doesn't exist
26
  download_model()
27
 
model.py CHANGED
@@ -104,7 +104,10 @@ class Seq2Seq(nn.Module):
104
  else:
105
  # Predict
106
  preds = []
107
- zero = torch.cuda.LongTensor(1).fill_(0)
 
 
 
108
  for i in range(source_ids.shape[0]):
109
  context = encoder_output[:, i : i + 1]
110
  context_mask = source_mask[i : i + 1, :]
@@ -154,7 +157,10 @@ class Seq2Seq(nn.Module):
154
  class Beam(object):
155
  def __init__(self, size, sos, eos):
156
  self.size = size
157
- self.tt = torch.cuda
 
 
 
158
  # The score for each translation on the beam.
159
  self.scores = self.tt.FloatTensor(size).zero_()
160
  # The backpointers at each time-step.
 
104
  else:
105
  # Predict
106
  preds = []
107
+ try:
108
+ zero = torch.cuda.LongTensor(1).fill_(0)
109
+ except Exception as e:
110
+ zero = torch.LongTensor(1).fill_(0)
111
  for i in range(source_ids.shape[0]):
112
  context = encoder_output[:, i : i + 1]
113
  context_mask = source_mask[i : i + 1, :]
 
157
  class Beam(object):
158
  def __init__(self, size, sos, eos):
159
  self.size = size
160
+ if torch.cuda.is_available():
161
+ self.tt = torch.cuda
162
+ else:
163
+ self.tt = torch
164
  # The score for each translation on the beam.
165
  self.scores = self.tt.FloatTensor(size).zero_()
166
  # The backpointers at each time-step.