File size: 1,236 Bytes
231b4b5
 
29a0778
6ef97ad
29a0778
231b4b5
 
29a0778
6ef97ad
4a2e7c0
9c9e3b5
4a2e7c0
231b4b5
fc6e0e6
555dc8c
 
0a7a5c3
231b4b5
 
 
 
5546b1b
231b4b5
 
6ef97ad
231b4b5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
import random

tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')

def text_generation(keywords, domain):
  input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
  torch.manual_seed(random.randint(0,18446744073709551615))
  outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
  title_description_arr = generated_text[0].split(' title ')[1].split('description')
  title = title_description_arr[0].strip()
  description = title_description_arr[1].strip()
  return title + "\r\n\r\n" + description

title = "TDK GPT2"
description = "Title and description generation by keywords"


gr.Interface(
    text_generation,
    [gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain")],
    [gr.outputs.Textbox(type="auto", label="Text Generated")],
    title=title,
    description=description,
    theme="huggingface"
).launch()