laba2.2 / app.py
Tasya's picture
Upload app.py
f9cd35c
import streamlit as st
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax
@st.cache
def load_tags_info():
tag_id = {}
id_tag = {}
with open('tag.txt', 'r') as file:
i = 0
for line in file:
tag = line[1:].split(',')[0]
tag_id[tag] = i
id_tag[i] = tag
i += 1
tag_id['None'] = 150
id_tag[150] = 'None'
return (tag_id, id_tag)
tag_id, id_tag = load_tags_info()
@st.cache
def load_tokenizer():
return AutoTokenizer.from_pretrained('./')
@st.cache
def load_model():
return DistilBertForSequenceClassification.from_pretrained('./')
def top_xx(preds, xx=95):
tops = torch.argsort(preds, 1, descending=True)
total = 0
index = 0
result = []
while total < xx / 100:
next_id = tops[0, index].item()
if next_id == 150:
index += 1
continue
total += preds[0, next_id]
index += 1
result.append(id_tag[next_id])
return result
model = load_model()
tokenizer = load_tokenizer()
title = st.text_area(label='Title', height=30)
abstract = st.text_area(label='Abstract (optional)', height=200)
st.caption('Generation:')
prompt = 'Title: ' + title + ' Abstract: ' + abstract
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
preds = softmax(model(tokens.reshape(1, -1)).logits , dim=1)
tags = top_xx(preds)
other_tags = []
st.header('Inferred tags:')
for i, tag_data in enumerate(tags):
st.caption(tag_data)