File size: 1,038 Bytes
231b4b5
 
29a0778
6ef97ad
29a0778
231b4b5
 
29a0778
6ef97ad
4a2e7c0
6ef97ad
4a2e7c0
231b4b5
6ef97ad
231b4b5
 
 
 
 
 
6ef97ad
231b4b5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
import random

tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')

def text_generation(keywords, domain):
  input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
  torch.manual_seed(random.seed(18446744073709551615))
  outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
  return generated_text[0]

title = "TDK GPT2"
description = "Title and description generation by keywords"

gr.Interface(
    text_generation,
    [gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain")],
    [gr.outputs.Textbox(type="auto", label="Text Generated")],
    title=title,
    description=description,
    theme="huggingface"
).launch()