word_suggester / app.py
alijkdkar's picture
set pad_token_id explicitly in suggest_next_word function
495bccc
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
# Initialize the Hugging Face model and tokenizer
model_name = "HooshvareLab/gpt2-fa" # Example Persian GPT-2 model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def suggest_next_word(prompt, num_suggestions=3,temper=0.7):
"""
Suggests the next word based on the given prompt.
:param prompt: The text input so far.
:param num_suggestions: Number of next-word suggestions to generate.
:return: List of suggested next words.
"""
# Tokenize input and generate predictions
input_ids = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
input_ids,
max_length=len(input_ids[0]) + 1,
num_return_sequences=num_suggestions,
do_sample=True,
top_k=50, # Adjust for randomness in suggestions
temperature=temper, # Adjust for creativity in suggestions
pad_token_id=tokenizer.eos_token_id, # Set pad_token_id explicitly
)
# Decode and extract next words
suggestions = []
for output in outputs:
decoded_text = tokenizer.decode(output, skip_special_tokens=True)
next_word = decoded_text[len(prompt):].strip().split()[0] # Get the next word
suggestions.append(next_word)
return suggestions
# Gradio interface
def chat_interface(prompt, num_suggestions,temperature):
suggestions = suggest_next_word(prompt, num_suggestions,temperature)
return " | ".join(suggestions)
# Define Gradio app
with gr.Blocks() as demo:
gr.Markdown("# Persian Language Next Word Predictor")
prompt_input = gr.Textbox(label="Enter your prompt:", lines=2)
num_suggestions_input = gr.Slider(
minimum=1, maximum=5, value=3, step=1, label="Number of suggestions"
)
temperature_input = gr.Slider(
minimum=0.1, maximum=2, value=1, step=0.1, label="Temperature"
)
output = gr.Textbox(label="Next Word Suggestions:")
# Link input and output
suggest_btn = gr.Button("Suggest Next Word")
suggest_btn.click(
chat_interface,
inputs=[prompt_input, num_suggestions_input,temperature_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()