blunt-octopus commited on
Commit
35556d6
1 Parent(s): e187312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -32,7 +32,7 @@ tag_to_id, id_to_tag, id_to_description = load_tags_info()
32
 
33
  @st.cache
34
  def load_model():
35
- return DistilBertForSequenceClassification.from_pretrained('./').to('cuda')
36
 
37
  def load_tokenizer():
38
  return AutoTokenizer.from_pretrained(base_model_name)
@@ -68,7 +68,7 @@ with st.form("Taxonomizer"):
68
  if submitted:
69
  prompt = 'Title: ' + title + ' Abstract: ' + abstract
70
  tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
71
- preds = softmax(model(tokens.reshape(1, -1).to('cuda')).logits / temperature, dim=1)
72
  tags = top_xx(preds, xx)
73
  other_tags = []
74
  st.header('Inferred tags:')
 
32
 
33
  @st.cache
34
  def load_model():
35
+ return DistilBertForSequenceClassification.from_pretrained('./')
36
 
37
  def load_tokenizer():
38
  return AutoTokenizer.from_pretrained(base_model_name)
 
68
  if submitted:
69
  prompt = 'Title: ' + title + ' Abstract: ' + abstract
70
  tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
71
+ preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
72
  tags = top_xx(preds, xx)
73
  other_tags = []
74
  st.header('Inferred tags:')