gpt2 / run.py
ElliNet13's picture
Update run.py
c0579aa verified
raw
history blame contribute delete
832 Bytes
import time
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load pre-trained model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
def generate_response(user_input, max_length=50):
# Tokenize user input and convert to tensor
input_ids = tokenizer.encode(user_input, return_tensors="pt")
# Generate response using the model
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
# Decode the generated response
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
def resposeYielder(message, history):
yield generate_response(message)
demo = gr.ChatInterface(resposeYielder).queue()
if __name__ == "__main__":
demo.launch()