LanHarmony commited on
Commit
1dc205a
1 Parent(s): 1e59525

update text2image

Browse files
Files changed (1) hide show
  1. visual_foundation_models.py +11 -10
visual_foundation_models.py CHANGED
@@ -137,26 +137,27 @@ class InstructPix2Pix:
137
 
138
  class Text2Image:
139
  def __init__(self, device):
140
- print("Initializing Text2Image to %s" % device)
141
  self.device = device
142
- self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16)
143
- self.text_refine_tokenizer = AutoTokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
144
- self.text_refine_model = AutoModelForCausalLM.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
145
- self.text_refine_gpt2_pipe = pipeline("text-generation", model=self.text_refine_model,
146
- tokenizer=self.text_refine_tokenizer, device=self.device)
147
  self.pipe.to(device)
 
 
 
148
 
149
  @prompts(name="Generate Image From User Input Text",
150
  description="useful when you want to generate an image from a user input text and save it to a file. "
151
  "like: generate an image of an object or something, or generate an image that includes some objects. "
152
  "The input to this tool should be a string, representing the text used to generate image. ")
153
  def inference(self, text):
154
- image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
155
- refined_text = self.text_refine_gpt2_pipe(text)[0]["generated_text"]
156
- image = self.pipe(refined_text).images[0]
157
  image.save(image_filename)
158
  print(
159
- f"\nProcessed Text2Image, Input Text: {text}, Refined Text: {refined_text}, Output Image: {image_filename}")
160
  return image_filename
161
 
162
 
 
137
 
138
  class Text2Image:
139
  def __init__(self, device):
140
+ print(f"Initializing Text2Image to {device}")
141
  self.device = device
142
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
143
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
144
+ torch_dtype=self.torch_dtype)
 
 
145
  self.pipe.to(device)
146
+ self.a_prompt = 'best quality, extremely detailed'
147
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
148
+ 'fewer digits, cropped, worst quality, low quality'
149
 
150
  @prompts(name="Generate Image From User Input Text",
151
  description="useful when you want to generate an image from a user input text and save it to a file. "
152
  "like: generate an image of an object or something, or generate an image that includes some objects. "
153
  "The input to this tool should be a string, representing the text used to generate image. ")
154
  def inference(self, text):
155
+ image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
156
+ prompt = text + ', ' + self.a_prompt
157
+ image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
158
  image.save(image_filename)
159
  print(
160
+ f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
161
  return image_filename
162
 
163