AneriThakkar's picture
Upload main.py
2a0d6cc verified
raw
history blame
No virus
2.7 kB
# import torch
import streamlit as st
# import numpy as np
from transformers import T5ForConditionalGeneration, T5Tokenizer
# from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_model(model_name):
if model_name == "T5":
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
return model, tokenizer
if model_name == "Llama3":
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
return model, tokenizer
if model_name == "Llama3-Instruct":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
return model, tokenizer
else:
st.error(f"Model {model_name} not available.")
return None, None
def generate_question(model,tokenizer,context):
input_text = 'Generate a question from this: ' + context
input_ids = tokenizer(input_text, return_tensors='pt').input_ids
outputs = model.generate(input_ids,max_length=512)
output_text = tokenizer.decode(outputs[0][1:len(outputs[0])-1])
return output_text
def main():
st.title("Question Generation From Given Text")
context = st.text_area("Enter text","Laughter is the best medicine.")
st.write("Select a model and provide the text to generate questions.")
model_choice = st.selectbox("Select a model", ["T5", "Llama3", "Llama3-Instruct"])
if st.button("Generate Questions"):
model, tokenizer = load_model(model_choice)
if model and tokenizer:
questions = generate_question(model, tokenizer, context)
st.write("Generated Question:")
st.write(questions)
else:
st.error("Model loading failed.")
# tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
# tokenizer = AutoTokenizer.from_pretrained("ramsrigouthamg/t5_squad_v1")
# model = AutoModelForSeq2SeqLM.from_pretrained("ramsrigouthamg/t5_squad_v1")
# input_text = 'Generate a question from this: ' + context
# input_ids = tokenizer(input_text, return_tensors='pt').input_ids
# outputs = model.generate(input_ids)
# output_text = tokenizer.decode(outputs[0][1:len(outputs[0])-1])
# st.write("Generated question:")
# st.write(output_text)
if __name__ == '__main__':
main()