avsolatorio commited on
Commit
09767db
1 Parent(s): c47be47

Add device

Browse files
Files changed (1) hide show
  1. README.md +5 -3
README.md CHANGED
@@ -53,8 +53,9 @@ class WBGDocTopic:
53
  each document, along with the mean and standard deviation of the topic classification scores.
54
  """
55
 
56
- def __init__(self, classifiers: dict = None):
57
  self.classifiers = classifiers or {}
 
58
 
59
  if classifiers is None:
60
  self.load_classifiers()
@@ -71,7 +72,7 @@ class WBGDocTopic:
71
  continue
72
 
73
  model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
74
- classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None)
75
 
76
  self.classifiers[model_name] = classifier
77
 
@@ -118,7 +119,8 @@ sample_text = """A growing literature attributes gender inequality in labor mark
118
  sents = sent_tokenize(inp)
119
 
120
  # Create the instance which will load the models.
121
- dtopic_model = WBGDocTopic()
 
122
 
123
  # Infer the topics and scores
124
  outs = dtopic_model.suggest_topics(sents)
 
53
  each document, along with the mean and standard deviation of the topic classification scores.
54
  """
55
 
56
+ def __init__(self, classifiers: dict = None, device: str = None
57
  self.classifiers = classifiers or {}
58
+ self.device = device
59
 
60
  if classifiers is None:
61
  self.load_classifiers()
 
72
  continue
73
 
74
  model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
75
+ classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)
76
 
77
  self.classifiers[model_name] = classifier
78
 
 
119
  sents = sent_tokenize(inp)
120
 
121
  # Create the instance which will load the models.
122
+ # Set the device to "cuda" if you want to use a GPU.
123
+ dtopic_model = WBGDocTopic(device=None)
124
 
125
  # Infer the topics and scores
126
  outs = dtopic_model.suggest_topics(sents)