import json import os import shutil import requests import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer def generate(html, entity, website_desc, datasource, year, month, title, prompt): html_text = "html | " if html == "on" else "" entity_text = "" if entity != "": ent_list = [x.strip() for x in entity.split(',')] for ent in ent_list: entity_text = entity_text + " |" + ent + "|" entity_text = "entity ||| " + entity_text + " " else: entity_text = "||| " website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else "" datasource_text = "Datasource: " + datasource + " | " if datasource != "" else "" year_text = "Year: " + year + " | " if year != "" else "" month_text = "Month: " + month + " | " if month != "" else "" title_text = "Title: " + title + " | " if title != "" else "" final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step") tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True) bad_words_ids = tokenizer(["", " "]).input_ids inputs = tokenizer(final_prompt, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids) return tokenizer.batch_decode(outputs, skip_special_tokens=True) html = gr.Radio(["on", "off"], label="html", info="turn html as on or off") entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities") website_desc = gr.Textbox(placeholder="enter a website description", label="website description") datasource = gr.Textbox(placeholder="enter a datasource", label="datasource") year = gr.Textbox(placeholder="enter a year", label="year") month = gr.Textbox(placeholder="enter a month", label="month") title = gr.Textbox(placeholder="enter a website title", label="website title") prompt = gr.Textbox(placeholder="enter a prompt", label="prompt") demo = gr.Interface( fn=generate, inputs=[html, entity, website_desc, datasource, year, month, title, prompt], outputs="text", ) demo.launch()