Spaces:
Running
Running
| import torch | |
| from langgraph.graph import END, StateGraph | |
| from typing import TypedDict, Any | |
| from transformers import ( | |
| AutoProcessor, | |
| BitsAndBytesConfig, | |
| Gemma3ForConditionalGeneration, | |
| ) | |
| def get_quantization_config(): | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Define the state schema | |
| class State(TypedDict): | |
| image: Any | |
| voice: str | |
| caption: str | |
| description: str | |
| # Build the workflow graph | |
| def build_graph(): | |
| workflow = StateGraph(State) | |
| # Add nodes | |
| workflow.add_node("caption_image", caption_image) | |
| workflow.add_node("describe_with_voice", describe_with_voice) | |
| # Add edges | |
| workflow.set_entry_point("caption_image") | |
| workflow.add_edge("caption_image", "describe_with_voice") | |
| workflow.add_edge("describe_with_voice", END) | |
| # Compile the graph | |
| return workflow.compile() | |
| model_id = "google/gemma-3-4b-it" | |
| # Initialize processor and model | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| # quantization_config=get_quantization_config(), | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| def describe_with_voice(state: State) -> State: | |
| state["description"] = "Dummy description" | |
| return state | |
| def caption_image(state: State) -> State: | |
| state["caption"] = "Dummy caption" | |
| def describe_with_voice2(state: State) -> State: | |
| caption = state["caption"] | |
| voice = state["voice"] | |
| # Voice prompt templates | |
| voice_prompts = { | |
| "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", | |
| "forgetful wizard": "You are a forgetful and easily distracted wizard.", | |
| "sarcastic teenager": "You are a sarcastic and disinterested teenager.", | |
| } | |
| messages = [ | |
| {"role": "system", "content": [voice_prompts.get(voice)]}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": f"Describe the following:\n\n{caption}"} | |
| ], | |
| }, | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| description = processor.decode(generation, skip_special_tokens=True) | |
| state["description"] = description | |
| return state | |
| def caption_image2(state: State) -> State: | |
| # image is PIL | |
| image = state["image"] | |
| # Load models (in practice, do this once and cache) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are a helpful assistant that will describe images in 3-5 sentences.", | |
| } | |
| ], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "Describe this image."}, | |
| ], | |
| }, | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| caption = processor.decode(generation, skip_special_tokens=True) | |
| state["caption"] = caption | |
| return state | |