osanseviero HF staff commited on
Commit
7a2c78d
1 Parent(s): d5b4d55

Fix imports

Browse files
Files changed (2) hide show
  1. pipeline.py +1 -5
  2. vqgan_jax/__init__.py +1 -0
pipeline.py CHANGED
@@ -73,14 +73,10 @@ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
73
 
74
  class PreTrainedPipeline():
75
  def __init__(self, path=""):
76
- # IMPLEMENT_THIS
77
- # Preload all the elements you are going to need at inference.
78
- # For instance your model, processors, tokenizer that might be needed.
79
- # This function is only called once, so do all the heavy processing I/O here"""
80
  self.tokenizer = BartTokenizer.from_pretrained(path)
81
  self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
82
 
83
- self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
84
 
85
 
86
  def __call__(self, inputs: str):
 
73
 
74
  class PreTrainedPipeline():
75
  def __init__(self, path=""):
76
+ self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
 
 
 
77
  self.tokenizer = BartTokenizer.from_pretrained(path)
78
  self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
79
 
 
80
 
81
 
82
  def __call__(self, inputs: str):
vqgan_jax/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *