John6666's picture
Upload 43 files
cd39c08 verified
raw
history blame
2.66 kB
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
from PIL import Image
import torch
import re
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
fl_model = AutoModelForCausalLM.from_pretrained('thwri/CogFlorence-2.1-Large', trust_remote_code=True).eval().to("cpu").eval()
fl_processor = AutoProcessor.from_pretrained('thwri/CogFlorence-2.1-Large', trust_remote_code=True)
def modify_caption(caption: str) -> str:
special_patterns = [
(r'the image is ', ''),
(r'the image captures ', ''),
(r'the image showcases ', ''),
(r'the image shows ', ''),
(r'the image ', ''),
]
for pattern, replacement in special_patterns:
caption = re.sub(pattern, replacement, caption, flags=re.IGNORECASE)
caption = caption.replace('\n', '').replace('\r', '')
caption = re.sub(r'(?<=[.,?!])(?=[^\s])', r' ', caption)
caption = ' '.join(caption.strip().splitlines())
return caption
@spaces.GPU(duration=30)
def process_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image)
if image.mode != "RGB":
image = image.convert("RGB")
prompt = "<MORE_DETAILED_CAPTION>"
fl_model.to(device)
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = fl_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=True
)
fl_model.to("cpu")
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
return modify_caption(parsed_answer["<MORE_DETAILED_CAPTION>"])
def predict_tags_fl2_cog(image: Image.Image, input_tags: str, algo: list[str]):
def to_list(s):
return [x.strip() for x in s.split(",") if not s == ""]
def list_uniq(l):
return sorted(set(l), key=l.index)
if not "Use CogFlorence-2.1-Large" in algo:
return input_tags
tag_list = list_uniq(to_list(input_tags) + to_list(process_image(image) + ", "))
tag_list.remove("")
return ", ".join(tag_list)