lele-cecere commited on
Commit
809d04b
1 Parent(s): 6dc6766

switched to non quantized mistral instruct

Browse files
Files changed (2) hide show
  1. examples_metadata.py +23 -0
  2. main.py +20 -19
examples_metadata.py CHANGED
@@ -263,6 +263,29 @@ bag_metadata = {
263
  },
264
  }
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  dress_example = f'''
267
  #### Input
268
  Metadata: {dress_metadata};
 
263
  },
264
  }
265
 
266
+ def clean_json(json_data):
267
+ """
268
+ Extracts the English description and English tags from the given JSON data.
269
+
270
+ :param json_data: A dictionary representing the JSON data of a fashion item.
271
+ :return: A dictionary containing the English description and a list of English tags.
272
+ """
273
+ important_data = {}
274
+
275
+ # Extracting the English description
276
+ description_en = next((desc['text'] for desc in json_data['descriptions'] if desc['language'] == 'en'), None)
277
+ important_data['description_en'] = description_en
278
+
279
+ # Extracting the English tags
280
+ tags_en = [tag['tag']['languages']['en'] for tag in json_data['tagsData']['tags'] if 'en' in tag['tag']['languages']]
281
+ important_data['tags_en'] = tags_en
282
+
283
+ return list(important_data.values())
284
+
285
+ dress_metadata = clean_json(dress_metadata)
286
+ bomber_metadata = clean_json(bomber_metadata)
287
+ bag_metadata = clean_json(bag_metadata)
288
+
289
  dress_example = f'''
290
  #### Input
291
  Metadata: {dress_metadata};
main.py CHANGED
@@ -27,7 +27,8 @@ from examples_metadata import (
27
  dress_example,
28
  bomber_example,
29
  )
30
- #init()
 
31
  logging.basicConfig(level=logging.DEBUG)
32
  os.system("pip list")
33
  #print Cuda version
@@ -126,17 +127,19 @@ def shortenMods(generator, res):
126
  return res[0]["generated_text"]
127
 
128
 
129
- # usare "You:" evita che il modello generi samples extra ma legga l'input
130
  def formatMods(generator, res):
131
- prompt = f"""
132
- Given as input a list like:
133
  -var 1
134
  -var 2
135
  -var 3...
136
 
137
- Return as output a list as:
138
  [var1, var2, var3, ...]
139
 
 
 
 
140
  Examples:
141
  {bomber_format_example}
142
  {shirt_format_example}
@@ -148,18 +151,17 @@ def formatMods(generator, res):
148
 
149
  Output:
150
 
151
- """
152
 
153
- prompt_template = f"""<s> [INST]
154
  {prompt}
155
  [/INST]
156
- """
157
- print("before inference")
158
- print_gpu_utilization()
159
- with torch.no_grad():
160
- res = generator(prompt_template)
161
- # print(res)
162
- return res[0]["generated_text"]
163
 
164
 
165
  def initModel(model_name_or_path, revision):
@@ -208,30 +210,29 @@ def initModel(model_name_or_path, revision):
208
  def generateTags():
209
  start = time.time()
210
  res = generateMods(generator, bag_metadata, dress_example, bomber_example)
 
211
  print("generation mods response:")
212
  print(res)
213
- stripped_res = remove_before_word(res, "[/INST]")
214
  shorten_res = shortenMods(generator, stripped_res)
215
  print("shortened response:")
216
  print(shorten_res)
217
  shorten_res = remove_before_word(shorten_res, "[/INST]")
218
  formatted_res = formatMods(generator, shorten_res)
219
- formatted_res = remove_before_word(formatted_res, "[/INST]")
220
  print("formatted response:")
221
  print(formatted_res)
 
222
  end = time.time()
223
  print("time spent for generating tags:", end - start)
224
  return {"response": stripped_res, "shortened response:": shorten_res, "formatted response": formatted_res}
225
 
226
-
227
  app = FastAPI()
228
 
229
 
230
  @app.on_event("startup")
231
  def load_model():
232
  global generator
233
- model_name_or_path = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
234
- revision = "gptq-8bit-128g-actorder_True"
235
  generator = initModel(model_name_or_path, revision)
236
  print("Model loaded")
237
 
 
27
  dress_example,
28
  bomber_example,
29
  )
30
+
31
+
32
  logging.basicConfig(level=logging.DEBUG)
33
  os.system("pip list")
34
  #print Cuda version
 
127
  return res[0]["generated_text"]
128
 
129
 
 
130
  def formatMods(generator, res):
131
+ prompt = f'''
132
+ I have a list like:
133
  -var 1
134
  -var 2
135
  -var 3...
136
 
137
+ Rewrite the list and put it in square brackets
138
  [var1, var2, var3, ...]
139
 
140
+ no code, just the list
141
+ It must begin with "[" and end with "]".
142
+
143
  Examples:
144
  {bomber_format_example}
145
  {shirt_format_example}
 
151
 
152
  Output:
153
 
154
+ '''
155
 
156
+ prompt_template=f'''<s> [INST]
157
  {prompt}
158
  [/INST]
159
+ '''
160
+ print("before inference")
161
+ print_gpu_utilization()
162
+ res = generator(prompt_template)
163
+ #print(res)
164
+ return res[0]['generated_text']
 
165
 
166
 
167
  def initModel(model_name_or_path, revision):
 
210
  def generateTags():
211
  start = time.time()
212
  res = generateMods(generator, bag_metadata, dress_example, bomber_example)
213
+ stripped_res = remove_before_word(res, "[/INST]")
214
  print("generation mods response:")
215
  print(res)
 
216
  shorten_res = shortenMods(generator, stripped_res)
217
  print("shortened response:")
218
  print(shorten_res)
219
  shorten_res = remove_before_word(shorten_res, "[/INST]")
220
  formatted_res = formatMods(generator, shorten_res)
 
221
  print("formatted response:")
222
  print(formatted_res)
223
+ formatted_res = remove_before_word(formatted_res, "[/INST]")
224
  end = time.time()
225
  print("time spent for generating tags:", end - start)
226
  return {"response": stripped_res, "shortened response:": shorten_res, "formatted response": formatted_res}
227
 
 
228
  app = FastAPI()
229
 
230
 
231
  @app.on_event("startup")
232
  def load_model():
233
  global generator
234
+ model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
235
+ revision = "main"
236
  generator = initModel(model_name_or_path, revision)
237
  print("Model loaded")
238