File size: 2,226 Bytes
08caf3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
import os
import shutil
import requests

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer


def generate(html, entity, website_desc, datasource, year, month, title):
    html_text = "html | " if html == "on" else ""
    entity_text = ""
    if entity != "":
        ent_list = [x.strip() for x in entity.split(',')]
        for ent in ent_list:
            entity_text = entity_text + " |" + ent + "|"       
        entity_text = "entity ||| <ENTITY_CHAIN>"  + entity_text +  " </ENTITY_CHAIN> "
    else:
        entity_text = ""
    website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
    datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
    year_text = "Year: " + year + " | " if year != "" else "" 
    month_text = "Month: " + month + " | " if month != "" else ""  
    title_text = "Title: " + title + " | " if title != "" else ""

    prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text
    
    model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
    tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer")

    inputs = tokenizer(prompt, return_tensors="pt")

    outputs = model.generate(**inputs, max_new_tokens=128)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    

html = gr.Radio(["on", "off"], label="html", info="turn html as on or off")
entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities")
website_desc = gr.Textbox(placeholder="enter a website description", label="website description")
datasource = gr.Textbox(placeholder="enter a datasource", label="datasource")
year = gr.Textbox(placeholder="enter a year", label="year")
month = gr.Textbox(placeholder="enter a month", label="month")
title = gr.Textbox(placeholder="enter a website title", label="website title")

demo = gr.Interface(
    fn=generate,
    inputs=[html, entity, website_desc, datasource, year, month, title],
    outputs="text",
)
demo.launch()