File size: 2,203 Bytes
b0095ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from typing import List
from pydantic import BaseModel
from PIL import Image
import io
from transformers import AutoModel, AutoTokenizer
import torch

app = FastAPI()

# Load model and tokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
                                  attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)

class FewshotExample(BaseModel):
    image: bytes
    question: str
    answer: str

class PredictRequest(BaseModel):
    fewshot_examples: List[FewshotExample]
    test_image: bytes
    test_question: str

@app.post("/predict_with_fewshot")
async def predict_with_fewshot(
    fewshot_images: List[UploadFile] = File(...),
    fewshot_questions: List[str] = Form(...),
    fewshot_answers: List[str] = Form(...),
    test_image: UploadFile = File(...),
    test_question: str = Form(...)
):
    # Validate input lengths
    if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers):
        raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.")
    
    msgs = []
    try:
        for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers):
            img_content = await fs_img.read()
            img = Image.open(io.BytesIO(img_content)).convert('RGB')
            msgs.append({'role': 'user', 'content': [img, fs_q]})
            msgs.append({'role': 'assistant', 'content': [fs_a]})
        
        # Test example
        test_img_content = await test_image.read()
        test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB')
        msgs.append({'role': 'user', 'content': [test_img, test_question]})
        
        # Get answer
        answer = model.chat(
            image=None,
            msgs=msgs,
            tokenizer=tokenizer
        )
        
        return {"answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")