|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from PIL import Image |
|
import warnings |
|
import gradio as gr |
|
import os |
|
from gradio_client import Client |
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
transformers.logging.disable_progress_bar() |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
torch.set_default_device(device) |
|
|
|
|
|
model_name = 'qnguyen3/nanoLLaVA-1.5' |
|
|
|
print(f"Loading model {model_name} on {device}...") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map='auto', |
|
trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True) |
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
chatter = "K00B404/transcript_image_generator" |
|
chatbot_client = Client(chatter) |
|
|
|
def analyze_character(image_path, analysis_type): |
|
""" |
|
Analyze a character image for dramaturgical insights |
|
|
|
Args: |
|
image_path: Path to the character image |
|
analysis_type: Type of character analysis to perform |
|
|
|
Returns: |
|
str: The generated character analysis |
|
""" |
|
|
|
try: |
|
image = Image.open(image_path).convert('RGB') |
|
|
|
image = image.resize((256, 256), Image.Resampling.LANCZOS) |
|
|
|
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype) |
|
except Exception as e: |
|
return f"Error processing image: {str(e)}" |
|
|
|
|
|
if analysis_type == "full_analysis": |
|
prompt = ("Analyze this character as a dramaturg would. Describe their appearance, " |
|
"potential personality traits, character archetype, suitable roles, and how they might " |
|
"function within a dramatic narrative. Consider costume, posture, expression, and visual symbolism.") |
|
elif analysis_type == "archetype": |
|
prompt = ("Identify the potential character archetype(s) represented in this image. " |
|
"Consider both classical archetypes (hero, mentor, trickster, etc.) and modern " |
|
"interpretations. Explain your reasoning based on visual cues.") |
|
elif analysis_type == "historical_context": |
|
prompt = ("Analyze this character's appearance in terms of historical context. " |
|
"Identify the likely time period, cultural influences, and how these elements " |
|
"would influence the character's role in a dramatic work. Consider costume details, " |
|
"props, and stylistic elements.") |
|
else: |
|
prompt = "Describe this character in detail for dramatic casting purposes." |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are an expert dramaturg with deep knowledge of character analysis, theatrical traditions, and visual storytelling."}, |
|
{"role": "user", "content": f'<image>\n{prompt}'} |
|
] |
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')] |
|
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
try: |
|
|
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
max_new_tokens=1024, |
|
temperature=0.7, |
|
top_p=0.9, |
|
use_cache=False, |
|
do_sample=True) |
|
|
|
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
return response |
|
except Exception as e: |
|
|
|
try: |
|
print(f"First generation method failed with: {str(e)}. Trying fallback method...") |
|
|
|
with torch.inference_mode(): |
|
output = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.9, |
|
temperature=0.7, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id |
|
) |
|
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
return response |
|
except Exception as e2: |
|
return f"Error generating analysis: {str(e)}\nFallback also failed: {str(e2)}\n\nPlease try a different image or check model compatibility." |
|
|
|
def chat_with_persona(message, history, system_message, max_tokens, temperature, top_p): |
|
"""Function to interact with the chatbot API using the generated persona""" |
|
try: |
|
|
|
response = chatbot_client.predict( |
|
message=message, |
|
system_message=system_message, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
api_name="/chat" |
|
) |
|
return response |
|
except Exception as e: |
|
return f"Error communicating with the chatbot API: {str(e)}" |
|
|
|
|
|
def create_ui(): |
|
with gr.Blocks(title="Dramaturg Character Analyzer") as demo: |
|
|
|
analysis_result = gr.State("") |
|
|
|
with gr.Tabs() as tabs: |
|
|
|
with gr.TabItem("Character Analysis"): |
|
gr.Markdown("# Dramaturg Character Analyzer") |
|
gr.Markdown("Upload a character image to receive a dramaturgical analysis") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="filepath", label="Upload Character Image") |
|
analysis_type = gr.Radio( |
|
["full_analysis", "archetype", "historical_context", "basic_description"], |
|
label="Analysis Type", |
|
value="full_analysis" |
|
) |
|
analyze_btn = gr.Button("Analyze Character") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Character Analysis", lines=20) |
|
copy_to_test_btn = gr.Button("Copy to Test Bot", interactive=False) |
|
|
|
def update_analysis_result(result): |
|
|
|
return result, True |
|
|
|
analyze_btn.click( |
|
fn=analyze_character, |
|
inputs=[input_image, analysis_type], |
|
outputs=[output_text, copy_to_test_btn] |
|
) |
|
|
|
def copy_to_test(result): |
|
|
|
return result, 1 |
|
|
|
copy_to_test_btn.click( |
|
fn=copy_to_test, |
|
inputs=[output_text], |
|
outputs=[analysis_result, tabs] |
|
) |
|
|
|
|
|
with gr.TabItem("Test Bot"): |
|
gr.Markdown("# Test Your Character Persona") |
|
gr.Markdown("The character analysis will be used as the system prompt for the test bot.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
system_prompt = gr.Textbox(label="System Prompt (Character Persona)", lines=10) |
|
|
|
with gr.Row(): |
|
max_tokens = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="Max Tokens") |
|
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P") |
|
|
|
user_input = gr.Textbox(label="Your message", placeholder="Ask something about the character...") |
|
send_btn = gr.Button("Send Message") |
|
|
|
with gr.Column(): |
|
chatbot = gr.Chatbot(label="Conversation") |
|
|
|
def update_system_prompt(result): |
|
return result |
|
|
|
|
|
demo.load( |
|
fn=update_system_prompt, |
|
inputs=[analysis_result], |
|
outputs=[system_prompt] |
|
) |
|
|
|
|
|
chat_history = [] |
|
|
|
def respond(message, history, system_message, max_tokens_val, temperature_val, top_p_val): |
|
|
|
history.append((message, "")) |
|
|
|
|
|
response = chat_with_persona( |
|
message=message, |
|
history=history, |
|
system_message=system_message, |
|
max_tokens=max_tokens_val, |
|
temperature=temperature_val, |
|
top_p=top_p_val |
|
) |
|
|
|
|
|
history[-1] = (message, response) |
|
|
|
return "", history |
|
|
|
send_btn.click( |
|
fn=respond, |
|
inputs=[user_input, chatbot, system_prompt, max_tokens, temperature, top_p], |
|
outputs=[user_input, chatbot] |
|
) |
|
|
|
|
|
user_input.submit( |
|
fn=respond, |
|
inputs=[user_input, chatbot, system_prompt, max_tokens, temperature, top_p], |
|
outputs=[user_input, chatbot] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_ui() |
|
demo.launch(share=True) |
|
print("Dramaturg Character Analyzer is now running with Test Bot integration!") |