davidaf3 commited on
Commit
6cc1e81
1 Parent(s): 0832cb7

Fixed pipeline errors

Browse files
Files changed (1) hide show
  1. pipeline.py +4 -3
pipeline.py CHANGED
@@ -11,12 +11,13 @@ import numpy as np
11
 
12
  class PreTrainedPipeline():
13
  def __init__(self, path=""):
14
- crop_size = (224, 224)
15
  embed_dim = 256
16
  num_layers = 3
17
  seq_length = 20
18
  hidden_dim = 1024
19
  num_heads = 8
 
 
20
  self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
21
  with open(os.path.join(path, "ingredients_metadata.json"), encoding='UTF-8') as f:
22
  self.ingredients = json.load(f)
@@ -25,7 +26,7 @@ class PreTrainedPipeline():
25
  self.seq_length = seq_length
26
 
27
  self.tfing = TFIng(
28
- crop_size,
29
  embed_dim,
30
  num_layers,
31
  seq_length,
@@ -38,7 +39,7 @@ class PreTrainedPipeline():
38
  self.tfing.load_weights(os.path.join(path, 'tfing.h5'))
39
 
40
  self.tfport = TFPort(
41
- crop_size,
42
  embed_dim,
43
  num_layers,
44
  num_layers,
 
11
 
12
  class PreTrainedPipeline():
13
  def __init__(self, path=""):
 
14
  embed_dim = 256
15
  num_layers = 3
16
  seq_length = 20
17
  hidden_dim = 1024
18
  num_heads = 8
19
+ self.crop_size = (224, 224)
20
+ self.img_size = 256
21
  self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
22
  with open(os.path.join(path, "ingredients_metadata.json"), encoding='UTF-8') as f:
23
  self.ingredients = json.load(f)
 
26
  self.seq_length = seq_length
27
 
28
  self.tfing = TFIng(
29
+ self.crop_size,
30
  embed_dim,
31
  num_layers,
32
  seq_length,
 
39
  self.tfing.load_weights(os.path.join(path, 'tfing.h5'))
40
 
41
  self.tfport = TFPort(
42
+ self.crop_size,
43
  embed_dim,
44
  num_layers,
45
  num_layers,