IvaElen commited on
Commit
418bba9
·
1 Parent(s): af28313

Update pages/ImageToText.py

Browse files
Files changed (1) hide show
  1. pages/ImageToText.py +22 -16
pages/ImageToText.py CHANGED
@@ -11,33 +11,39 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  vitgpt_model.to(device)
13
 
14
- def generate_caption(processor, model, image, tokenizer=None):
15
 
16
  inputs = processor(images=image, return_tensors="pt").to(device)
17
  generated_ids = model.generate(pixel_values=inputs.pixel_values,
18
- max_length=100,
19
  num_beams=5,
20
  do_sample=True,
21
- temperature=1.,
22
- top_k=50,
23
- top_p=0.6,
24
- no_repeat_ngram_size=3,
25
- num_return_sequences=3)
26
 
27
  if tokenizer is not None:
28
- generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
  else:
30
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
31
  return generated_caption
32
 
33
- def generate_captions(image):
34
- caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
35
  return caption_vitgpt
36
 
 
37
  uploaded_file = st.file_uploader("Upload your image")
 
 
38
  if uploaded_file is not None:
39
- image = Image.open(uploaded_file)
40
- st.image(image)
41
- generated_caption = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
42
- st.write(generated_caption)
43
-
 
 
 
 
 
11
 
12
  vitgpt_model.to(device)
13
 
14
+ def generate_caption(processor, model, image, tokenizer=None, num_seq):
15
 
16
  inputs = processor(images=image, return_tensors="pt").to(device)
17
  generated_ids = model.generate(pixel_values=inputs.pixel_values,
18
+ max_length=50,
19
  num_beams=5,
20
  do_sample=True,
21
+ temperature=2.,
22
+ top_k = 20,
23
+ no_repeat_ngram_size=5,
24
+ num_return_sequences=num_seq)
 
25
 
26
  if tokenizer is not None:
27
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
28
  else:
29
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
30
  return generated_caption
31
 
32
+ def generate_captions(image, num_seq):
33
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer, num_seq)
34
  return caption_vitgpt
35
 
36
+ st.title('Generate text to your image')
37
  uploaded_file = st.file_uploader("Upload your image")
38
+ num_seq = st.slider('Return sequences quantity', 1, 5, 2)
39
+
40
  if uploaded_file is not None:
41
+ if st.button('Generate!'):
42
+ col1, col2 = st.columns(2)
43
+ with col1:
44
+ image = Image.open(uploaded_file)
45
+ st.image(image)
46
+ with col2:
47
+ generated_caption = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
48
+ for i in generated_caption:
49
+ st.write(i)