prateekbh's picture
Update app.py
5da11b3 verified
import gradio as gr
import torch
import numpy as np
import torch.nn.functional as F
import PIL
import random
from threading import Thread
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download, InferenceClient
from briarmbg import BriaRMBG
from PIL import Image
from typing import Tuple
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [151645]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>modelo"
return prompt
def getProductTitle(history, context, image):
product_description=getImageDescription(image)
prompt="We have a product which is a" + context + ". Product description is as follows: " + product_description + ". Please write a product title options for it."
yield interactWithModel(history, prompt)
def getProductDescription(history):
prompt="Please also write an SEO friendly description for it describing its value to its users."
yield interactWithModel(history, prompt)
def interactWithModel(history, prompt):
system_prompt="You're a helpful e-commerce marketing assitant working on art products."
client = InferenceClient("google/gemma-7b-it")
rand_val = random.randint(1, 1111111111111111)
if not history:
history = []
generate_kwargs = dict(
temperature=0.67,
max_new_tokens=1024,
top_p=0.9,
repetition_penalty=1,
do_sample=True,
seed=rand_val,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
history.append((prompt, output))
return history
@torch.no_grad()
def getImageDescription(image):
message = "Generate an ecommerce product description for the image"
stop = StopOnTokens()
messages = [{"role": "system", "content": "You are a helpful assistant."}]
if len(messages) == 1:
message = f" <image>{message}"
messages.append({"role": "user", "content": message})
model_inputs = processor.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
image = (
processor.feature_extractor(image)
.unsqueeze(0)
)
attention_mask = torch.ones(
1, model_inputs.shape[1] + processor.num_image_latents - 1
)
model_inputs = {
"input_ids": model_inputs,
"images": image,
"attention_mask": attention_mask
}
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# history.append([message, ""])
partial_response = ""
for new_token in streamer:
partial_response += new_token
# history[-1][1] = partial_response
# yield history
return partial_response
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = image
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="col-container"):
image = gr.Image(type="pil")
output = gr.Image(type="pil", interactive=False, label="Without background")
context = gr.Textbox(label="Small description")
submit = gr.Button(value="Upload", variant="primary")
with gr.Column():
chat = gr.Chatbot(show_label=False)
user_input= gr.Textbox()
send = gr.Button(value="Send")
title_handler = (
getProductTitle,
[chat, context, image],
[chat]
)
description_handler = (
getProductDescription,
[chat],
[chat]
)
interaction_handler = (
interactWithModel,
[chat, user_input],
[chat]
)
background_remover_handler = (
process,
[image],
[output]
)
# postresponse_handler = (
# lambda: (gr.Button(visible=False), gr.Button(visible=True)),
# None,
# [submit]
# )
submit.click(*title_handler).then(*description_handler)
submit.click(*background_remover_handler)
send.click(*interaction_handler)
# event.then(*postresponse_handler)
demo.launch(share=True)