vishalkatheriya18 commited on
Commit
2c867d4
1 Parent(s): 0bb757b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -40
app.py CHANGED
@@ -16,62 +16,52 @@ if 'models_loaded' not in st.session_state:
16
  st.session_state.models_loaded = True
17
 
18
  # Define image processing and classification functions
19
- def topwear(encoding,top_wear_model):
20
- # Make prediction
21
  outputs = st.session_state.top_wear_model(**encoding)
22
  logits = outputs.logits
23
  predicted_class_idx = logits.argmax(-1).item()
24
- # Print the result
25
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
26
 
27
- def patterns(encoding,pattern_model):
28
- # Make prediction
29
  outputs = st.session_state.pattern_model(**encoding)
30
  logits = outputs.logits
31
  predicted_class_idx = logits.argmax(-1).item()
32
- # Print the result
33
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
34
 
35
- def prints(encoding,print_model):
36
- # Make prediction
37
  outputs = st.session_state.print_model(**encoding)
38
  logits = outputs.logits
39
  predicted_class_idx = logits.argmax(-1).item()
40
- # Print the result
41
  return st.session_state.print_model.config.id2label[predicted_class_idx]
42
 
43
- def sleevelengths(encoding,sleeve_length_model):
44
- # Make prediction
45
  outputs = st.session_state.sleeve_length_model(**encoding)
46
  logits = outputs.logits
47
  predicted_class_idx = logits.argmax(-1).item()
48
- # Print the result
49
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
50
 
51
- def imageprocessing(url):
52
- response = requests.get(url)
53
- if response.status_code == 200:
54
- image = Image.open(BytesIO(response.content))
55
- encoding = image_processor(image.convert("RGB"), return_tensors="pt")
56
  return encoding
57
-
58
  # Define the function that will be used in each thread
59
- def call_model(func, encoding, model, results, index):
60
- results[index] = func(encoding, model)
 
61
  # Run all models in parallel
62
- def pipes(imagepath):
63
  # Process the image once and reuse the encoding
64
- encoding = imageprocessing(imagepath)
65
 
66
  # Prepare a list to store the results from each thread
67
  results = [None] * 4
68
 
69
  # Create threads for each function call
70
  threads = [
71
- threading.Thread(target=call_model, args=(topwear, encoding, top_wear_model, results, 0)),
72
- threading.Thread(target=call_model, args=(patterns, encoding, pattern_model, results, 1)),
73
- threading.Thread(target=call_model, args=(prints, encoding, print_model, results, 2)),
74
- threading.Thread(target=call_model, args=(sleevelengths, encoding, sleeve_length_model, results, 3)),
75
  ]
76
 
77
  # Start all threads
@@ -97,20 +87,19 @@ st.title("Clothing Classification Pipeline")
97
 
98
  url = st.text_input("Paste image URL here...")
99
  if url:
100
- response = requests.get(url)
101
- if response.status_code == 200:
102
- image = Image.open(BytesIO(response.content))
103
- encoding = image_processor(image.convert("RGB"), return_tensors="pt")
104
- st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
105
-
106
- start_time = time.time()
107
-
108
- try:
109
- result = pipes(url)
110
  st.write("Classification Results (JSON):")
111
  st.json(result) # Display results in JSON format
112
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
113
- except Exception as e:
114
- st.error(f"Error processing the image: {str(e)}")
115
- else:
116
- st.error("Failed to load image from URL. Please check the URL.")
 
16
  st.session_state.models_loaded = True
17
 
18
  # Define image processing and classification functions
19
+ def topwear(encoding):
 
20
  outputs = st.session_state.top_wear_model(**encoding)
21
  logits = outputs.logits
22
  predicted_class_idx = logits.argmax(-1).item()
 
23
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
24
 
25
+ def patterns(encoding):
 
26
  outputs = st.session_state.pattern_model(**encoding)
27
  logits = outputs.logits
28
  predicted_class_idx = logits.argmax(-1).item()
 
29
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
30
 
31
+ def prints(encoding):
 
32
  outputs = st.session_state.print_model(**encoding)
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
 
35
  return st.session_state.print_model.config.id2label[predicted_class_idx]
36
 
37
+ def sleevelengths(encoding):
 
38
  outputs = st.session_state.sleeve_length_model(**encoding)
39
  logits = outputs.logits
40
  predicted_class_idx = logits.argmax(-1).item()
 
41
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
42
 
43
+ def imageprocessing(image):
44
+ encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
 
 
 
45
  return encoding
46
+
47
  # Define the function that will be used in each thread
48
+ def call_model(func, encoding, results, index):
49
+ results[index] = func(encoding)
50
+
51
  # Run all models in parallel
52
+ def pipes(image):
53
  # Process the image once and reuse the encoding
54
+ encoding = imageprocessing(image)
55
 
56
  # Prepare a list to store the results from each thread
57
  results = [None] * 4
58
 
59
  # Create threads for each function call
60
  threads = [
61
+ threading.Thread(target=call_model, args=(topwear, encoding, results, 0)),
62
+ threading.Thread(target=call_model, args=(patterns, encoding, results, 1)),
63
+ threading.Thread(target=call_model, args=(prints, encoding, results, 2)),
64
+ threading.Thread(target=call_model, args=(sleevelengths, encoding, results, 3)),
65
  ]
66
 
67
  # Start all threads
 
87
 
88
  url = st.text_input("Paste image URL here...")
89
  if url:
90
+ try:
91
+ response = requests.get(url)
92
+ if response.status_code == 200:
93
+ image = Image.open(BytesIO(response.content))
94
+ st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
95
+
96
+ start_time = time.time()
97
+
98
+ result = pipes(image)
 
99
  st.write("Classification Results (JSON):")
100
  st.json(result) # Display results in JSON format
101
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
102
+ else:
103
+ st.error("Failed to load image from URL. Please check the URL.")
104
+ except Exception as e:
105
+ st.error(f"Error processing the image: {str(e)}")