test-four / main.py
yiyii's picture
update1
01295a5
from fastapi import FastAPI, File, UploadFile,HTTPException, Form
from pydantic import BaseModel
from deepface import DeepFace
from transformers import pipeline
import io
import base64
import pandas as pd
import numpy as ny
import torch
#from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from huggingface_hub import InferenceClient
app = FastAPI()
# Allow all origins during development, update for production
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageInfo(BaseModel):
#image: str
image: UploadFile
# #define quantization parameters through the BitsandBytesConfig from transformers
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
# )
get_blip = pipeline("image-to-text",model="Salesforce/blip-image-captioning-large")
# using deepface to detect age, gender, emotion
def analyze_face(image):
#convert PIL image to numpy array
image_array = ny.array(image)
face_result = DeepFace.analyze(image_array, actions=['age','gender','emotion'], enforce_detection=False)
#convert the resulting dictionary to a dataframe
df = pd.DataFrame(face_result)
return df['dominant_gender'][0],df['age'][0],df['dominant_emotion'][0]
#The [0] at the end is for accessing the value at the first row in a DataFrame column.
#using blip to generate caption
#image_to_base64_str function to convert image to base64 format
def image_to_base64_str(pil_image):
byte_arr = io.BytesIO()
pil_image.save(byte_arr, format='PNG')
byte_arr = byte_arr.getvalue()
return str(base64.b64encode(byte_arr).decode('utf-8'))
#captioner function to take an image
def captioner(image):
base64_image = image_to_base64_str(image)
caption = get_blip(base64_image)
return caption[0]['generated_text']
#The [0] at the beginning is for accessing the first element in a container (like a list or dictionary).
def get_image_info(image):
#call captioner() function
image_caption = captioner(image)
#call analyze_face() function
gender, age, emotion = analyze_face(image)
#return image_caption,face_attributes
return image_caption, gender, age, emotion
# #load model with quantization
# model_id = "mistralai/Mistral-7B-Instruct-v0.1"
# #model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
# #model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
# #no GPU, RuntimeError: No GPU found. A GPU is needed for quantization.
# model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",low_cpu_mem_usage=True)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# def generate_story(image, storyType, length):
# image_caption, gender, age, emotion = get_image_info(image)
# device = "cuda:0"
# messages = [
# {
# "role": "user",
# "content":f"generate a {storyType} story for the person in the image which describes a scenario:{image_caption}. Please also notice the person's age:{age}, gender:{gender} and emotion:{emotion} in the image\n\n"
# }
# ]
# encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
# model_inputs = encodeds.to(device)
# generated_ids = model.generate(model_inputs, max_new_tokens=length, do_sample=True)
# decoded = tokenizer.batch_decode(generated_ids)
# generated_Story = decoded[0].replace("<s>", "").replace("</s>", "").replace("[INST]","").replace( "[/INST]","" ).strip()
# return generated_Story
def generate_story(image, length):
image_caption, gender, age, emotion = get_image_info(image)
#prompt = f"[INS] generate a story for the person in the image which describes a scenario:{image_caption}. Please also notice the person's age:{age}, gender:{gender} and emotion:{emotion} in the image.[/INS]"
prompt = f"[INS] Develop a story inspired by the image and its caption: {image_caption}. Factor in the person's age: {age} (e.g., child, teenager, adult, elder), gender: {gender}, and emotion: {emotion}. Adjust the language style accordingly; use simple words and a child-friendly tone for children, a mature tone for adults, and a considerate, reflective tone for elders. The generated story should less than: {length}. Tailor the narrative to fit the specified story length, ensuring a satisfying and conclusive ending. Ensure the narrative resonates with the appropriate age group for a more nuanced and relatable storytelling experience.[/INS]"
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
# Mount the static directory to serve HTML, JS, and CSS files
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# Additional route to serve the HTML form
@app.get("/")
async def read_item():
content = open("app/static/index.html", "r").read()
return HTMLResponse(content=content, status_code=200)
@app.post("/generate_story")
async def generate_story_endpoint(
image: UploadFile = File(...),
length: int = Form(...)
):
try:
contents = await image.read()
pil_image = Image.open(io.BytesIO(contents))
generated_story = generate_story(pil_image, length)
return {"generated_story": generated_story}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating story: {str(e)}")
@app.post("/generate_story")
async def generate_story_endpoint(
image: UploadFile = File(...),
length: int = Form(...)
):
try:
contents = await image.read()
return {"generated_story": "Story will be generated here"}
except Exception as e:
raise HTTPException(status_code=50)