avsolatorio
commited on
Commit
•
09767db
1
Parent(s):
c47be47
Add device
Browse files
README.md
CHANGED
@@ -53,8 +53,9 @@ class WBGDocTopic:
|
|
53 |
each document, along with the mean and standard deviation of the topic classification scores.
|
54 |
"""
|
55 |
|
56 |
-
def __init__(self, classifiers: dict = None
|
57 |
self.classifiers = classifiers or {}
|
|
|
58 |
|
59 |
if classifiers is None:
|
60 |
self.load_classifiers()
|
@@ -71,7 +72,7 @@ class WBGDocTopic:
|
|
71 |
continue
|
72 |
|
73 |
model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
|
74 |
-
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None)
|
75 |
|
76 |
self.classifiers[model_name] = classifier
|
77 |
|
@@ -118,7 +119,8 @@ sample_text = """A growing literature attributes gender inequality in labor mark
|
|
118 |
sents = sent_tokenize(inp)
|
119 |
|
120 |
# Create the instance which will load the models.
|
121 |
-
|
|
|
122 |
|
123 |
# Infer the topics and scores
|
124 |
outs = dtopic_model.suggest_topics(sents)
|
|
|
53 |
each document, along with the mean and standard deviation of the topic classification scores.
|
54 |
"""
|
55 |
|
56 |
+
def __init__(self, classifiers: dict = None, device: str = None
|
57 |
self.classifiers = classifiers or {}
|
58 |
+
self.device = device
|
59 |
|
60 |
if classifiers is None:
|
61 |
self.load_classifiers()
|
|
|
72 |
continue
|
73 |
|
74 |
model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
|
75 |
+
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)
|
76 |
|
77 |
self.classifiers[model_name] = classifier
|
78 |
|
|
|
119 |
sents = sent_tokenize(inp)
|
120 |
|
121 |
# Create the instance which will load the models.
|
122 |
+
# Set the device to "cuda" if you want to use a GPU.
|
123 |
+
dtopic_model = WBGDocTopic(device=None)
|
124 |
|
125 |
# Infer the topics and scores
|
126 |
outs = dtopic_model.suggest_topics(sents)
|