Update main.py
Browse files
main.py
CHANGED
|
@@ -1,101 +1,219 @@
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 2 |
-
|
| 3 |
-
from fastapi.
|
| 4 |
-
from fastapi.
|
| 5 |
-
from fastapi.
|
|
|
|
|
|
|
| 6 |
import requests
|
| 7 |
import base64
|
| 8 |
-
from typing import Iterator
|
| 9 |
import os
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from deep_translator import GoogleTranslator
|
|
|
|
|
|
|
| 12 |
|
| 13 |
app = FastAPI()
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def run(message: str,
|
| 42 |
-
chat_history: list[tuple[str, str]],
|
| 43 |
-
system_prompt: str,
|
| 44 |
-
max_new_tokens: int = 1024,
|
| 45 |
-
temperature: float = 0.1,
|
| 46 |
-
top_p: float = 0.9,
|
| 47 |
-
top_k: int = 50) -> Iterator[str]:
|
| 48 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
| 49 |
-
|
| 50 |
-
generate_kwargs = dict(
|
| 51 |
-
max_new_tokens=max_new_tokens,
|
| 52 |
-
do_sample=True,
|
| 53 |
-
top_p=top_p,
|
| 54 |
-
top_k=top_k,
|
| 55 |
-
temperature=temperature,
|
| 56 |
)
|
| 57 |
-
|
| 58 |
-
output = ""
|
| 59 |
-
for response in stream:
|
| 60 |
-
if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
|
| 61 |
-
yield output
|
| 62 |
-
output = ""
|
| 63 |
-
else:
|
| 64 |
-
output += response.token.text
|
| 65 |
|
|
|
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def generate_image_caption(image_data):
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
|
| 71 |
if response.status_code == 200:
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
-
return "Error:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
|
|
|
|
| 87 |
@app.post("/generate-story/")
|
| 88 |
-
async def
|
| 89 |
image_data = await image_file.read()
|
| 90 |
-
system_prompt = f"write an attractive story in 300 words about {random.choice(['an adventurous journey', 'a mysterious encounter', 'a heroic quest', 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', 'a dangerous mission', 'a romantic escapade', 'an epic battle', 'a journey into the unknown'])}"
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
caption = generate_image_caption(image_data)
|
| 93 |
if caption.startswith("Error"):
|
|
|
|
| 94 |
raise HTTPException(status_code=500, detail=caption)
|
| 95 |
-
ai_response = next(run(caption, [], system_prompt))
|
| 96 |
-
|
| 97 |
-
if language != "english":
|
| 98 |
-
translator = GoogleTranslator(source='english', target=language)
|
| 99 |
-
ai_response = translator.translate(ai_response)
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 2 |
+
# Keep these if you use them elsewhere in your app (HTML, static files)
|
| 3 |
+
# from fastapi.responses import HTMLResponse
|
| 4 |
+
# from fastapi.staticfiles import StaticFiles
|
| 5 |
+
# from fastapi.templating import Jinja2Templates
|
| 6 |
+
# from fastapi.responses import FileResponse
|
| 7 |
+
|
| 8 |
import requests
|
| 9 |
import base64
|
|
|
|
| 10 |
import os
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
# Import necessary classes from transformers
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Added BitsAndBytesConfig in case you ever need quantization
|
| 16 |
+
|
| 17 |
+
|
| 18 |
from deep_translator import GoogleTranslator
|
| 19 |
+
from deep_translator.exceptions import InvalidSourceOrTargetLanguage
|
| 20 |
+
|
| 21 |
|
| 22 |
app = FastAPI()
|
| 23 |
|
| 24 |
+
# --- Hugging Face Model Setup (Local) ---
|
| 25 |
+
# Model name for Gemma 2B Instruction-Tuned
|
| 26 |
+
# This version is trained to follow instructions, ideal for your task.
|
| 27 |
+
model_name = "google/gemma-2b-it"
|
| 28 |
+
tokenizer = None
|
| 29 |
+
model = None
|
| 30 |
+
|
| 31 |
+
# Function to load the model and tokenizer
|
| 32 |
+
def load_model():
|
| 33 |
+
global tokenizer, model
|
| 34 |
+
print(f"Loading model: {model_name}...")
|
| 35 |
+
|
| 36 |
+
# Load tokenizer
|
| 37 |
+
# trust_remote_code=True might be needed for some newer models/features,
|
| 38 |
+
# but standard Gemma usually works without it. Let's omit it for security unless necessary.
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 40 |
+
|
| 41 |
+
# Load model - Gemma can be loaded in float16 to save RAM
|
| 42 |
+
# On CPU, float16 performance can vary, but it reduces memory bandwidth
|
| 43 |
+
# which can sometimes help. 16GB RAM is plenty for Gemma 2B float16 (~2GB).
|
| 44 |
+
# We don't need quantization (load_in_8bit/4bit) for Gemma 2B with 16GB RAM,
|
| 45 |
+
# but it's an option for larger models or less RAM.
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_name,
|
| 48 |
+
torch_dtype=torch.float16, # Use float16 precision
|
| 49 |
+
# device_map="auto" # Not strictly needed for single CPU inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
)
|
| 51 |
+
# model.to("cpu") # Explicitly ensure it's on CPU, although from_pretrained on CPU does this.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
print(f"Model {model_name} loaded successfully.")
|
| 54 |
|
| 55 |
+
# Load the model when the app starts
|
| 56 |
+
@app.on_event("startup")
|
| 57 |
+
async def startup_event():
|
| 58 |
+
load_model()
|
| 59 |
+
|
| 60 |
+
# --- Image Captioning (External API - Keep) ---
|
| 61 |
+
# Keep this as it is, it uses an external service
|
| 62 |
def generate_image_caption(image_data):
|
| 63 |
+
payload = {"data": ["data:image/jpeg;base64," + base64.b64encode(image_data).decode('utf-8')]}
|
| 64 |
+
# Use the correct URL for the captioning API. This is the one from your original code.
|
| 65 |
+
# Ensure it's stable or replace if needed.
|
| 66 |
response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
|
| 67 |
if response.status_code == 200:
|
| 68 |
+
try:
|
| 69 |
+
result = response.json()
|
| 70 |
+
caption = result.get("data", ["Error: Unexpected API response format"])[0]
|
| 71 |
+
return caption
|
| 72 |
+
except Exception as e:
|
| 73 |
+
return f"Error: Failed to parse caption API response: {e}"
|
| 74 |
else:
|
| 75 |
+
return f"Error: Caption API returned status code {response.status_code}: {response.text}"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- Gemma Story Generation Function ---
|
| 79 |
+
# Replace the old generation function with one specific to Gemma-IT
|
| 80 |
+
def generate_story_gemma(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Generates text using the loaded Gemma model.
|
| 83 |
+
Applies the Gemma-IT chat template to the prompt.
|
| 84 |
+
"""
|
| 85 |
+
if tokenizer is None or model is None:
|
| 86 |
+
raise RuntimeError("Model and tokenizer not loaded. App startup failed?")
|
| 87 |
+
|
| 88 |
+
# Gemma-IT uses a specific chat template. We wrap the user's prompt in it.
|
| 89 |
+
messages = [
|
| 90 |
+
{"role": "user", "content": prompt_text}
|
| 91 |
+
# You could add a system prompt here if desired, but Gemma-IT
|
| 92 |
+
# often works well with a detailed user prompt.
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
# Apply the chat template. This adds the necessary special tokens
|
| 96 |
+
# and formatting for the model to understand the instruction.
|
| 97 |
+
# `add_generation_prompt=True` adds the token that signals the model
|
| 98 |
+
# should start generating its response.
|
| 99 |
+
input_text = tokenizer.apply_chat_template(
|
| 100 |
+
messages,
|
| 101 |
+
tokenize=False, # Keep as string for encoding later
|
| 102 |
+
add_generation_prompt=True # Add the assistant turn prompt
|
| 103 |
+
)
|
| 104 |
|
| 105 |
+
# Encode the templated prompt
|
| 106 |
+
# Max length should consider the prompt length + generated length
|
| 107 |
+
# Max input context for Gemma is 8192 tokens, but keeping prompt shorter is better for CPU
|
| 108 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024) # Using a reasonable max_length for input
|
| 109 |
|
| 110 |
+
# Ensure inputs are on the correct device (CPU by default)
|
| 111 |
+
# inputs = {k: v.to(model.device) for k, v in inputs.items()} # Redundant on CPU
|
| 112 |
|
| 113 |
+
# Generate text
|
| 114 |
+
# The generate method returns the input_ids plus the generated tokens
|
| 115 |
+
generate_ids = model.generate(
|
| 116 |
+
inputs.input_ids,
|
| 117 |
+
max_new_tokens=max_new_tokens,
|
| 118 |
+
do_sample=True, # Set to True for creative text generation
|
| 119 |
+
temperature=temperature,
|
| 120 |
+
top_p=top_p,
|
| 121 |
+
top_k=top_k,
|
| 122 |
+
pad_token_id=tokenizer.pad_token_id, # Use the pad token during generation
|
| 123 |
+
# Gemma's EOS token is handled by default generate logic often
|
| 124 |
+
# eos_token_id=tokenizer.eos_token_id
|
| 125 |
+
)
|
| 126 |
|
| 127 |
+
# Decode the generated text.
|
| 128 |
+
# We slice generate_ids to exclude the input prompt tokens, only decoding the new ones.
|
| 129 |
+
# The slicing [0, inputs.input_ids.shape[-1]:] selects the generated part for the first (and only) item in the batch
|
| 130 |
+
# The `skip_special_tokens=True` removes special tokens like <start_of_turn>, <end_of_turn>, <eos>
|
| 131 |
+
generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
| 132 |
|
| 133 |
+
# Gemma responses might sometimes include extra whitespace or turn markers if decoding is not perfect.
|
| 134 |
+
# Further cleanup might be needed depending on the exact output format, but skip_special_tokens helps.
|
| 135 |
+
# We can also remove leading/trailing whitespace.
|
| 136 |
+
return generated_text.strip()
|
| 137 |
|
| 138 |
+
# --- FastAPI Endpoint ---
|
| 139 |
@app.post("/generate-story/")
|
| 140 |
+
async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
|
| 141 |
image_data = await image_file.read()
|
|
|
|
| 142 |
|
| 143 |
+
# Choose a random theme for the story prompt
|
| 144 |
+
story_theme = random.choice([
|
| 145 |
+
'an adventurous journey',
|
| 146 |
+
'a mysterious encounter',
|
| 147 |
+
'a heroic quest',
|
| 148 |
+
'a magical adventure',
|
| 149 |
+
'a thrilling escape',
|
| 150 |
+
'an unexpected discovery',
|
| 151 |
+
'a dangerous mission',
|
| 152 |
+
'a romantic escapade',
|
| 153 |
+
'an epic battle',
|
| 154 |
+
'a journey into the unknown'
|
| 155 |
+
])
|
| 156 |
+
|
| 157 |
+
# Get image caption
|
| 158 |
caption = generate_image_caption(image_data)
|
| 159 |
if caption.startswith("Error"):
|
| 160 |
+
print(f"Caption generation failed: {caption}")
|
| 161 |
raise HTTPException(status_code=500, detail=caption)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
# Construct the detailed prompt for Gemma-IT.
|
| 164 |
+
# Instruct it clearly to write a story based on the theme and incorporating the caption.
|
| 165 |
+
prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
|
| 166 |
+
|
| 167 |
+
# Generate the story using the local Gemma model
|
| 168 |
+
try:
|
| 169 |
+
story = generate_story_gemma(
|
| 170 |
+
prompt_text,
|
| 171 |
+
max_new_tokens=300, # Generate up to 300 new tokens
|
| 172 |
+
temperature=0.7, # Controls randomness (higher = more random)
|
| 173 |
+
top_p=0.9, # Controls diversity (nucleus sampling)
|
| 174 |
+
top_k=50 # Controls diversity (top-k sampling)
|
| 175 |
+
)
|
| 176 |
+
# Basic cleanup: Sometimes models might start with whitespace or unwanted characters
|
| 177 |
+
story = story.strip()
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Story generation failed: {e}") # Log generation errors
|
| 181 |
+
# Provide more detail in the HTTP exception for debugging
|
| 182 |
+
raise HTTPException(status_code=500, detail=f"Story generation failed: {e}. Please check Space logs for details.")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Translate the story if the target language is not English
|
| 186 |
+
if language.lower() != "english":
|
| 187 |
+
try:
|
| 188 |
+
translator = GoogleTranslator(source='english', target=language.lower())
|
| 189 |
+
translated_story = translator.translate(story)
|
| 190 |
+
|
| 191 |
+
if translated_story is None:
|
| 192 |
+
print(f"Translation returned None for language: {language}")
|
| 193 |
+
# Return English story with a warning
|
| 194 |
+
return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
|
| 195 |
+
|
| 196 |
+
story = translated_story
|
| 197 |
+
|
| 198 |
+
except InvalidSourceOrTargetLanguage:
|
| 199 |
+
print(f"Invalid target language requested: {language}")
|
| 200 |
+
raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"Translation failed for language {language}: {e}") # Log translation errors
|
| 203 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {e}")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Return the generated (and potentially translated) story
|
| 207 |
+
return {"story": story}
|
| 208 |
+
|
| 209 |
+
# --- Optional: Serve a simple HTML form for testing (Needs templates dir and index.html) ---
|
| 210 |
+
# from fastapi import Request
|
| 211 |
+
# from fastapi.templating import Jinja2Templates
|
| 212 |
+
# from fastapi.staticfiles import StaticFiles
|
| 213 |
+
|
| 214 |
+
# templates = Jinja2Templates(directory="templates")
|
| 215 |
+
# app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 216 |
+
|
| 217 |
+
# @app.get("/", response_class=HTMLResponse)
|
| 218 |
+
# async def read_root(request: Request):
|
| 219 |
+
# return templates.TemplateResponse("index.html", {"request": request})
|