olya-const commited on
Commit
88b4df7
1 Parent(s): 7c14be4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -5,9 +5,10 @@ import datasets
5
 
6
  @st.cache
7
  def load_model():
8
- return AutoTokenizer.from_pretrained('distilbert-base-cased'), AutoModelForSequenceClassification.from_pretrained('./')
9
 
10
- tokenizer, model = load_model()
 
11
 
12
  title = st.text_area('Title')
13
  summary = st.text_area('Summary')
@@ -15,14 +16,13 @@ summary = st.text_area('Summary')
15
  label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics',
16
  3: 'Quantum biology', 4: 'Statistic'}
17
 
18
- device= 'cuda' if torch.cuda.is_available() else 'cpu'
19
  def predict(title, summary):
20
  dataset = datasets.Dataset.from_dict({'title': [title],
21
  'summary': [summary.replace("\n", " ")]})
22
  dataset = tokenizer(dataset["title"], dataset['summary'],
23
  padding="max_length", truncation=True, return_tensors='pt')
24
- logits = model(input_ids=dataset['input_ids'].to(device),
25
- attention_mask=dataset['attention_mask'].to(device))['logits']
26
  probs = torch.nn.functional.softmax(logits)[0].cpu().detach()
27
  preds = []
28
  proba = 0.
 
5
 
6
  @st.cache
7
  def load_model():
8
+ return AutoModelForSequenceClassification.from_pretrained('./')
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
11
+ model = load_model()
12
 
13
  title = st.text_area('Title')
14
  summary = st.text_area('Summary')
 
16
  label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics',
17
  3: 'Quantum biology', 4: 'Statistic'}
18
 
 
19
  def predict(title, summary):
20
  dataset = datasets.Dataset.from_dict({'title': [title],
21
  'summary': [summary.replace("\n", " ")]})
22
  dataset = tokenizer(dataset["title"], dataset['summary'],
23
  padding="max_length", truncation=True, return_tensors='pt')
24
+ logits = model(input_ids=dataset['input_ids'],
25
+ attention_mask=dataset['attention_mask'])['logits']
26
  probs = torch.nn.functional.softmax(logits)[0].cpu().detach()
27
  preds = []
28
  proba = 0.