TrLOX commited on
Commit
6ef97ad
·
1 Parent(s): 4a2e7c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,23 +1,24 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import gradio as gr
 
4
 
5
  tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
6
  model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
7
 
8
- def text_generation(keywords, domain, seed):
9
  input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
10
- torch.manual_seed(seed) # Max value: 18446744073709551615
11
  outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
12
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
13
- return generated_text
14
 
15
  title = "TDK GPT2"
16
  description = "Title and description generation by keywords"
17
 
18
  gr.Interface(
19
  text_generation,
20
- [gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain"), gr.inputs.Number(default=10, label="Enter seed number")],
21
  [gr.outputs.Textbox(type="auto", label="Text Generated")],
22
  title=title,
23
  description=description,
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import gradio as gr
4
+ import random
5
 
6
  tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
7
  model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
8
 
9
+ def text_generation(keywords, domain):
10
  input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
11
+ torch.manual_seed(random.seed(18446744073709551615))
12
  outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
13
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
14
+ return generated_text[0]
15
 
16
  title = "TDK GPT2"
17
  description = "Title and description generation by keywords"
18
 
19
  gr.Interface(
20
  text_generation,
21
+ [gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain")],
22
  [gr.outputs.Textbox(type="auto", label="Text Generated")],
23
  title=title,
24
  description=description,