|
from transformers import AutoProcessor, CLIPSegForImageSegmentation |
|
from io import BytesIO |
|
from PIL import Image |
|
import torch |
|
import requests |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class Generador(): |
|
def __init__(self, configuraciones): |
|
self.modelo = configuraciones.get('model') |
|
self.tokenizer = configuraciones.get('tokenizer') |
|
|
|
def generar_prediccion(self, imagen_bytes, new_prompt): |
|
respuestas = [] |
|
try: |
|
|
|
procesador = AutoProcessor.from_pretrained(self.tokenizer) |
|
modelo = CLIPSegForImageSegmentation.from_pretrained(self.modelo) |
|
|
|
prompts = new_prompt.split(',') |
|
inputs = procesador(text=prompts, images=[imagen_bytes] * len(prompts), padding=True, return_tensors="pt") |
|
outputs = modelo(**inputs) |
|
logits = outputs.logits |
|
predicciones = outputs.logits.unsqueeze(1) |
|
|
|
|
|
_, cajas = plt.subplots(1, len(prompts), figsize=(15, 4)) |
|
|
|
for indice, caja in enumerate(cajas.flatten()): |
|
caja.axis('off') |
|
_img = torch.sigmoid(predicciones[indice][0]).detach().numpy() |
|
|
|
|
|
respuestas.append(_img) |
|
|
|
except Exception as error: |
|
print(f"No es Chems\n{error}") |
|
finally: |
|
self.prediccion = respuestas |
|
|
|
|
|
|
|
|