liltom-eth commited on
Commit
88e86b9
1 Parent(s): 39f8585

Delete https:

Browse files
https:/huggingface.co/anymodality/llava-v1.5-13b/tree/main/code/inference.py DELETED
@@ -1,74 +0,0 @@
1
- import requests
2
- from PIL import Image
3
- from io import BytesIO
4
- import torch
5
- from transformers import AutoTokenizer
6
-
7
- from llava.model import LlavaLlamaForCausalLM
8
- from llava.utils import disable_torch_init
9
- from llava.constants import IMAGE_TOKEN_INDEX
10
- from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
11
-
12
-
13
- def model_fn(model_dir):
14
- kwargs = {"device_map": "auto"}
15
- kwargs["torch_dtype"] = torch.float16
16
- model = LlavaLlamaForCausalLM.from_pretrained(
17
- model_dir, low_cpu_mem_usage=True, **kwargs
18
- )
19
- tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
20
-
21
- vision_tower = model.get_vision_tower()
22
- if not vision_tower.is_loaded:
23
- vision_tower.load_model()
24
- vision_tower.to(device="cuda", dtype=torch.float16)
25
- image_processor = vision_tower.image_processor
26
- return model, tokenizer, image_processor
27
-
28
-
29
- def predict_fn(data, model_and_tokenizer):
30
- # unpack model and tokenizer
31
- model, tokenizer, image_processor = model_and_tokenizer
32
-
33
- # get prompt & parameters
34
- image_file = data.pop("image", data)
35
- prompt = data.pop("question", data)
36
-
37
- max_new_tokens = data.pop("max_new_tokens", 1024)
38
- temperature = data.pop("temperature", 0.2)
39
- stop_str = data.pop("stop_str", "###")
40
-
41
- if image_file.startswith("http") or image_file.startswith("https"):
42
- response = requests.get(image_file)
43
- image = Image.open(BytesIO(response.content)).convert("RGB")
44
- else:
45
- image = Image.open(image_file).convert("RGB")
46
-
47
- disable_torch_init()
48
- image_tensor = (
49
- image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
50
- .half()
51
- .cuda()
52
- )
53
-
54
- keywords = [stop_str]
55
- input_ids = (
56
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
57
- .unsqueeze(0)
58
- .cuda()
59
- )
60
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
61
- with torch.inference_mode():
62
- output_ids = model.generate(
63
- input_ids,
64
- images=image_tensor,
65
- do_sample=True,
66
- temperature=temperature,
67
- max_new_tokens=max_new_tokens,
68
- use_cache=True,
69
- stopping_criteria=[stopping_criteria],
70
- )
71
- outputs = tokenizer.decode(
72
- output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
73
- ).strip()
74
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
https:/huggingface.co/anymodality/llava-v1.5-13b/tree/main/code/requirements.txt DELETED
@@ -1 +0,0 @@
1
- llava @ git+https://github.com/haotian-liu/LLaVA@v1.1.1