Image2Paragraph / models /image_text_transformation.py
shideqin's picture
Update models/image_text_transformation.py
72762a4
raw
history blame contribute delete
No virus
3.08 kB
from models.blip2_model import ImageCaptioning
from models.grit_model import DenseCaptioning
from models.gpt_model import ImageToText
from models.controlnet_model import TextToImage
from models.region_semantic import RegionSemantic
from utils.util import read_image_width_height, display_images_and_text, resize_long_edge
import argparse
from PIL import Image
import base64
from io import BytesIO
import os
def pil_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
class ImageTextTransformation:
def __init__(self, args):
# Load your big model here
self.args = args
self.init_models()
self.ref_image = None
def init_models(self):
print(self.args)
print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device, captioner_base_model=self.args.captioner_base_model)
self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
self.gpt_model = ImageToText()
self.controlnet_model = TextToImage(device=self.args.contolnet_device)
self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device, image_caption_model=self.image_caption_model, region_classify_model=self.args.region_classify_model, sam_arch=self.args.sam_arch)
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
def image_to_text(self, img_src):
# the information to generate paragraph based on the context
self.ref_image = Image.open(img_src)
# resize image to long edge 384
self.ref_image = resize_long_edge(self.ref_image, 384)
width, height = read_image_width_height(img_src)
print(self.args)
if self.args.image_caption:
image_caption = self.image_caption_model.image_caption(img_src)
else:
image_caption = " "
if self.args.dense_caption:
dense_caption = self.dense_caption_model.image_dense_caption(img_src)
else:
dense_caption = " "
if self.args.semantic_segment:
region_semantic = self.region_semantic_model.region_semantic(img_src)
else:
region_semantic = " "
generated_text = self.gpt_model.paragraph_summary_with_gpt(image_caption, dense_caption, region_semantic, width, height)
return image_caption, dense_caption, region_semantic, generated_text
def text_to_image(self, text):
generated_image = self.controlnet_model.text_to_image(text, self.ref_image)
return generated_image
def text_to_image_retrieval(self, text):
pass
def image_to_text_retrieval(self, image):
pass