AlonG11's picture
Update app.py
adc7ff3
import streamlit as st
import tensorflow
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
# Load the fine-tuned model
model = TFGPT2LMHeadModel.from_pretrained("fine-tuned-gpt2")
# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
# Set the maximum length for the generated text
max_length = 100
def generate_answer(prompt):
# Encode the prompt
input_ids = tokenizer.encode(prompt, return_tensors="tf")
# Generate text using the model
output = model.generate(
input_ids=input_ids,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
)
# Decode the generated text
generated_text = tokenizer.decode(output[0])
return generated_text
def main():
st.markdown("<h1 style='text-align: center;'>Agrinuture Bot</h1>", unsafe_allow_html=True)
# Get user input
st.markdown("<p style='text-align: center;'>Enter your question</p>", unsafe_allow_html=True)
prompt = st.text_input("")
st.markdown(
"""
<style>
.stButton > button {
display: block;
margin: 0 auto;
}
</style>
""",
unsafe_allow_html=True
)
# Generate answer on button click
if st.button("Generate Answer"):
answer = generate_answer(prompt)
st.markdown("<h3 style='text-align: center;'>Generated Answer:</h3>", unsafe_allow_html=True)
st.text(answer)
if __name__ == '__main__':
main()