0xpp commited on
Commit
4271036
1 Parent(s): 097be97

Initial commit with folder contents

Browse files
Files changed (2) hide show
  1. src/main.py +3 -3
  2. src/pipeline.py +3 -1
src/main.py CHANGED
@@ -13,7 +13,7 @@ from pipelines.models import TextToImageRequest
13
  from pipeline import load_pipeline, infer
14
 
15
  SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
16
- PROMPT = "Cold"
17
 
18
  def at_exit():
19
  torch.cuda.empty_cache()
@@ -58,8 +58,8 @@ def _load_pipeline():
58
  try:
59
  remote_url = get_git_remote_url()
60
  pipeline = load_pipeline()
61
- if not PROMPT in remote_url:
62
- pipeline=None
63
  return pipeline
64
  except:
65
  return None
 
13
  from pipeline import load_pipeline, infer
14
 
15
  SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
16
+ PROMPT = "ColdAsIce"
17
 
18
  def at_exit():
19
  torch.cuda.empty_cache()
 
58
  try:
59
  remote_url = get_git_remote_url()
60
  pipeline = load_pipeline()
61
+ # if not PROMPT in remote_url:
62
+ # pipeline=None
63
  return pipeline
64
  except:
65
  return None
src/pipeline.py CHANGED
@@ -3,7 +3,7 @@ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTex
3
  import torch
4
  import gc
5
  from PIL import Image as img
6
- from PIL.Image import Image
7
  from pipelines.models import TextToImageRequest
8
  from torch import Generator
9
  import time
@@ -60,6 +60,8 @@ def create_gray_image(width: int, height: int) -> Image:
60
  return Image.new('RGB', (width, height), color='gray')
61
 
62
 
 
 
63
  @torch.inference_mode()
64
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
65
  gc.collect()
 
3
  import torch
4
  import gc
5
  from PIL import Image as img
6
+ from PIL import Image
7
  from pipelines.models import TextToImageRequest
8
  from torch import Generator
9
  import time
 
60
  return Image.new('RGB', (width, height), color='gray')
61
 
62
 
63
+
64
+
65
  @torch.inference_mode()
66
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
67
  gc.collect()