qc7 commited on
Commit
571d116
1 Parent(s): dcbf77d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -2,10 +2,11 @@ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
 
 
5
  import transformers
6
  from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
7
 
8
- @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
9
  def load_tok_and_model():
10
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
11
  model = AutoModelForSequenceClassification.from_pretrained(".")
@@ -16,7 +17,7 @@ CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathem
16
  "Q. Biology", "Q. Finances", "Statistics" , "Physics"]
17
 
18
 
19
- @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
20
  def forward_pass(title, abstract, tokenizer, model):
21
  title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids'])
22
  abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids'])
 
2
  import numpy as np
3
  import pandas as pd
4
 
5
+ import tokenizers # for streamlit caching
6
  import transformers
7
  from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
8
 
9
+ @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
10
  def load_tok_and_model():
11
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
12
  model = AutoModelForSequenceClassification.from_pretrained(".")
 
17
  "Q. Biology", "Q. Finances", "Statistics" , "Physics"]
18
 
19
 
20
+ @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
21
  def forward_pass(title, abstract, tokenizer, model):
22
  title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids'])
23
  abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids'])