File size: 931 Bytes
b34cc6a
14a9ccd
b34cc6a
14a9ccd
ed0e96f
b34cc6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from text_generation import Client
import os 

API_TOKEN = os.environ.get("API_TOKEN")
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceM4/idefics-9b-instruct"
DECODING_STRATEGY = "Greedy"
QUERY = "User: What is in this image?![](https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG)<end_of_utterance>\nAssistant:"

client = Client(
    base_url=API_URL,
    headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
)
generation_args = {
    "max_new_tokens": 256,
    "repetition_penalty": 1.0,
    "stop_sequences": ["<end_of_utterance>", "\nUser:"],
}

if DECODING_STRATEGY == "Greedy":
    generation_args["do_sample"] = False
elif DECODING_STRATEGY == "Top P Sampling":
    generation_args["temperature"] = 1.
    generation_args["do_sample"] = True
    generation_args["top_p"] = 0.95
    
generated_text = client.generate(prompt=QUERY, **generation_args)  
print(generated_text)