PLTS / app.py
Linseypass's picture
Update app.py
e77c4bf
raw
history blame contribute delete
No virus
5.15 kB
import gradio as gr
from nltk.tokenize import sent_tokenize
import torch
import ujson as json
from transformers import AutoModelForCausalLM,LlamaTokenizer,BitsAndBytesConfig
from peft import PeftModel
from keybert import KeyBERT
from keyphrase_vectorizers import KeyphraseCountVectorizer
import nltk
nltk.download('punkt')
# loads Guanaco 7B model - takes around 2-3 minutes - can do this separately
model_name = "decapoda-research/llama-7b-hf"
adapters_name = 'timdettmers/guanaco-7b'
# print(f"Starting to load the model {model_name} into memory")
m = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16)
m = PeftModel.from_pretrained(m, adapters_name)
m = m.merge_and_unload()
tok = LlamaTokenizer.from_pretrained(model_name)
tok.bos_token_id = 1
stop_token_ids = [0]
# print(f"Successfully loaded the model {model_name} into memory")
print('Guanaco model loaded into memory.')
def generate(title, abstract):
print("Started running.")
'''
Take gradio input and output data to sample-data.jsonl in readable form for classifier.py to run.
'''
newline = {}
text = abstract
# eliminate word lowercase "abstract" or "abstract." at beginning of abstract text
if text.lower()[0:9] == "abstract.":
text = text[9:]
elif text.lower()[0:8] == "abstract":
text = text[8:]
sentences = sent_tokenize(text)
newline["target"] = sentences
newline["title"] = title
print("Tokenized abstract to sentences.")
'''
Main part
'''
'''
This is for summarization
'''
tooShortForKeyword = False
obj = newline
doc = ""
if len(obj["target"]) > 1:
doc += obj["title"] + ". " + obj["target"][0] + " " + obj["target"][1]
elif len(obj["target"]) == 1:
tooShortForKeyword = True
doc += obj["title"] + ". " + obj["target"][0]
else:
tooShortForKeyword = True
doc += obj["title"]
text = doc
prompt = """
Can you explain the main idea of what is being studied in the following paragraph for someone who is not familiar with the topic. Comment on areas of application.:
"""
formatted_prompt = (
f"A chat between a curious human and an artificial intelligence assistant."
f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
f"### Human: {prompt + doc} \n"
f"### Assistant:"
)
inputs = tok(formatted_prompt, return_tensors="pt")#.to("cuda:1")
outputs = m.generate(inputs=inputs.input_ids, max_new_tokens=300)
output = tok.decode(outputs[0], skip_special_tokens=True)
index_response = output.find("### Assistant: ") + 15
if (output[index_response:index_response + 10] == "Certainly!"):
index_response += 10
end_response = output.rfind('.') + 1
response = output[index_response:end_response]
print('Plain Language Summary Created.')
'''
Keyphrase extraction.
'''
# the document is the title and first two sentences of the abstract.
obj = newline
doc = ""
if len(obj["target"]) > 1:
doc += obj["title"] + ". " + obj["target"][0] + " " + obj["target"][1]
kw_model = KeyBERT(model="all-MiniLM-L6-v2")
vectorizer = KeyphraseCountVectorizer()
top_n = 2
keywords = kw_model.extract_keywords(doc, stop_words="english", top_n = top_n, vectorizer=vectorizer, use_mmr=True)
my_keywords = []
for i in range(top_n):
add = True
for j in range(top_n):
if i != j:
if keywords[i][0] in keywords[j][0]:
add = False
if add:
my_keywords.append(keywords[i][0])
for entry in my_keywords:
print(entry)
'''
This is for feeding the keyphrases into Guanaco.
'''
responseTwo = ""
keyword_string = ""
if not tooShortForKeyword:
separator = ', '
keyword_string = separator.join(my_keywords)
prompt = "What is the purpose of studying " + keyword_string + "? Comment on areas of application."
formatted_prompt = (
f"A chat between a curious human and an artificial intelligence assistant."
f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
f"### Human: {prompt} \n"
f"### Assistant:"
)
inputs = tok(formatted_prompt, return_tensors="pt")#.to("cuda:2")
outputs = m.generate(inputs=inputs.input_ids, max_new_tokens=300)
output = tok.decode(outputs[0], skip_special_tokens=True)
index_response = output.find("### Assistant: ") + 15
end_response = output.rfind('.') + 1
responseTwo = output[index_response:end_response]
print('Keyphrase elaboration ran.')
return keyword_string, responseTwo, response
demo = gr.Interface(
fn=generate,
inputs=[gr.Textbox(label="Title"), gr.Textbox(label="Abstract")],
outputs=[gr.Textbox(label="Keyphrases"), gr.Textbox(label="Keyphrase Elaboration"), gr.Textbox(label="Plain Language Summary")],
)
demo.launch()