user-agent commited on
Commit
cec1e53
·
verified ·
1 Parent(s): 8a34bf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -27
app.py CHANGED
@@ -46,26 +46,34 @@ ATTRIBUTES_DICT = attributes_data['attribute_mapping']
46
  def shot(input, category, level):
47
  output_dict = {}
48
  if level == 'variant':
49
- subColour,mainColour,score = get_colour(ast.literal_eval(str(input)),category)
50
  openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
51
  face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
52
- cropped_images = get_cropped_images(ast.literal_eval(str(input)),category)
 
 
53
  output_dict['colors'] = {
54
- "main":mainColour,
55
- "sub":subColour,
56
- "score":score
57
  }
58
  output_dict['image_mapping'] = openai_parsed_response
59
  output_dict['face_embeddings'] = face_embeddings
60
  output_dict['cropped_images'] = cropped_images
61
 
62
-
63
  if level == 'product':
64
- common_result = get_predicted_attributes(ast.literal_eval(str(input)),category)
65
  output_dict['attributes'] = common_result
66
  output_dict['subcategory'] = category
67
 
68
- return json.dumps(output_dict)
 
 
 
 
 
 
 
69
 
70
 
71
 
@@ -411,35 +419,55 @@ def encode_images_to_base64(cropped_list):
411
  return base64_images
412
 
413
 
414
- def get_cropped_images(images,category):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  cropped_list = []
416
  resultsPerCategory = {}
417
  for num, image in enumerate(images):
418
  image = open_image_from_url(image)
419
  class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37)
420
- print(cropped_images)
421
  if not class_counts:
422
  continue
423
 
424
- # Get the inverse category as any other mapping label except the current one corresponding category
425
- inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 0]
426
-
427
- # If category is a cardigan, we don't recommend category indices 1 and 3
428
- if category == 'women-sweatersknits-cardigan':
429
- inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 1 and i != 3]
430
-
431
  for i, image in enumerate(cropped_images):
432
- cropped_category = cropped_classes[i]
433
- print(cropped_category, cropped_classes[i], get_category_index(category))
434
-
435
- specific_category = label_mapping[cropped_category]
436
-
437
- if cropped_category == get_category_index(category):
438
- continue
439
-
440
  cropped_list.append(image)
441
-
442
-
443
  base64_images = encode_images_to_base64(cropped_list)
444
 
445
  return base64_images
@@ -447,6 +475,7 @@ def get_cropped_images(images,category):
447
 
448
 
449
 
 
450
  # Define the Gradio interface with the updated components
451
  iface = gr.Interface(
452
  fn=shot,
 
46
  def shot(input, category, level):
47
  output_dict = {}
48
  if level == 'variant':
49
+ subColour, mainColour, score = get_colour(ast.literal_eval(str(input)), category)
50
  openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
51
  face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
52
+ cropped_images = get_cropped_images(ast.literal_eval(str(input)), category)
53
+
54
+ # Ensure all outputs are JSON serializable
55
  output_dict['colors'] = {
56
+ "main": mainColour,
57
+ "sub": subColour,
58
+ "score": score
59
  }
60
  output_dict['image_mapping'] = openai_parsed_response
61
  output_dict['face_embeddings'] = face_embeddings
62
  output_dict['cropped_images'] = cropped_images
63
 
 
64
  if level == 'product':
65
+ common_result = get_predicted_attributes(ast.literal_eval(str(input)), category)
66
  output_dict['attributes'] = common_result
67
  output_dict['subcategory'] = category
68
 
69
+ # Convert the dictionary to a JSON-serializable format
70
+ try:
71
+ serialized_output = json.dumps(output_dict)
72
+ except TypeError as e:
73
+ print(f"Serialization Error: {e}")
74
+ return {"error": "Serialization failed"}
75
+
76
+ return serialized_output
77
 
78
 
79
 
 
419
  return base64_images
420
 
421
 
422
+ # def get_cropped_images(images,category):
423
+ # cropped_list = []
424
+ # resultsPerCategory = {}
425
+ # for num, image in enumerate(images):
426
+ # image = open_image_from_url(image)
427
+ # class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37)
428
+ # if not class_counts:
429
+ # continue
430
+
431
+ # # Get the inverse category as any other mapping label except the current one corresponding category
432
+ # inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 0]
433
+
434
+ # # If category is a cardigan, we don't recommend category indices 1 and 3
435
+ # if category == 'women-sweatersknits-cardigan':
436
+ # inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 1 and i != 3]
437
+
438
+ # for i, image in enumerate(cropped_images):
439
+ # cropped_category = cropped_classes[i]
440
+ # print(cropped_category, cropped_classes[i], get_category_index(category))
441
+
442
+ # specific_category = label_mapping[cropped_category]
443
+
444
+ # if cropped_category == get_category_index(category):
445
+ # continue
446
+
447
+ # cropped_list.append(image)
448
+
449
+
450
+ # base64_images = encode_images_to_base64(cropped_list)
451
+
452
+ # return base64_images
453
+
454
+
455
+
456
+
457
+ def get_cropped_images(images, category):
458
  cropped_list = []
459
  resultsPerCategory = {}
460
  for num, image in enumerate(images):
461
  image = open_image_from_url(image)
462
  class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37)
463
+
464
  if not class_counts:
465
  continue
466
 
 
 
 
 
 
 
 
467
  for i, image in enumerate(cropped_images):
 
 
 
 
 
 
 
 
468
  cropped_list.append(image)
469
+
470
+ # Convert cropped images to base64 strings
471
  base64_images = encode_images_to_base64(cropped_list)
472
 
473
  return base64_images
 
475
 
476
 
477
 
478
+
479
  # Define the Gradio interface with the updated components
480
  iface = gr.Interface(
481
  fn=shot,