Spaces:
Runtime error
Runtime error
import streamlit as st | |
import tensorflow | |
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel | |
# Load the fine-tuned model | |
model = TFGPT2LMHeadModel.from_pretrained("fine-tuned-gpt2") | |
# Initialize the tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium") | |
# Set the maximum length for the generated text | |
max_length = 100 | |
def generate_answer(prompt): | |
# Encode the prompt | |
input_ids = tokenizer.encode(prompt, return_tensors="tf") | |
# Generate text using the model | |
output = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=0.7, | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(output[0]) | |
return generated_text | |
def main(): | |
st.markdown("<h1 style='text-align: center;'>Agrinuture Bot</h1>", unsafe_allow_html=True) | |
# Get user input | |
st.markdown("<p style='text-align: center;'>Enter your question</p>", unsafe_allow_html=True) | |
prompt = st.text_input("") | |
st.markdown( | |
""" | |
<style> | |
.stButton > button { | |
display: block; | |
margin: 0 auto; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# Generate answer on button click | |
if st.button("Generate Answer"): | |
answer = generate_answer(prompt) | |
st.markdown("<h3 style='text-align: center;'>Generated Answer:</h3>", unsafe_allow_html=True) | |
st.text(answer) | |
if __name__ == '__main__': | |
main() | |