rphrp1985's picture
Update app.py
d6fc7c0 verified
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=models_path)
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
import requests
import base64
import mimetypes
def url_to_base64(image_url):
try:
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/122.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
"Accept-Encoding": "gzip, deflate, br",
"Accept-Language": "en-US,en;q=0.9",
}
# Follow redirects explicitly
response = requests.get(image_url, headers=headers, stream=True, allow_redirects=True, timeout=15)
response.raise_for_status()
# Step 1: Try to detect MIME type from header
content_type = response.headers.get('Content-Type', '')
# Step 2: If it's generic (S3 often uses application/octet-stream)
if not content_type or content_type == 'application/octet-stream':
# Try to extract filename from Content-Disposition
content_disp = response.headers.get('Content-Disposition', '')
filename = None
if 'filename=' in content_disp:
filename = content_disp.split('filename=')[-1].strip('" ')
else:
# Fallback: get filename from URL
filename = os.path.basename(image_url.split("?")[0])
# Guess MIME type from filename extension
mime_type, _ = mimetypes.guess_type(filename)
content_type = mime_type or 'image/jpeg'
# Step 3: Encode content in Base64
base64_image = base64.b64encode(response.content).decode('utf-8')
xx=f"data:{content_type};base64,{base64_image}"
print("base64 ",xx)
return xx
except Exception as e:
print(f"Error fetching image: {e}")
return "data:image/jpeg;base64,"
import json
@spaces.GPU(duration=90)
def run_inference(message, history):
try:
messages= message['text']
print("messages ", messages)
messages = json.loads(messages)
final_msg=[]
for x in messages:
if x['role']=='user':
tmmp=[]
for y in x['content']:
if y['type']=='image':
print('inserting image')
tmmp+=[ImageURLChunk(image_url= url_to_base64(y['url'])) ]
else:
tmmp+=[TextChunk(text= y['text'] )]
final_msg.append(UserMessage(content =tmmp ) )
else:
final_msg.append(AssistantMessage(content = x['content'][0]['text'] ))
print('final msg ', final_msg)
completion_request = ChatCompletionRequest(messages=final_msg)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=2048, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
## may work
except Exception as e:
print('usig deqfualt ', e)
messages = []
images = []
print('\n\nmessage ',message)
print('\n\nhistoery ',history)
for couple in history:
if type(couple[0]) is tuple:
images += couple[0]
elif couple[0][1]:
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])]))
messages.append(AssistantMessage(content = couple[1]))
images = []
##
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
print('\n\nfinal messageds', messages)
completion_request = ChatCompletionRequest(messages=messages)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.")
demo.queue().launch()