CelebChat / utils.py
lhzstar
new commits
31bd8d7
raw
history blame contribute delete
No virus
3.26 kB
import re
import spacy
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel
from unlimiformer import Unlimiformer, UnlimiformerArguments
import streamlit as st
from urllib.request import Request, urlopen, HTTPError
from bs4 import BeautifulSoup
def hide_footer():
hide_st_style = """
<style>
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
@st.cache_resource
def get_seq2seq_model(model_id, use_unlimiformer=True, _tokenizer=None):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
if use_unlimiformer:
defaults = UnlimiformerArguments()
unlimiformer_kwargs = {
'layer_begin': defaults.layer_begin,
'layer_end': defaults.layer_end,
'unlimiformer_head_num': defaults.unlimiformer_head_num,
'exclude_attention': defaults.unlimiformer_exclude,
'chunk_overlap': defaults.unlimiformer_chunk_overlap,
'model_encoder_max_len': defaults.unlimiformer_chunk_size,
'verbose': defaults.unlimiformer_verbose, 'tokenizer': _tokenizer,
'unlimiformer_training': defaults.unlimiformer_training,
'use_datastore': defaults.use_datastore,
'flat_index': defaults.flat_index,
'test_datastore': defaults.test_datastore,
'reconstruct_embeddings': defaults.reconstruct_embeddings,
'gpu_datastore': defaults.gpu_datastore,
'gpu_index': defaults.gpu_index
}
return Unlimiformer.convert_model(model, **unlimiformer_kwargs)
else:
return model
@st.cache_resource
def get_causal_model(model_id):
return AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
@st.cache_resource
def get_auto_model(model_id):
return AutoModel.from_pretrained(model_id)
@st.cache_resource
def get_tokenizer(model_id):
return AutoTokenizer.from_pretrained(model_id)
@st.cache_data
def get_celeb_data(fpath):
with open(fpath, encoding='UTF-8') as json_file:
return json.load(json_file)
def get_article(url):
req = Request(
url=url,
headers={'User-Agent': 'Mozilla/5.0'}
)
try:
html = urlopen(req).read()
soup = BeautifulSoup(html, features="html.parser")
# kill all script and style elements
for script in soup(["script", "style"]):
script.extract() # rip it out
lines = []
# get text
for para in soup.find_all("p", class_='topic-paragraph'):
lines.append(para.get_text().strip())
# break multi-headlines into a line each
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
# drop blank lines
text = ' '.join(chunk for chunk in chunks if chunk)
return text
except:
st.markdown("The internet is not stable.")
return ""
@st.cache_resource
def get_spacy_model(model_id):
return spacy.load(model_id)
def preprocess_text(name, text:str, model_id):
spacy_model = get_spacy_model(model_id)
texts = [i.text.strip() for i in spacy_model(text).sents]
return spacy_model, texts