import os from fastapi import FastAPI, File, UploadFile, HTTPException, Depends from fastapi.security.api_key import APIKeyHeader from starlette.status import HTTP_403_FORBIDDEN from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig from PIL import Image import torch import base64 import io app = FastAPI() # Load the processor and model processor = AutoProcessor.from_pretrained( 'allenai/Molmo-7B-D-0924', trust_remote_code=True, torch_dtype='auto', device_map='auto' ) model = AutoModelForCausalLM.from_pretrained( 'allenai/Molmo-7B-D-0924', trust_remote_code=True, torch_dtype='auto', device_map='auto' ) # API Key setup API_KEY = os.environ.get("API_KEY") api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) async def get_api_key(api_key_header: str = Depends(api_key_header)): if api_key_header == API_KEY: return api_key_header else: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) def process_image_and_text(image, text): inputs = processor.process( images=[image], text=text ) inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} output = model.generate_from_batch( inputs, GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), tokenizer=processor.tokenizer ) generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) return generated_text class Base64Request(BaseModel): image: str text: str @app.post("/upload") async def upload_image(file: UploadFile = File(...), text: str = "", api_key: str = Depends(get_api_key)): contents = await file.read() image = Image.open(io.BytesIO(contents)) response = process_image_and_text(image, text) return {"response": response} @app.post("/base64") async def process_base64(request: Base64Request, api_key: str = Depends(get_api_key)): try: image_data = base64.b64decode(request.image) image = Image.open(io.BytesIO(image_data)) except: raise HTTPException(status_code=400, detail="Invalid base64 image") response = process_image_and_text(image, request.text) return {"response": response}