Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| from gradio.data_classes import FileData | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| import base64 | |
| import spaces | |
| import os | |
| from mistral_inference.transformer import Transformer | |
| from mistral_inference.generate import generate | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| models_path = Path.home().joinpath('pixtral', 'Pixtral') | |
| models_path.mkdir(parents=True, exist_ok=True) | |
| snapshot_download(repo_id="mistral-community/pixtral-12b-240910", | |
| allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], | |
| local_dir=models_path) | |
| tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json") | |
| model = Transformer.from_folder(models_path) | |
| def image_to_base64(image_path): | |
| with open(image_path, 'rb') as img: | |
| encoded_string = base64.b64encode(img.read()).decode('utf-8') | |
| return f"data:image/jpeg;base64,{encoded_string}" | |
| import requests | |
| import base64 | |
| import mimetypes | |
| def url_to_base64(image_url): | |
| try: | |
| headers = { | |
| "User-Agent": ( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) " | |
| "Chrome/122.0.0.0 Safari/537.36" | |
| ), | |
| "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", | |
| "Accept-Encoding": "gzip, deflate, br", | |
| "Accept-Language": "en-US,en;q=0.9", | |
| } | |
| # Follow redirects explicitly | |
| response = requests.get(image_url, headers=headers, stream=True, allow_redirects=True, timeout=15) | |
| response.raise_for_status() | |
| # Step 1: Try to detect MIME type from header | |
| content_type = response.headers.get('Content-Type', '') | |
| # Step 2: If it's generic (S3 often uses application/octet-stream) | |
| if not content_type or content_type == 'application/octet-stream': | |
| # Try to extract filename from Content-Disposition | |
| content_disp = response.headers.get('Content-Disposition', '') | |
| filename = None | |
| if 'filename=' in content_disp: | |
| filename = content_disp.split('filename=')[-1].strip('" ') | |
| else: | |
| # Fallback: get filename from URL | |
| filename = os.path.basename(image_url.split("?")[0]) | |
| # Guess MIME type from filename extension | |
| mime_type, _ = mimetypes.guess_type(filename) | |
| content_type = mime_type or 'image/jpeg' | |
| # Step 3: Encode content in Base64 | |
| base64_image = base64.b64encode(response.content).decode('utf-8') | |
| xx=f"data:{content_type};base64,{base64_image}" | |
| print("base64 ",xx) | |
| return xx | |
| except Exception as e: | |
| print(f"Error fetching image: {e}") | |
| return "data:image/jpeg;base64," | |
| import json | |
| def run_inference(message, history): | |
| try: | |
| messages= message['text'] | |
| print("messages ", messages) | |
| messages = json.loads(messages) | |
| final_msg=[] | |
| for x in messages: | |
| if x['role']=='user': | |
| tmmp=[] | |
| for y in x['content']: | |
| if y['type']=='image': | |
| print('inserting image') | |
| tmmp+=[ImageURLChunk(image_url= url_to_base64(y['url'])) ] | |
| else: | |
| tmmp+=[TextChunk(text= y['text'] )] | |
| final_msg.append(UserMessage(content =tmmp ) ) | |
| else: | |
| final_msg.append(AssistantMessage(content = x['content'][0]['text'] )) | |
| print('final msg ', final_msg) | |
| completion_request = ChatCompletionRequest(messages=final_msg) | |
| encoded = tokenizer.encode_chat_completion(completion_request) | |
| images = encoded.images | |
| tokens = encoded.tokens | |
| out_tokens, _ = generate([tokens], model, images=[images], max_tokens=2048, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
| result = tokenizer.decode(out_tokens[0]) | |
| return result | |
| ## may work | |
| except Exception as e: | |
| print('usig deqfualt ', e) | |
| messages = [] | |
| images = [] | |
| print('\n\nmessage ',message) | |
| print('\n\nhistoery ',history) | |
| for couple in history: | |
| if type(couple[0]) is tuple: | |
| images += couple[0] | |
| elif couple[0][1]: | |
| messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])])) | |
| messages.append(AssistantMessage(content = couple[1])) | |
| images = [] | |
| ## | |
| messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])])) | |
| print('\n\nfinal messageds', messages) | |
| completion_request = ChatCompletionRequest(messages=messages) | |
| encoded = tokenizer.encode_chat_completion(completion_request) | |
| images = encoded.images | |
| tokens = encoded.tokens | |
| out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
| result = tokenizer.decode(out_tokens[0]) | |
| return result | |
| demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.") | |
| demo.queue().launch() |