|
import gradio as gr |
|
from deepface import DeepFace |
|
from transformers import pipeline |
|
import io |
|
import base64 |
|
import pandas as pd |
|
import numpy as ny |
|
from huggingface_hub import InferenceClient |
|
|
|
from langchain.text_splitter import TokenTextSplitter |
|
|
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.vectorstores import Chroma |
|
|
|
|
|
|
|
get_blip = pipeline("image-to-text",model="Salesforce/blip-image-captioning-large") |
|
|
|
|
|
def analyze_face(image): |
|
|
|
image_array = ny.array(image) |
|
face_result = DeepFace.analyze(image_array, actions=['age','gender','emotion'], enforce_detection=False) |
|
|
|
df = pd.DataFrame(face_result) |
|
return df['dominant_gender'][0],df['age'][0],df['dominant_emotion'][0] |
|
|
|
|
|
|
|
|
|
def image_to_base64_str(pil_image): |
|
byte_arr = io.BytesIO() |
|
pil_image.save(byte_arr, format='PNG') |
|
byte_arr = byte_arr.getvalue() |
|
return str(base64.b64encode(byte_arr).decode('utf-8')) |
|
|
|
def captioner(image): |
|
base64_image = image_to_base64_str(image) |
|
caption = get_blip(base64_image) |
|
return caption[0]['generated_text'] |
|
|
|
|
|
def get_image_info(image): |
|
|
|
image_caption = captioner(image) |
|
|
|
|
|
gender, age, emotion = analyze_face(image) |
|
|
|
|
|
return image_caption, gender, age, emotion |
|
|
|
|
|
|
|
model_name = "BAAI/bge-large-en-v1.5" |
|
model_kwargs = {'device':'cpu'} |
|
|
|
|
|
|
|
encode_kwargs = {'normalize_embeddings':True} |
|
|
|
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) |
|
print("embeddings model loaded....................") |
|
|
|
with open("story.txt", "r") as f: |
|
|
|
state_of_the_union = f.read() |
|
|
|
|
|
text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) |
|
|
|
|
|
texts = text_splitter.split_text(state_of_the_union) |
|
print("...........................................") |
|
|
|
print("text[0]: ", texts[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_store = Chroma.from_texts(texts, embeddings, collection_metadata = {"hnsw:space":"cosine"}, persist_directory="stores/story_cosine" ) |
|
print("vector store created........................") |
|
|
|
load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_function=embeddings) |
|
|
|
|
|
|
|
|
|
retriever = load_vector_store.as_retriever(search_kwargs={"k":3}) |
|
|
|
client = InferenceClient( |
|
"mistralai/Mistral-7B-Instruct-v0.1" |
|
) |
|
|
|
def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0): |
|
image_caption, gender, age, emotion = get_image_info(image) |
|
print("............................................") |
|
print("image_caption:", image_caption) |
|
print("age:", age) |
|
print("gender:", gender) |
|
print("emotion:", emotion) |
|
print("............................................") |
|
query = f"{image_caption}. {emotion}{age} years old {gender}" |
|
|
|
documents = retriever.get_relevant_documents(query) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if documents: |
|
print("document:", dir(documents[0])) |
|
|
|
print(documents[0]) |
|
print(".....................................") |
|
print(documents) |
|
else: |
|
print("no documents") |
|
|
|
|
|
""" |
|
document: ['Config', '__abstractmethods__', '__annotations__', '__class__', '__class_vars__', '__config__', '__custom_root_type__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__exclude_fields__', |
|
'__fields__', '__fields_set__', '__format__', '__ge__', '__get_validators__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__include_fields__', '__init__', '__init_subclass__', '__iter__', '__json_encoder__', |
|
'__le__', '__lt__', '__module__', '__ne__', '__new__', '__post_root_validators__', '__pre_root_validators__', '__pretty__', '__private_attributes__', '__reduce__', '__reduce_ex__', '__repr__', '__repr_args__', '__repr_name__', |
|
'__repr_str__', '__rich_repr__', '__schema_cache__', '__setattr__', '__setstate__', '__signature__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__try_update_forward_refs__', '__validators__', '_abc_impl', '_calculate_keys', |
|
'_copy_and_set_values', '_decompose_class', '_enforce_dict_if_root', '_get_value', '_init_private_attributes', '_iter', 'construct', 'copy', 'dict', 'from_orm', 'get_lc_namespace', 'is_lc_serializable', 'json', 'lc_attributes', 'lc_id', |
|
'lc_secrets', 'metadata', 'page_content', 'parse_file', 'parse_obj', 'parse_raw', 'schema', 'schema_json', 'to_json', 'to_json_not_implemented', 'type', 'update_forward_refs', 'validate'] |
|
""" |
|
|
|
|
|
|
|
|
|
context = '\n'.join([f"Document {index + 1}: {doc.page_content}" for index, doc in enumerate(documents)]) |
|
|
|
print("....................................................................") |
|
print("context:",context) |
|
|
|
prompt = ( |
|
f"[INS] Please generate a detailed and engaging story based on the person's emotion: {emotion}, " |
|
f"age: {age}, and gender: {gender} shown in the image. Begin with the scene described in the image's caption: '{image_caption}'. " |
|
f"Just use the following example story plots and formats as an inspiration: " |
|
f"{context} " |
|
f"The generated story should include a beginning, middle, and end, and the complete story should approximately be {max_new_tokens} words.[/INS]" |
|
|
|
) |
|
|
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
|
|
|
|
output = "" |
|
for response in stream: |
|
output += response.token.text |
|
|
|
yield output |
|
print("..........................................................") |
|
print("generated story:", output) |
|
return output |
|
|
|
demo = gr.Interface(fn=generate, |
|
inputs=[ |
|
|
|
gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"), |
|
|
|
gr.Slider( |
|
label="Temperature", |
|
value=0.9, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
), |
|
|
|
gr.Slider( |
|
label="Max new tokens", |
|
value=1500, |
|
minimum=0, |
|
maximum=3000, |
|
step=1.0, |
|
interactive=True, |
|
info="The maximum numbers of new tokens"), |
|
|
|
gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.90, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
), |
|
gr.Slider( |
|
label="Repetition penalty", |
|
value=1.2, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
], |
|
outputs=[gr.Textbox(label="Generated Story")], |
|
title="story generation", |
|
description="generate a story for you", |
|
allow_flagging="never" |
|
|
|
) |
|
demo.launch(debug=(True)) |
|
|
|
|
|
|
|
|