JoPmt's picture
Update app.py
c062a99 verified
import gradio as gr
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline
from accelerate import Accelerator
accelerator = Accelerator(cpu=True)
cwd = "./models"
tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m"))
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"))
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='./train_text.txt',
block_size=128
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
training_args = TrainingArguments(
output_dir=cwd,
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=5,
save_steps=500,
save_total_limit=5,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
tokenizer.save_pretrained('./models')
trainer.save_model('./models', 'pytorch_model')
src = './config.json'
des = './models/config.json'
os.rename(src, des)
tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m"))
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("./models"))
def plex(input_text):
mnputs = tokenizer(input_text, return_tensors='pt')
prediction = model.generate(mnputs['input_ids'], min_length=20, max_length=150, num_return_sequences=1)
lines = tokenizer.decode(prediction[0]).splitlines()
return lines[0]
iface=gr.Interface(
fn=plex,
inputs=gr.Textbox(label="Prompt Finetuned Model Exmpl: a cat, Exmpl: 3 little bears, Exmpl: once upon a time", value="Once upon a"),
outputs=gr.Textbox(label="Generated_Text"),
title="GPT-Neo-125M fine-tuned on a small set of shortstories with Gradio",
description="Prompt for a short bedtime story.",
##examples=gr.Examples(fn=fine_tune_llm,inputs=['./test.txt',"Once upon a time",2,2000],outputs=[gr.Textbox(),gr.File()],cache_examples=True,)
)
iface.queue(max_size=1,api_open=False)
iface.launch(max_threads=1)