gamechat / app_old.py
Rajkumar Pramanik "RJproz
no message
1419291
import gradio as gr
import logging
# from transformers import GPTJForCausalLM, GPT2Tokenizer
# # Load the GPT-J model and tokenizer
# model_name = "EleutherAI/gpt-j-6B"
# tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# model = GPTJForCausalLM.from_pretrained(model_name)
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
#model_name = "gpt2" # You can use "gpt2-medium" or "gpt2-large" for more power
model_name = "../custom_model/custom_gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.config.eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
# Set the pad_token to eos_token
# Function to generate text based on the user input
def generate_text(prompt):
# Tokenizing the input
inputs = tokenizer(prompt, return_tensors="pt", truncation=False, padding=False, max_length=512)
# Generate output
outputs = model.generate(inputs['input_ids'],max_length = 150, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50,
top_p=0.95,
eos_token_id=model.config.eos_token_id
)
# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text.strip()
# Gradio interface setup
iface = gr.Interface(fn=generate_text,
inputs=gr.inputs.Textbox(lines=10, placeholder="Enter your prompt here..."),
outputs="text")
# Launch the Gradio interface
iface.launch(share=True)