Avijit Ghosh commited on
Commit
ab041ea
1 Parent(s): 8a06c57

add gpu wrapper

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
 
13
 
14
  # Define model initialization functions
15
  def load_model(model_name):
@@ -47,27 +48,29 @@ def load_model(model_name):
47
  raise ValueError("Unknown model name")
48
  return pipeline
49
 
50
- # choices=[
51
- # "stabilityai/sdxl-turbo",
52
- # "runwayml/stable-diffusion-v1-5",
53
- # "ByteDance/SDXL-Lightning",
54
- # "segmind/SSD-1B"
55
- # ]
56
 
57
- # for model_name in choices:
58
- # load_model(model_name)
59
 
60
  # Initialize the default model
61
  default_model = "stabilityai/sdxl-turbo"
62
 
63
  pipeline_text2image = load_model(default_model)
64
 
 
65
  def getimgen(prompt):
66
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
67
 
68
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
69
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
70
 
 
71
  def blip_caption_image(image, prefix):
72
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
73
  out = blip_model.generate(**inputs)
@@ -106,6 +109,7 @@ def skintoneplot(hex_codes):
106
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
107
  return fig
108
 
 
109
  def generate_images_plots(prompt, model_name):
110
  global pipeline_text2image
111
  pipeline_text2image = load_model(model_name)
 
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
13
+ import spaces
14
 
15
  # Define model initialization functions
16
  def load_model(model_name):
 
48
  raise ValueError("Unknown model name")
49
  return pipeline
50
 
51
+ choices=[
52
+ "stabilityai/sdxl-turbo",
53
+ "runwayml/stable-diffusion-v1-5",
54
+ "ByteDance/SDXL-Lightning",
55
+ "segmind/SSD-1B"
56
+ ]
57
 
58
+ for model_name in choices:
59
+ load_model(model_name)
60
 
61
  # Initialize the default model
62
  default_model = "stabilityai/sdxl-turbo"
63
 
64
  pipeline_text2image = load_model(default_model)
65
 
66
+ @spaces.GPU
67
  def getimgen(prompt):
68
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
69
 
70
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
71
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
72
 
73
+ @spaces.GPU
74
  def blip_caption_image(image, prefix):
75
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
76
  out = blip_model.generate(**inputs)
 
109
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
110
  return fig
111
 
112
+ @spaces.GPU
113
  def generate_images_plots(prompt, model_name):
114
  global pipeline_text2image
115
  pipeline_text2image = load_model(model_name)