SmartPy commited on
Commit
92e1aef
1 Parent(s): 6f958e0

Update textrank.py

Browse files
Files changed (1) hide show
  1. textrank.py +10 -6
textrank.py CHANGED
@@ -1,19 +1,23 @@
1
  import numpy as np
2
  import pandas as pd
3
  import nltk
4
- nltk.download('punkt') # one time execution
5
  import re
6
- import warnings
7
- warnings.filterwarnings('ignore')
8
- from sklearn.metrics.pairwise import cosine_similarity
9
  import networkx as nx
10
  from tqdm import tqdm
11
-
12
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
13
 
14
  model = SentenceTransformer('all-mpnet-base-v2')
 
 
15
 
16
- model.to('cuda')
17
  def get_summary(text, num_words: int=1000):
18
  sentences = nltk.sent_tokenize(text)
19
  embeddings = model.encode(sentences, show_progress_bar=False)
 
1
  import numpy as np
2
  import pandas as pd
3
  import nltk
 
4
  import re
5
+
6
+ import torch
 
7
  import networkx as nx
8
  from tqdm import tqdm
 
9
  from sentence_transformers import SentenceTransformer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ nltk.download('punkt')
16
 
17
  model = SentenceTransformer('all-mpnet-base-v2')
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model.to(device)
20
 
 
21
  def get_summary(text, num_words: int=1000):
22
  sentences = nltk.sent_tokenize(text)
23
  embeddings = model.encode(sentences, show_progress_bar=False)