coder
first commit
7245dc6
raw
history blame contribute delete
No virus
1.62 kB
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from io import BytesIO
from PIL import Image
import torch
import requests
import matplotlib.pyplot as plt
class Generador():
def __init__(self, configuraciones):
self.modelo = configuraciones.get('model')
self.tokenizer = configuraciones.get('tokenizer')
def generar_prediccion(self, imagen_bytes, new_prompt):
respuestas = []
try:
# Inicializamos los procesadores y el modelo
procesador = AutoProcessor.from_pretrained(self.tokenizer)
modelo = CLIPSegForImageSegmentation.from_pretrained(self.modelo)
# Procesamos nuestra imagen y objetos
prompts = new_prompt.split(',')
inputs = procesador(text=prompts, images=[imagen_bytes] * len(prompts), padding=True, return_tensors="pt")
outputs = modelo(**inputs)
logits = outputs.logits
predicciones = outputs.logits.unsqueeze(1)
# Creamos un espacio para cada uno de los prompts
_, cajas = plt.subplots(1, len(prompts), figsize=(15, 4))
# por cada caja, agregamos una predicción
for indice, caja in enumerate(cajas.flatten()):
caja.axis('off')
_img = torch.sigmoid(predicciones[indice][0]).detach().numpy()
#caja.imshow(_img)
#caja.text(0, -15, prompts[indice])
respuestas.append(_img)
except Exception as error:
print(f"No es Chems\n{error}")
finally:
self.prediccion = respuestas