minelli / app.py
Jordan Bracco
no idea what i'm doing, but doing it!
c266f8b
raw
history blame
1.76 kB
import torch
import gradio as gr
import requests
import os
import transformers
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = AutoTokenizer.from_pretrained("href/gpt2-schiappa")
model = AutoModelForCausalLM.from_pretrained("href/gpt2-schiappa")
pipe = pipeline('text-generation', model="href/gpt2-schiappa", tokenizer=tokenizer)
def text_generation(input_text, seed, min_length, max_length, temperature, top_k, top_p, repetition_penalty):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
torch.manual_seed(seed) # Max value: 18446744073709551615
outputs = model.generate(input_ids, do_sample=True, max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return generated_text
title = "Minelli"
description = "Écrit comme Marie Minelli"
gr.Interface(
text_generation,
[
gr.inputs.Textbox(lines=2, label="Prompt"),
gr.inputs.Number(default=10, label="Enter seed number"),
gr.inputs.Slider(default=0, label="Minimum Length", minimum=0),
gr.inputs.Slider(default=100, label="Maximum Length", minimum=0),
gr.inputs.Slider(default=1.0, label="Temperature", minimum=0.1),
gr.inputs.Slider(default=40, label="topK", minimum=1),
gr.inputs.Slider(default=0.9, label="topP", minimum=0.0, maximum=1.0),
gr.inputs.Slider(default=1.0, label="Repetition penalty", minimum=1.0)
],
[gr.outputs.Textbox(type="auto", label="Generated")],
title=title,
description=description,
theme="huggingface"
).launch()