|
import gradio as gr |
|
|
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-erichmariaremarque' |
|
MIN_WORDS = 80 |
|
|
|
|
|
def get_model(): |
|
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) |
|
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) |
|
return model, tokenizer |
|
|
|
|
|
def generate(prompt): |
|
model, tokenizer = get_model() |
|
|
|
input_context = prompt |
|
input_ids = tokenizer.encode(input_context, return_tensors="pt").to('cuda') |
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=100, |
|
temperature=0.7, |
|
num_return_sequences=3, |
|
do_sample=True, |
|
|
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.' |
|
|
|
|
|
def predict(prompt): |
|
return generate(prompt=prompt) |
|
|
|
|
|
title = "What would Remarques say?" |
|
description = """ |
|
The bot was trained to complete your prompt as if it was a begining of a paragraph of Remarque's book. |
|
<img src="https://upload.wikimedia.org/wikipedia/commons/1/10/Bundesarchiv_Bild_183-R04034%2C_Erich_Maria_Remarque_%28cropped%29.jpg" align=center width=200px> |
|
""" |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs="textbox", |
|
outputs="text", |
|
title=title, |
|
description=description, |
|
examples=[["I was drinking because"], ["Who is Karl for me?"]] |
|
).launch(debug=True) |
|
|