Upload multit2i.py
Browse files- multit2i.py +5 -5
multit2i.py
CHANGED
@@ -104,7 +104,7 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
|
|
104 |
def load_model(model_name: str):
|
105 |
global loaded_models
|
106 |
global model_info_dict
|
107 |
-
if model_name in loaded_models.keys(): return model_name
|
108 |
try:
|
109 |
loaded_models[model_name] = gr.load(f'models/{model_name}')
|
110 |
print(f"Loaded: {model_name}")
|
@@ -112,13 +112,13 @@ def load_model(model_name: str):
|
|
112 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
113 |
print(f"Failed to load: {model_name}")
|
114 |
print(e)
|
115 |
-
return
|
116 |
try:
|
117 |
model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
|
118 |
except Exception as e:
|
119 |
if model_name in model_info_dict.keys(): del model_info_dict[model_name]
|
120 |
print(e)
|
121 |
-
return model_name
|
122 |
|
123 |
|
124 |
async def async_load_models(models: list, limit: int=5):
|
@@ -163,12 +163,12 @@ def infer(prompt: str, model_name: str, recom_prompt: bool, progress=gr.Progress
|
|
163 |
caption = model_name.split("/")[-1]
|
164 |
try:
|
165 |
model = load_model(model_name)
|
166 |
-
if not model: return (
|
167 |
image_path = model(prompt + rprompt + seed)
|
168 |
image = Image.open(image_path).convert('RGB')
|
169 |
except Exception as e:
|
170 |
print(e)
|
171 |
-
return (
|
172 |
return (image, caption)
|
173 |
|
174 |
|
|
|
104 |
def load_model(model_name: str):
|
105 |
global loaded_models
|
106 |
global model_info_dict
|
107 |
+
if model_name in loaded_models.keys(): return loaded_models[model_name]
|
108 |
try:
|
109 |
loaded_models[model_name] = gr.load(f'models/{model_name}')
|
110 |
print(f"Loaded: {model_name}")
|
|
|
112 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
113 |
print(f"Failed to load: {model_name}")
|
114 |
print(e)
|
115 |
+
return None
|
116 |
try:
|
117 |
model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
|
118 |
except Exception as e:
|
119 |
if model_name in model_info_dict.keys(): del model_info_dict[model_name]
|
120 |
print(e)
|
121 |
+
return loaded_models[model_name]
|
122 |
|
123 |
|
124 |
async def async_load_models(models: list, limit: int=5):
|
|
|
163 |
caption = model_name.split("/")[-1]
|
164 |
try:
|
165 |
model = load_model(model_name)
|
166 |
+
if not model: return (Image.Image(), None)
|
167 |
image_path = model(prompt + rprompt + seed)
|
168 |
image = Image.open(image_path).convert('RGB')
|
169 |
except Exception as e:
|
170 |
print(e)
|
171 |
+
return (Image.Image(), None)
|
172 |
return (image, caption)
|
173 |
|
174 |
|