Syrinx commited on
Commit
ec83c3c
1 Parent(s): a478e11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -22,12 +22,17 @@ def main():
22
  description = generate_description(title)
23
  st.success(description)
24
 
 
 
 
 
 
25
 
26
  # Define the function that generates the description
27
  def generate_description(title):
28
  # Preprocess the input
29
  input_text = f"{title}"
30
- input_ids = tokenizer.encode(input_text, return_tensors='pt')
31
 
32
  # Generate the output using the model
33
  output = model.generate(
 
22
  description = generate_description(title)
23
  st.success(description)
24
 
25
+ # Check if GPU is available
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda")
28
+ else:
29
+ device = torch.device("cpu")
30
 
31
  # Define the function that generates the description
32
  def generate_description(title):
33
  # Preprocess the input
34
  input_text = f"{title}"
35
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
36
 
37
  # Generate the output using the model
38
  output = model.generate(