LukasHug commited on
Commit
ac4a50f
1 Parent(s): 97e36f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -34
README.md CHANGED
@@ -36,54 +36,60 @@ LlavaGuard-v1.2-7B-OV is trained on [LlavaGuard-DS](https://huggingface.co/datas
36
 
37
  ## Model Compatability
38
 
39
- - Inference: SGLang❌, LLaVA [repo](https://github.com/LLaVA-VL/LLaVA-NeXT)❌, HF Tranformers✅
40
  - Model Tuning:❌
41
 
42
  ## Overview
43
  We here provide the transformers converted weights for LlavaGuard v1.2 7B.
44
  It builds upon LLaVA-OneVision 7B and has achieved the best overall performance so far with improved reasoning capabilities within the rationales.
45
- This version is not compatible with the HF transformer implementation and must be used with SGLang or LLaVA implementation.
46
- The model is also compatible with LoRA tuning as well as full fine-tuning.
47
- For tuning, you can adopt and use the training scripts provided in our repository (see [ml-research/LlavaGuard](https://github.com/ml-research/LlavaGuard)).
48
- A suitable docker image can be found at our Github repo, too.
49
 
50
- #### Usage
51
-
52
- # 0. Install requirements
53
- For inference, you use the following [sglang docker](https://github.com/sgl-project/sglang/blob/main/docker/Dockerfile) and proceed with step 1.
54
- Otherwise, you can also install sglang via pip or from source [see here](https://github.com/sgl-project/sglang).
55
 
56
- # 1. Select a model and start an SGLang server
57
 
58
- CUDA_VISIBLE_DEVICES=0 python3 -m sglang.launch_server --model-path AIML-TUDA/LlavaGuard-v1.2-7B-OV --port 10000
59
-
60
- # 2. Model Inference
61
  For model inference, you can access this server by running the code provided below, e.g.
62
  `python my_script.py`
63
 
64
  ```Python
65
- import sglang as sgl
66
- from sglang import RuntimeEndpoint
67
-
68
- @sgl.function
69
- def guard_gen(s, image_path, prompt):
70
- s += sgl.user(sgl.image(image_path) + prompt)
71
- hyperparameters = {
72
- 'temperature': 0.2,
73
- 'top_p': 0.95,
74
- 'top_k': 50,
75
- 'max_tokens': 500,
76
- }
77
- s += sgl.assistant(sgl.gen("json_output", **hyperparameters))
78
-
79
- im_path = 'path/to/your/image'
80
- prompt = safety_taxonomy_below
81
- backend = RuntimeEndpoint(f"http://localhost:10000")
82
- sgl.set_default_backend(backend)
83
- out = guard_gen.run(image_path=im_path, prompt=prompt)
84
- print(out['json_output'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ```
86
 
 
87
  ## Safety Taxonomy
88
 
89
  Our default policy prompt looks like this:
 
36
 
37
  ## Model Compatability
38
 
39
+ - Inference: HF Tranformers✅, SGLang❌, LLaVA [repo](https://github.com/LLaVA-VL/LLaVA-NeXT)
40
  - Model Tuning:❌
41
 
42
  ## Overview
43
  We here provide the transformers converted weights for LlavaGuard v1.2 7B.
44
  It builds upon LLaVA-OneVision 7B and has achieved the best overall performance so far with improved reasoning capabilities within the rationales.
 
 
 
 
45
 
 
 
 
 
 
46
 
47
+ #### Usage
48
 
 
 
 
49
  For model inference, you can access this server by running the code provided below, e.g.
50
  `python my_script.py`
51
 
52
  ```Python
53
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
54
+ from PIL import Image
55
+ import requests
56
+
57
+ model = LlavaForConditionalGeneration.from_pretrained('AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf')
58
+ processor = AutoProcessor.from_pretrained('AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf')
59
+
60
+ conversation = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {"type": "image"},
65
+ {"type": "text", "text": policy},
66
+ ],
67
+ },
68
+ ]
69
+
70
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
71
+
72
+ url = "https://www.ilankelman.org/stopsigns/australia.jpg"
73
+ image = Image.open(requests.get(url, stream=True).raw)
74
+
75
+ inputs = processor(text=text_prompt, images=image, return_tensors="pt")
76
+ model.to('cuda:0')
77
+ inputs = {k: v.to('cuda:0') for k, v in inputs.items()}
78
+ # Generate
79
+ hyperparameters = {
80
+ "max_new_tokens": 200,
81
+ "do_sample": True,
82
+ "temperature": 0.2,
83
+ "top_p": 0.95,
84
+ "top_k": 50,
85
+ "num_beams": 2,
86
+ "use_cache": True,
87
+ }
88
+ output = model.generate(**inputs, **hyperparameters)
89
+ print(processor.decode(output[0], skip_special_tokens=True))
90
  ```
91
 
92
+
93
  ## Safety Taxonomy
94
 
95
  Our default policy prompt looks like this: