Update README.md
Browse files
README.md
CHANGED
@@ -3,4 +3,41 @@ license: mit
|
|
3 |
language:
|
4 |
- en
|
5 |
pipeline_tag: image-to-text
|
6 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
language:
|
4 |
- en
|
5 |
pipeline_tag: image-to-text
|
6 |
+
---
|
7 |
+
|
8 |
+
## **Description**
|
9 |
+
|
10 |
+
It is a ViT model that has been fine-tuned on a **Stable Diffusion 2.0** image dataset and applied **LORA**.
|
11 |
+
It produces optimal results in a reasonable time. Moreover, its implementation with Pytorch is straightforward.
|
12 |
+
|
13 |
+
|
14 |
+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/lora-assets/latent-diffusion.png" alt="Image" width="600">
|
15 |
+
|
16 |
+
* Reference: *https://huggingface.co/blog/lora*
|
17 |
+
|
18 |
+
## **Usage**
|
19 |
+
|
20 |
+
```python
|
21 |
+
# Libraries
|
22 |
+
from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel
|
23 |
+
|
24 |
+
# Model
|
25 |
+
model = VisionEncoderDecoderModel.from_pretrained("nttdataspain/vit-gpt2-coco-lora")
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained("nttdataspain/vit-gpt2-coco-lora")
|
27 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained("nttdataspain/vit-gpt2-coco-lora")
|
28 |
+
|
29 |
+
# Predict function
|
30 |
+
def predict_prompts(list_images, max_length=16):
|
31 |
+
model.eval()
|
32 |
+
pixel_values = feature_extractor(images=list_images, return_tensors="pt").pixel_values
|
33 |
+
with torch.no_grad():
|
34 |
+
output_ids = model.generate(pixel_values, max_length=max_length, num_beams=4, return_dict_in_generate=True).sequences
|
35 |
+
|
36 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
37 |
+
preds = [pred.strip() for pred in preds]
|
38 |
+
return preds
|
39 |
+
|
40 |
+
# Get an image and predict
|
41 |
+
img = Image.open(image_path).convert('RGB')
|
42 |
+
pred_prompts = predict_prompts([img], max_length=16)
|
43 |
+
```
|