File size: 1,667 Bytes
f1b2ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from wandb.integration.diffusers import autolog
from diffusers import DiffusionPipeline
import torch
from config import PROJECT_NAME
autolog(init=dict(project=PROJECT_NAME))

class SDXLLoraInference:
    """
    Class for running inference using the SDXL-LoRA model to generate stunning product photographs.
    
    Args:
        num_inference_steps (int): The number of inference steps to perform.
        guidance_scale (float): The scale factor for guidance during inference.
    """
    def __init__(self, num_inference_steps: int, guidance_scale: float) -> None:
        self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
        self.pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
        self.pipe.to("cuda")
        self.pipe.load_lora_weights(self.model_path)
        self.num_inference_steps = num_inference_steps
        self.guidance_scale = guidance_scale

    def run_inference(self, prompt):
        """
        Runs inference using the SDXL-LoRA model to generate a stunning product photograph.
        
        Args:
            prompt: The input prompt for generating the product photograph.
        
        Returns:
            images: The generated product photograph(s).
        """
        
        prompt = prompt
        images = self.pipe(prompt, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale).images
        return images

inference = SDXLLoraInference(num_inference_steps=100, guidance_scale=2.5)
inference.run_inference(prompt= "A stunning 4k Shot of a Balenciaga X Anime Hoodie with a person wearing it in a party" )