manandey's picture
Update app.py
e0a6493
raw history blame
No virus
2.33 kB
import json
import os
import shutil
import requests
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate(html, entity, website_desc, datasource, year, month, title, prompt):
html_text = "html | " if html == "on" else ""
entity_text = ""
if entity != "":
ent_list = [x.strip() for x in entity.split(',')]
for ent in ent_list:
entity_text = entity_text + " |" + ent + "|"
entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
else:
entity_text = ""
website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
year_text = "Year: " + year + " | " if year != "" else ""
month_text = "Month: " + month + " | " if month != "" else ""
title_text = "Title: " + title + " | " if title != "" else ""
final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer")
inputs = tokenizer(final_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=128)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
html = gr.Radio(["on", "off"], label="html", info="turn html as on or off")
entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities")
website_desc = gr.Textbox(placeholder="enter a website description", label="website description")
datasource = gr.Textbox(placeholder="enter a datasource", label="datasource")
year = gr.Textbox(placeholder="enter a year", label="year")
month = gr.Textbox(placeholder="enter a month", label="month")
title = gr.Textbox(placeholder="enter a website title", label="website title")
prompt = gr.Textbox(placeholder="enter a prompt", label="prompt")
demo = gr.Interface(
fn=generate,
inputs=[html, entity, website_desc, datasource, year, month, title, prompt],
outputs="text",
)
demo.launch()