the-future-dev commited on
Commit
b5dc2f8
1 Parent(s): b688905

first commit

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. kosmos.py +6 -4
app.py CHANGED
@@ -11,13 +11,13 @@ if uploaded_file is not None:
11
  image = Image.open(uploaded_file)
12
  st.image(image, caption='Uploaded Image.', use_column_width=True)
13
  st.write("")
14
- num_characters = st.text_input("Number of characters", "20")
15
  prompt = st.chat_input("Detect the main object in the image. The image is a")
16
- num_characters = int(num_characters)
17
  if prompt:
18
  st.write(f"User: {prompt}")
19
  with st.spinner('Processing...'):
20
- label = kosmos.single_image_classification(image, prompt, num_characters)
21
  st.write(f"Model: {label}")
22
  except Exception as e:
23
  st.error(f"An error occurred: {e}")
 
11
  image = Image.open(uploaded_file)
12
  st.image(image, caption='Uploaded Image.', use_column_width=True)
13
  st.write("")
14
+ num_tokens = st.text_input("Number of new tokens", "20")
15
  prompt = st.chat_input("Detect the main object in the image. The image is a")
16
+ num_tokens = int(num_tokens)
17
  if prompt:
18
  st.write(f"User: {prompt}")
19
  with st.spinner('Processing...'):
20
+ label = kosmos.single_image_classification(image, prompt, num_tokens)
21
  st.write(f"Model: {label}")
22
  except Exception as e:
23
  st.error(f"An error occurred: {e}")
kosmos.py CHANGED
@@ -21,7 +21,9 @@ def single_image_classification(image, prompt="", max_new_tokens=30):
21
  )
22
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
23
 
24
- caption, entities = processor.post_process_generation(generated_text)
25
- print("ROBOT:", caption)
26
- print("ENTITIES: ", entities)
27
- return caption
 
 
 
21
  )
22
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
23
 
24
+ print("GENERATED:", generated_text)
25
+
26
+ processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
27
+
28
+ print("PROCESSED:", processed_text)
29
+ return processed_text