wing-nus commited on
Commit
fcc0878
1 Parent(s): 96bfa0d

automatically choose cpu, gpu settings based on running environment.

Files changed (1) hide show
  1. controlled_summarization.py +9 -5
controlled_summarization.py CHANGED
@@ -4,17 +4,21 @@ from SciAssist import Summarization
4
  import os
5
  import requests
6
  from datasets import load_dataset
 
7
  print(f"Is CUDA available: {torch.cuda.is_available()}")
8
  # True
9
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
10
-
11
- acl_data = load_dataset("dyxohjl666/CocoScisum_ACL", revision="refs/convert/parquet")
12
- device = "gpu" if torch.cuda.is_available() else "cpu"
 
 
 
13
 
14
- ctrlsum_pipeline = Summarization(os_name="nt",model_name="flan-t5-xl",checkpoint="dyxohjl666/flant5-xl-cocoscisum",device=device)
15
 
16
  acl_dict = {}
17
  recommended_kw = {}
 
18
 
19
 
20
  def convert_to_dict(data):
 
4
  import os
5
  import requests
6
  from datasets import load_dataset
7
+
8
  print(f"Is CUDA available: {torch.cuda.is_available()}")
9
  # True
10
+ if torch.cuda.is_available():
11
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
12
+ device = 'gpu'
13
+ ctrlsum_pipeline = Summarization(os_name="nt",model_name="flan-t5-xl",checkpoint="dyxohjl666/flant5-xl-cocoscisum",device=device)
14
+ else:
15
+ device = 'cpu'
16
+ ctrlsum_pipeline = Summarization(os_name="nt",device=device)
17
 
 
18
 
19
  acl_dict = {}
20
  recommended_kw = {}
21
+ acl_data = load_dataset("dyxohjl666/CocoScisum_ACL", revision="refs/convert/parquet")
22
 
23
 
24
  def convert_to_dict(data):