qg_generation / app.py
dhmeltzer's picture
Update app.py
8df9ec0
raw
history blame
4.56 kB
import numpy as np
import requests
import streamlit as st
import openai
import json
def main():
st.title("Scientific Question Generation")
st.write("This application is designed to generate a question given a piece of scientific text.\
We include the output from four different models, the [BART-Large](https://huggingface.co/dhmeltzer/bart-large_askscience-qg) and [FLAN-T5-Base](https://huggingface.co/dhmeltzer/flan-t5-base_askscience-qg) models \
fine-tuned on the r/AskScience split of the [ELI5 dataset](https://huggingface.co/datasets/eli5) as well as the zero-shot output \
of the [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) model and the [GPT-3.5-turbo](https://platform.openai.com/docs/models/gpt-3-5) model.\
\n\n For a more thorough discussion of question generation see this [report](https://wandb.ai/dmeltzer/Question_Generation/reports/Exploratory-Data-Analysis-for-r-AskScience--Vmlldzo0MjQwODg1?accessToken=fndbu2ar26mlbzqdphvb819847qqth2bxyi4hqhugbnv97607mj01qc7ed35v6w8) for EDA on the r/AskScience dataset and this \
[report](https://api.wandb.ai/links/dmeltzer/7an677es) for details on our training procedure.\
\n\n \
The two fine-tuned models (BART-Large and FLAN-T5-Base) are hosted on AWS using a combination of AWS Sagemaker, Lambda, and API gateway. \
\ GPT-3.5 is called using the OpenAI API and the FLAN-T5-XXL model is hosted by HuggingFace and is called with their Inference API.\
\n \n \
**Disclaimer**: You may recieve an error message when calling the FLAN-T5-XXL model since the Inference API takes around 20 seconds to load the model.\
")
AWS_checkpoints = {}
AWS_checkpoints['BART-Large']='https://8hlnvys7bh.execute-api.us-east-1.amazonaws.com/beta/'
AWS_checkpoints['FLAN-T5-Base']='https://gnrxh05827.execute-api.us-east-1.amazonaws.com/beta/'
# Right now HF_checkpoints just consists of FLAN-T5-XXL but we may add more models later.
HF_checkpoints = ['google/flan-t5-xxl']
# Token to access HF inference API
HF_headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}
# Token to access OpenAI API
openai.api_key = st.secrets['OpenAI_token']
# Used to query models hosted on Huggingface
def query(checkpoint, payload):
API_URL = f"https://api-inference.huggingface.co/models/{checkpoint}"
response = requests.post(API_URL,
headers=headers,
json=payload)
return response.json()
# User search
user_input = st.text_area("Question Generator",
"""Black holes are the most gravitationally dense objects in the universe.""")
if user_input:
for name, url in AWS_checkpoints.values():
headers={'x-api-key': key}
input_data = json.dumps({'inputs':user_input})
r = requests.get(url,data=input_data,headers=headers)
output = r.json()[0]['generated_text']
st.write(f'**{name}**: {output}')
model_engine = "gpt-3.5-turbo"
# Max tokens to produce
max_tokens = 50
# Prompt GPT-3.5 with an explicit question
prompt = f"generate a question: {user_input}"
# We give GPT-3.5 a message so it knows to generate questions from text.
response=openai.ChatCompletion.create(
model=model_engine,
messages=[
{"role": "system", "content": "You are a helpful assistant that generates questions from text."},
{"role": "user", "content": prompt},
])
output = response['choices'][0]['message']['content']
st.write(f'**{model_engine}**: {output}')
for checkpoint in HF_checkpoints:
model_name = checkpoint.split('/')[1]
# For FLAN models we need to give them instructions explicitly.
if 'flan' in model_name.lower():
prompt = 'generate a question: ' + user_input
else:
prompt = user_input
output = query(checkpoint,{
"inputs": prompt,
"wait_for_model":True})
try:
output=output[0]['generated_text']
except:
st.write(output)
return
st.write(f'**{model_name}**: {output}')
if __name__ == "__main__":
main()
#[0]['generated_text']