File size: 11,792 Bytes
c207bc8 fc67e47 c207bc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import gradio as gr
from deepface import DeepFace
from transformers import pipeline
import io
import base64
import pandas as pd
import numpy as ny
from huggingface_hub import InferenceClient
from langchain.text_splitter import TokenTextSplitter
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
# from langchain.chain import RetrievalQA
# from langchain import PromptTemplate
get_blip = pipeline("image-to-text",model="Salesforce/blip-image-captioning-large")
# using deepface to detect age, gender, emotion(happy,neutral,surprise,sad,angry,fear,disgust)
def analyze_face(image):
#convert PIL image to numpy array
image_array = ny.array(image)
face_result = DeepFace.analyze(image_array, actions=['age','gender','emotion'], enforce_detection=False)
#convert the resulting dictionary to a dataframe
df = pd.DataFrame(face_result)
return df['dominant_gender'][0],df['age'][0],df['dominant_emotion'][0]
#The [0] at the end is for accessing the value at the first row in a DataFrame column.
#using blip to generate caption
#image_to_base64_str function to convert image to base64 format
def image_to_base64_str(pil_image):
byte_arr = io.BytesIO()
pil_image.save(byte_arr, format='PNG')
byte_arr = byte_arr.getvalue()
return str(base64.b64encode(byte_arr).decode('utf-8'))
#captioner function to take an image
def captioner(image):
base64_image = image_to_base64_str(image)
caption = get_blip(base64_image)
return caption[0]['generated_text']
#The [0] at the beginning is for accessing the first element in a container (like a list or dictionary).
def get_image_info(image):
#call captioner() function
image_caption = captioner(image)
#call analyze_face() function
gender, age, emotion = analyze_face(image)
#return image_caption,face_attributes
return image_caption, gender, age, emotion
# loading the embedding model
model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cpu'}
#encode_kwargs = {'normalize_embeddings':False}
# the embeddings will be normalized, normalization can make cosine similarity(angular distance) calculations more effective,
# bacause it is comparison tasks based on directional similarity between vectors.
encode_kwargs = {'normalize_embeddings':True}
# initialize embeddings
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
print("embeddings model loaded....................")
# load the txt file
with open("story.txt", "r") as f:
# r: read mode, reading only
state_of_the_union = f.read()
# read the file into a single string
# split the content into chunks
text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20)
# TokenTextSplitter() can ensure the integrity of words
# each chunk to overlap with the previous chunk by 20 tokens
texts = text_splitter.split_text(state_of_the_union)
print("...........................................")
# print the first chunk
print("text[0]: ", texts[0])
# create embeddings for chunks by using bge model, and then save these vectors into chroma vector database
# use hnsw(hierarchical navigable small world) index to facilitate efficient searching
# use cosine similarity to measure similiarity.(similarity is crucial in performing similarity search.)
# hnsw: builds a graph-based index for approximate nearest neighber searches.
# hnsw is used for organizing the data into an efficient structure that supports rapid retrieval operations(speed up the search).
# cosine similarity is used for telling the hnsw algorithm how to measure the distance between vectors.
# by setting space to cosine space, the index will operate using cosine similarity to measuer the vectors' similarity.
vector_store = Chroma.from_texts(texts, embeddings, collection_metadata = {"hnsw:space":"cosine"}, persist_directory="stores/story_cosine" )
print("vector store created........................")
load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_function=embeddings)
# persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
# embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
# Only get the 3 most similar document from the dataset
retriever = load_vector_store.as_retriever(search_kwargs={"k":3})
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0):
image_caption, gender, age, emotion = get_image_info(image)
print("............................................")
print("image_caption:", image_caption)
print("age:", age)
print("gender:", gender)
print("emotion:", emotion)
print("............................................")
query = f"{image_caption}. {emotion}{age} years old {gender}"
# retrieve documents based on query
documents = retriever.get_relevant_documents(query)
# the embedding of the query abd comparing query embedding and chunks embedding are handle internally by the get_relevant_documents() method.
# embedding query: When a query is made, the retriever first converts the query text into a vector using the same embedding model
# that was used for creating the document vectors in the store. This ensures that the query vector and document vectors are compatible for similarity comparisons.
# the method of comparing the similarity between query vector and chunk vectors is:
# cosine similarity and hnsw. because we've configured the vector store with {"hnsw:space":"cosine"}.
# the methods used for both embedding the query and comparing the query vector with the stored document vectors are directly influenced by the configurations of the vector store we set up.
# get_relevant_document() use the embedding function specified when we set up the Chroma database.
if documents:
print("document:", dir(documents[0]))
# print the directory of the methods and attributes of the first document
print(documents[0])
print(".....................................")
print(documents)
else:
print("no documents")
# dir(documents[0]):
"""
document: ['Config', '__abstractmethods__', '__annotations__', '__class__', '__class_vars__', '__config__', '__custom_root_type__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__exclude_fields__',
'__fields__', '__fields_set__', '__format__', '__ge__', '__get_validators__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__include_fields__', '__init__', '__init_subclass__', '__iter__', '__json_encoder__',
'__le__', '__lt__', '__module__', '__ne__', '__new__', '__post_root_validators__', '__pre_root_validators__', '__pretty__', '__private_attributes__', '__reduce__', '__reduce_ex__', '__repr__', '__repr_args__', '__repr_name__',
'__repr_str__', '__rich_repr__', '__schema_cache__', '__setattr__', '__setstate__', '__signature__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__try_update_forward_refs__', '__validators__', '_abc_impl', '_calculate_keys',
'_copy_and_set_values', '_decompose_class', '_enforce_dict_if_root', '_get_value', '_init_private_attributes', '_iter', 'construct', 'copy', 'dict', 'from_orm', 'get_lc_namespace', 'is_lc_serializable', 'json', 'lc_attributes', 'lc_id',
'lc_secrets', 'metadata', 'page_content', 'parse_file', 'parse_obj', 'parse_raw', 'schema', 'schema_json', 'to_json', 'to_json_not_implemented', 'type', 'update_forward_refs', 'validate']
"""
# context = ' '.join([doc.page_content for doc in documents])
#context = '\n'.join([f"Document {index + 1}: {doc}" for index, doc in enumerate(documents)])
# make the documents' format more clear
context = '\n'.join([f"Document {index + 1}: {doc.page_content}" for index, doc in enumerate(documents)])
#prompt = f"[INS] Generate a story based on person’s emotion: {emotion}, age: {age}, gender: {gender} of the image, and image’s caption: {image_caption}. Please use simple words and a child-friendly tone for children, a mature tone for adults, and a considerate, reflective tone for elders.[/INS]"
print("....................................................................")
print("context:",context)
#prompt = f"[INS] Generate a story based on person’s emotion: {emotion}, age: {age}, gender: {gender} of the image, and image’s caption: {image_caption}. The following are some sentence examples: {context}[/INS]"
prompt = (
f"[INS] Please generate a detailed and engaging story based on the person's emotion: {emotion}, "
f"age: {age}, and gender: {gender} shown in the image. Begin with the scene described in the image's caption: '{image_caption}'. "
f"Just use the following example story plots and formats as an inspiration: "
f"{context} "
f"The generated story should include a beginning, middle, and end, and the complete story should approximately be {max_new_tokens} words.[/INS]"
# f"Feel free to develop a complete story in depth and the generated story should approximately be {max_new_tokens} words.[/INS]"
)
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
# return_full_text=False: only has generated story
# return_full_text=True: include original prompt and generated story
output = ""
for response in stream:
output += response.token.text
# yield "".join(output)
yield output
print("..........................................................")
print("generated story:", output)
return output
demo = gr.Interface(fn=generate,
inputs=[
#gr.Video(sources=["webcam"], label="video")
gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1500,
minimum=0,
maximum=3000,
step=1.0,
interactive=True,
info="The maximum numbers of new tokens"),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
],
outputs=[gr.Textbox(label="Generated Story")],
title="story generation",
description="generate a story for you",
allow_flagging="never"
)
demo.launch(debug=(True))
|