Spaces:
Sleeping
Sleeping
File size: 2,203 Bytes
b0095ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from typing import List
from pydantic import BaseModel
from PIL import Image
import io
from transformers import AutoModel, AutoTokenizer
import torch
app = FastAPI()
# Load model and tokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
class FewshotExample(BaseModel):
image: bytes
question: str
answer: str
class PredictRequest(BaseModel):
fewshot_examples: List[FewshotExample]
test_image: bytes
test_question: str
@app.post("/predict_with_fewshot")
async def predict_with_fewshot(
fewshot_images: List[UploadFile] = File(...),
fewshot_questions: List[str] = Form(...),
fewshot_answers: List[str] = Form(...),
test_image: UploadFile = File(...),
test_question: str = Form(...)
):
# Validate input lengths
if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers):
raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.")
msgs = []
try:
for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers):
img_content = await fs_img.read()
img = Image.open(io.BytesIO(img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [img, fs_q]})
msgs.append({'role': 'assistant', 'content': [fs_a]})
# Test example
test_img_content = await test_image.read()
test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [test_img, test_question]})
# Get answer
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer
)
return {"answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |