Files changed (2) hide show
  1. app.py +5 -4
  2. requirements.txt +4 -3
app.py CHANGED
@@ -1,23 +1,24 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  model_name = "Lin-Chen/ShareCaptioner"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
- model_name, device_map="cpu", trust_remote_code=True).eval()
10
  model.tokenizer = tokenizer
11
 
12
  model.cuda()
13
- model.half()
14
 
15
  seg1 = '<|User|>:'
16
  seg2 = f'Analyze the image in a comprehensive and detailed manner.{model.eoh}\n<|Bot|>:'
17
- seg_emb1 = model.encode_text(seg1, add_special_tokens=True)
18
- seg_emb2 = model.encode_text(seg2, add_special_tokens=False)
19
 
20
 
 
21
  def detailed_caption(img_path):
22
  subs = []
23
  image = Image.open(img_path).convert("RGB")
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
  from PIL import Image
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  model_name = "Lin-Chen/ShareCaptioner"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
  model = AutoModelForCausalLM.from_pretrained(
10
+ model_name, device_map="cpu", torch_dtype=torch.float16, trust_remote_code=True).eval()
11
  model.tokenizer = tokenizer
12
 
13
  model.cuda()
 
14
 
15
  seg1 = '<|User|>:'
16
  seg2 = f'Analyze the image in a comprehensive and detailed manner.{model.eoh}\n<|Bot|>:'
17
+ seg_emb1 = model.encode_text(seg1, add_special_tokens=True).cuda()
18
+ seg_emb2 = model.encode_text(seg2, add_special_tokens=False).cuda()
19
 
20
 
21
+ @spaces.GPU
22
  def detailed_caption(img_path):
23
  subs = []
24
  image = Image.open(img_path).convert("RGB")
requirements.txt CHANGED
@@ -4,10 +4,11 @@ tiktoken==0.5.1
4
  einops==0.7.0
5
  transformers_stream_generator==0.0.4
6
  scipy==1.11.3
7
- torchvision==0.15.2
 
8
  pillow==10.0.1
9
  matplotlib==3.8.0
10
- gradio==3.50.2
11
  sentencepiece
12
  urllib3==1.26.18
13
- timm==0.6.13
 
 
4
  einops==0.7.0
5
  transformers_stream_generator==0.0.4
6
  scipy==1.11.3
7
+ torch==2.1.2
8
+ torchvision==0.16.2
9
  pillow==10.0.1
10
  matplotlib==3.8.0
 
11
  sentencepiece
12
  urllib3==1.26.18
13
+ timm==1.0.3
14
+ spaces