|
import gradio as gr |
|
import torch |
|
import os |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
|
model_id = 'ksabeh/qpave' |
|
max_input_length = 512 |
|
max_target_length = 20 |
|
auth_token = os.environ.get('TOKEN') |
|
|
|
model = T5ForConditionalGeneration.from_pretrained(model_id, use_auth_token=auth_token) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) |
|
|
|
def predict(cg_attribute, text, fg_attribute, category): |
|
input = f"{fg_attribute}: {text}" |
|
model_input = tokenizer(input, max_length=max_input_length, truncation=True, |
|
padding="max_length") |
|
model_input = {k:torch.unsqueeze(torch.tensor(v),dim=0) for k,v in model_input.items()} |
|
predictions = model.generate(**model_input, num_beams=4, do_sample=True, max_length=10) |
|
return tokenizer.batch_decode(predictions, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
predict, |
|
[ |
|
gr.Textbox( |
|
label = "Coarse-grained Attribute", |
|
info = "The coarse-grained attribute name", |
|
lines = 1, |
|
), |
|
gr.Textbox( |
|
label = "Context", |
|
info = "The value of the coarse-grained attribute", |
|
lines = 1, |
|
), |
|
gr.Textbox( |
|
label = "Fine-grained Attribute", |
|
info = "The target fine-grained attribute name", |
|
lines = 1, |
|
), |
|
gr.Textbox( |
|
label = "Category", |
|
info = "The product category", |
|
lines = 1, |
|
) |
|
], |
|
"textbox", |
|
title="QPAVE", |
|
examples=[["Processor", "3ghz intel core i5", "Brand Name", "Computers & Tablets"], |
|
["Special Feature", "Electric Razor for Men,Beard Trimmer,Rechargeable,Wet and Dry,Cordles", "Uses", "Electric Shavers"], |
|
["Special Feature", "Electric Razor for Men,Beard Trimmer,Rechargeable,Wet and Dry,Cordles", "Skin Type", "Electric Shavers"], |
|
["Color", "Black Foil Razor", "Head Type", "Electric Shavers"], |
|
["Color", "2 ustzc-541 black with silver wood", "Material Type", "Office Electronics Accessories"], |
|
["Brand Name", "beiter gray battery power", "Power Source", "Laptop Accessories"], |
|
['Color', '14.5-inch red color spectrum', 'Size', 'Novelty Lighting'], |
|
['Fixture Features', 'wattage 21w type 1200 mm input end', 'Wattage', 'Fluorescent Tubes'], |
|
['Fixture Features', 'wattage 21w type 1200 mm input end', 'Size', 'Fluorescent Tubes'] |
|
] |
|
) |
|
|
|
demo.launch() |