wangjin2000 commited on
Commit
b2cb887
1 Parent(s): 4027612

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
  import torch
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -15,18 +15,23 @@ checkpoint2 = "wangjin2000/git-base-finetune"
15
  #model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
16
  model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)
17
 
 
 
18
  def img2cap_com(image):
19
  input1 = processor(images=image, return_tensors="pt").to(device)
20
  pixel_values1 = input1.pixel_values
21
  generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
22
  generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
23
-
 
 
24
  input2 = processor(images=image, return_tensors="pt").to(device)
25
  pixel_values2 = input2.pixel_values
26
  generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
27
  generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
 
28
 
29
- return generated_caption1,generated_caption2
30
 
31
  inputs = [
32
  gr.inputs.Image(type="pil", label="Original Image")
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
3
  import torch
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  #model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
16
  model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)
17
 
18
+ en_zh_translator = pipeline("translation_en_to_zh")
19
+
20
  def img2cap_com(image):
21
  input1 = processor(images=image, return_tensors="pt").to(device)
22
  pixel_values1 = input1.pixel_values
23
  generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
24
  generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
25
+ #translated_caption1 = en_zh_translator(generated_caption1)
26
+ translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)]
27
+
28
  input2 = processor(images=image, return_tensors="pt").to(device)
29
  pixel_values2 = input2.pixel_values
30
  generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
31
  generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
32
+ translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)]
33
 
34
+ return translated_caption1,translated_caption2
35
 
36
  inputs = [
37
  gr.inputs.Image(type="pil", label="Original Image")