Atharva commited on
Commit
f28ff95
1 Parent(s): 2aa1d38

fixed torch device

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +1 -2
  3. requirements.txt +3 -2
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
 
9
  pinned: false
10
  license: mit
11
  ---
 
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
9
+ models: ["dslim/bert-base-NER"]
10
  pinned: false
11
  license: mit
12
  ---
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipelin
5
 
6
  from src import GBRT, wikipedia_search, google_search
7
 
8
-
9
  TYPE = {
10
  'LOC': ' location',
11
  'PER': ' person',
@@ -31,7 +30,7 @@ def load_models():
31
  tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
32
  bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
33
  tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
34
- device=0, aggregation_strategy="average")
35
  # NED
36
  model = GBRT()
37
  return model, tagger
 
5
 
6
  from src import GBRT, wikipedia_search, google_search
7
 
 
8
  TYPE = {
9
  'LOC': ' location',
10
  'PER': ' person',
 
30
  tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
31
  bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
32
  tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
33
+ device=-1, aggregation_strategy="average")
34
  # NED
35
  model = GBRT()
36
  return model, tagger
requirements.txt CHANGED
@@ -2,5 +2,6 @@ nltk
2
  numpy
3
  pandas
4
  wikipedia2vec
5
- torch
6
- transformers
 
 
2
  numpy
3
  pandas
4
  wikipedia2vec
5
+ transformers
6
+ --extra-index-url https://download.pytorch.org/whl/cpu
7
+ torch