| | |
| | |
| |
|
| | from cog import BasePredictor, Input, Path |
| | import os |
| | from factories import UNet_conditional |
| | from wrapper import DiffusionManager, Schedule |
| | import torch |
| | import re |
| | from bert_vectorize import vectorize_text_with_bert |
| | from logger import save_grid_with_label |
| | import torchvision |
| | import time |
| |
|
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self) -> None: |
| | """Load the model into memory to make running multiple predictions efficient""" |
| | |
| | |
| | device = "cpu" |
| | model_dir = "runs/run_3_jxa" |
| | self.device = device |
| | self.model_dir = model_dir |
| | |
| | |
| | os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) |
| |
|
| | |
| | self.net = UNet_conditional(num_classes=768,device=device) |
| | self.net.to(self.device) |
| | self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False)) |
| |
|
| | |
| | self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) |
| | self.wrapper.set_schedule(Schedule.LINEAR) |
| |
|
| | def predict( |
| | self, |
| | prompt: str = Input(description="Text prompt"), |
| | amt: int = Input(description="Amt", default=8) |
| | ) -> Path: |
| | """Run a single prediction on the model""" |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) |
| |
|
| | generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() |
| |
|
| | return torchvision.utils.make_grid(generated).cpu().numpy() |
| |
|