File size: 2,466 Bytes
08caf3a
 
 
 
 
 
 
 
 
e0a6493
08caf3a
 
 
 
 
 
 
 
9c9d8ef
08caf3a
 
 
 
 
 
e0a6493
08caf3a
 
9c9d8ef
 
08caf3a
e0a6493
08caf3a
9c9d8ef
08caf3a
 
 
 
 
 
 
 
 
 
e0a6493
08caf3a
 
 
e0a6493
08caf3a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import json
import os
import shutil
import requests

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer


def generate(html, entity, website_desc, datasource, year, month, title, prompt):
    html_text = "html | " if html == "on" else ""
    entity_text = ""
    if entity != "":
        ent_list = [x.strip() for x in entity.split(',')]
        for ent in ent_list:
            entity_text = entity_text + " |" + ent + "|"       
        entity_text = "entity ||| <ENTITY_CHAIN>"  + entity_text +  " </ENTITY_CHAIN> "
    else:
        entity_text = "||| "
    website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
    datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
    year_text = "Year: " + year + " | " if year != "" else "" 
    month_text = "Month: " + month + " | " if month != "" else ""  
    title_text = "Title: " + title + " | " if title != "" else ""

    final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
    
    model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
    tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True)
    bad_words_ids = tokenizer(["<ENTITY_CHAIN>", " </ENTITY_CHAIN> "]).input_ids

    inputs = tokenizer(final_prompt, return_tensors="pt")

    outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    

html = gr.Radio(["on", "off"], label="html", info="turn html as on or off")
entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities")
website_desc = gr.Textbox(placeholder="enter a website description", label="website description")
datasource = gr.Textbox(placeholder="enter a datasource", label="datasource")
year = gr.Textbox(placeholder="enter a year", label="year")
month = gr.Textbox(placeholder="enter a month", label="month")
title = gr.Textbox(placeholder="enter a website title", label="website title")
prompt = gr.Textbox(placeholder="enter a prompt", label="prompt")

demo = gr.Interface(
    fn=generate,
    inputs=[html, entity, website_desc, datasource, year, month, title, prompt],
    outputs="text",
)
demo.launch()