from fastapi import FastAPI, File, UploadFile,HTTPException, Form from pydantic import BaseModel from deepface import DeepFace from transformers import pipeline import io import base64 import pandas as pd import numpy as ny import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from fastapi.middleware.cors import CORSMiddleware from PIL import Image from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse app = FastAPI() # Allow all origins during development, update for production app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ImageInfo(BaseModel): #image: str image: UploadFile #define quantization parameters through the BitsandBytesConfig from transformers bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) get_blip = pipeline("image-to-text",model="Salesforce/blip-image-captioning-large") # using deepface to detect age, gender, emotion def analyze_face(image): #convert PIL image to numpy array image_array = ny.array(image) face_result = DeepFace.analyze(image_array, actions=['age','gender','emotion'], enforce_detection=False) #convert the resulting dictionary to a dataframe df = pd.DataFrame(face_result) return df['dominant_gender'][0],df['age'][0],df['dominant_emotion'][0] #The [0] at the end is for accessing the value at the first row in a DataFrame column. #using blip to generate caption #image_to_base64_str function to convert image to base64 format def image_to_base64_str(pil_image): byte_arr = io.BytesIO() pil_image.save(byte_arr, format='PNG') byte_arr = byte_arr.getvalue() return str(base64.b64encode(byte_arr).decode('utf-8')) #captioner function to take an image def captioner(image): base64_image = image_to_base64_str(image) caption = get_blip(base64_image) return caption[0]['generated_text'] #The [0] at the beginning is for accessing the first element in a container (like a list or dictionary). def get_image_info(image): #call captioner() function image_caption = captioner(image) #call analyze_face() function gender, age, emotion = analyze_face(image) #return image_caption,face_attributes return image_caption, gender, age, emotion #load model with quantization model_id = "mistralai/Mistral-7B-Instruct-v0.1" #model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto") #model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") #no GPU, RuntimeError: No GPU found. A GPU is needed for quantization. model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",low_cpu_mem_usage=True) tokenizer = AutoTokenizer.from_pretrained(model_id) def generate_story(image, storyType, length): image_caption, gender, age, emotion = get_image_info(image) device = "cuda:0" messages = [ { "role": "user", "content":f"generate a {storyType} story for the person in the image which describes a scenario:{image_caption}. Please also notice the person's age:{age}, gender:{gender} and emotion:{emotion} in the image\n\n" } ] #要不要在prompt里面让它注意 生成故事的时候的句子 要考虑到年龄 比如小孩子 就用简单的句子。 encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt") model_inputs = encodeds.to(device) generated_ids = model.generate(model_inputs, max_new_tokens=length, do_sample=True) decoded = tokenizer.batch_decode(generated_ids) generated_Story = decoded[0].replace("", "").replace("", "").replace("[INST]","").replace( "[/INST]","" ).strip() return generated_Story # Mount the static directory to serve HTML, JS, and CSS files app.mount("/", StaticFiles(directory="static", html=True), name="static") # Additional route to serve the HTML form @app.get("/") async def read_item(): content = open("app/static/index.html", "r").read() return HTMLResponse(content=content, status_code=200) @app.post("/generate_story") async def generate_story_endpoint( image: UploadFile = File(...), storyType: str = Form(...), length: int = Form(...), ): try: contents = await image.read() pil_image = Image.open(io.BytesIO(contents)) generated_story = generate_story(pil_image, storyType, length) return {"generated_story": generated_story} except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating story: {str(e)}")