import gradio as gr from huggingface_hub import InferenceClient from transformers import AutoTokenizer, AutoModelForCausalLM # Initialize the Hugging Face model and tokenizer model_name = "HooshvareLab/gpt2-fa" # Example Persian GPT-2 model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) def suggest_next_word(prompt, num_suggestions=3,temper=0.7): """ Suggests the next word based on the given prompt. :param prompt: The text input so far. :param num_suggestions: Number of next-word suggestions to generate. :return: List of suggested next words. """ # Tokenize input and generate predictions input_ids = tokenizer.encode(prompt, return_tensors="pt") outputs = model.generate( input_ids, max_length=len(input_ids[0]) + 1, num_return_sequences=num_suggestions, do_sample=True, top_k=50, # Adjust for randomness in suggestions temperature=temper, # Adjust for creativity in suggestions pad_token_id=tokenizer.eos_token_id, # Set pad_token_id explicitly ) # Decode and extract next words suggestions = [] for output in outputs: decoded_text = tokenizer.decode(output, skip_special_tokens=True) next_word = decoded_text[len(prompt):].strip().split()[0] # Get the next word suggestions.append(next_word) return suggestions # Gradio interface def chat_interface(prompt, num_suggestions,temperature): suggestions = suggest_next_word(prompt, num_suggestions,temperature) return " | ".join(suggestions) # Define Gradio app with gr.Blocks() as demo: gr.Markdown("# Persian Language Next Word Predictor") prompt_input = gr.Textbox(label="Enter your prompt:", lines=2) num_suggestions_input = gr.Slider( minimum=1, maximum=5, value=3, step=1, label="Number of suggestions" ) temperature_input = gr.Slider( minimum=0.1, maximum=2, value=1, step=0.1, label="Temperature" ) output = gr.Textbox(label="Next Word Suggestions:") # Link input and output suggest_btn = gr.Button("Suggest Next Word") suggest_btn.click( chat_interface, inputs=[prompt_input, num_suggestions_input,temperature_input], outputs=output ) if __name__ == "__main__": demo.launch()