coder
first commit
7245dc6
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from io import BytesIO
from PIL import Image
import torch
import requests
import matplotlib.pyplot as plt
class Generador():
def __init__(self, configuraciones):
self.modelo = configuraciones.get('model')
self.tokenizer = configuraciones.get('tokenizer')
def generar_prediccion(self, imagen_bytes, new_prompt):
respuestas = []
try:
# Inicializamos los procesadores y el modelo
procesador = AutoProcessor.from_pretrained(self.tokenizer)
modelo = CLIPSegForImageSegmentation.from_pretrained(self.modelo)
# Procesamos nuestra imagen y objetos
prompts = new_prompt.split(',')
inputs = procesador(text=prompts, images=[imagen_bytes] * len(prompts), padding=True, return_tensors="pt")
outputs = modelo(**inputs)
logits = outputs.logits
predicciones = outputs.logits.unsqueeze(1)
# Creamos un espacio para cada uno de los prompts
_, cajas = plt.subplots(1, len(prompts), figsize=(15, 4))
# por cada caja, agregamos una predicción
for indice, caja in enumerate(cajas.flatten()):
caja.axis('off')
_img = torch.sigmoid(predicciones[indice][0]).detach().numpy()
#caja.imshow(_img)
#caja.text(0, -15, prompts[indice])
respuestas.append(_img)
except Exception as error:
print(f"No es Chems\n{error}")
finally:
self.prediccion = respuestas