Makhinur commited on
Commit
e4ba08d
·
verified ·
1 Parent(s): 0db54a4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +134 -99
main.py CHANGED
@@ -5,15 +5,18 @@ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
5
  # from fastapi.templating import Jinja2Templates
6
  # from fastapi.responses import FileResponse
7
 
8
- import requests
9
- import base64
 
10
  import os
11
  import random
12
 
13
  # Import necessary classes from transformers
14
  import torch
15
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Added BitsAndBytesConfig in case you ever need quantization
16
 
 
 
17
 
18
  from deep_translator import GoogleTranslator
19
  from deep_translator.exceptions import InvalidSourceOrTargetLanguage
@@ -22,125 +25,164 @@ from deep_translator.exceptions import InvalidSourceOrTargetLanguage
22
  app = FastAPI()
23
 
24
  # --- Hugging Face Model Setup (Local) ---
25
- # Model name for Gemma 2B Instruction-Tuned
26
- # This version is trained to follow instructions, ideal for your task.
27
- model_name = "google/gemma-2b-it"
28
  tokenizer = None
29
  model = None
30
 
31
- # Function to load the model and tokenizer
 
 
 
 
 
 
32
  def load_model():
33
  global tokenizer, model
34
- print(f"Loading model: {model_name}...")
35
 
36
- # Load tokenizer
37
- # trust_remote_code=True might be needed for some newer models/features,
38
- # but standard Gemma usually works without it. Let's omit it for security unless necessary.
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
40
 
41
- # Load model - Gemma can be loaded in float16 to save RAM
42
- # On CPU, float16 performance can vary, but it reduces memory bandwidth
43
- # which can sometimes help. 16GB RAM is plenty for Gemma 2B float16 (~2GB).
44
- # We don't need quantization (load_in_8bit/4bit) for Gemma 2B with 16GB RAM,
45
- # but it's an option for larger models or less RAM.
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
- torch_dtype=torch.float16, # Use float16 precision
49
- # device_map="auto" # Not strictly needed for single CPU inference
50
  )
51
- # model.to("cpu") # Explicitly ensure it's on CPU, although from_pretrained on CPU does this.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- print(f"Model {model_name} loaded successfully.")
54
 
55
- # Load the model when the app starts
56
  @app.on_event("startup")
57
  async def startup_event():
58
  load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # --- Image Captioning (External API - Keep) ---
61
- # Keep this as it is, it uses an external service
62
- def generate_image_caption(image_data):
63
- payload = {"data": ["data:image/jpeg;base64," + base64.b64encode(image_data).decode('utf-8')]}
64
- # Use the correct URL for the captioning API. This is the one from your original code.
65
- # Ensure it's stable or replace if needed.
66
- response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
67
- if response.status_code == 200:
68
- try:
69
- result = response.json()
70
- caption = result.get("data", ["Error: Unexpected API response format"])[0]
71
- return caption
72
- except Exception as e:
73
- return f"Error: Failed to parse caption API response: {e}"
74
- else:
75
- return f"Error: Caption API returned status code {response.status_code}: {response.text}"
76
 
 
 
 
77
 
78
- # --- Gemma Story Generation Function ---
79
- # Replace the old generation function with one specific to Gemma-IT
80
- def generate_story_gemma(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
81
  """
82
- Generates text using the loaded Gemma model.
83
- Applies the Gemma-IT chat template to the prompt.
84
  """
85
  if tokenizer is None or model is None:
86
- raise RuntimeError("Model and tokenizer not loaded. App startup failed?")
87
 
88
- # Gemma-IT uses a specific chat template. We wrap the user's prompt in it.
89
  messages = [
90
  {"role": "user", "content": prompt_text}
91
- # You could add a system prompt here if desired, but Gemma-IT
92
- # often works well with a detailed user prompt.
93
  ]
94
 
95
- # Apply the chat template. This adds the necessary special tokens
96
- # and formatting for the model to understand the instruction.
97
- # `add_generation_prompt=True` adds the token that signals the model
98
- # should start generating its response.
99
- input_text = tokenizer.apply_chat_template(
100
- messages,
101
- tokenize=False, # Keep as string for encoding later
102
- add_generation_prompt=True # Add the assistant turn prompt
103
- )
104
-
105
- # Encode the templated prompt
106
- # Max length should consider the prompt length + generated length
107
- # Max input context for Gemma is 8192 tokens, but keeping prompt shorter is better for CPU
108
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024) # Using a reasonable max_length for input
109
 
110
- # Ensure inputs are on the correct device (CPU by default)
111
- # inputs = {k: v.to(model.device) for k, v in inputs.items()} # Redundant on CPU
112
 
113
- # Generate text
114
- # The generate method returns the input_ids plus the generated tokens
115
  generate_ids = model.generate(
116
  inputs.input_ids,
117
  max_new_tokens=max_new_tokens,
118
- do_sample=True, # Set to True for creative text generation
119
  temperature=temperature,
120
  top_p=top_p,
121
  top_k=top_k,
122
- pad_token_id=tokenizer.pad_token_id, # Use the pad token during generation
123
- # Gemma's EOS token is handled by default generate logic often
124
- # eos_token_id=tokenizer.eos_token_id
125
  )
126
 
127
- # Decode the generated text.
128
- # We slice generate_ids to exclude the input prompt tokens, only decoding the new ones.
129
- # The slicing [0, inputs.input_ids.shape[-1]:] selects the generated part for the first (and only) item in the batch
130
- # The `skip_special_tokens=True` removes special tokens like <start_of_turn>, <end_of_turn>, <eos>
131
  generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
132
-
133
- # Gemma responses might sometimes include extra whitespace or turn markers if decoding is not perfect.
134
- # Further cleanup might be needed depending on the exact output format, but skip_special_tokens helps.
135
- # We can also remove leading/trailing whitespace.
136
  return generated_text.strip()
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # --- FastAPI Endpoint ---
139
  @app.post("/generate-story/")
140
  async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
141
- image_data = await image_file.read()
 
142
 
143
- # Choose a random theme for the story prompt
144
  story_theme = random.choice([
145
  'an adventurous journey',
146
  'a mysterious encounter',
@@ -154,35 +196,33 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
154
  'a journey into the unknown'
155
  ])
156
 
157
- # Get image caption
158
- caption = generate_image_caption(image_data)
 
159
  if caption.startswith("Error"):
160
  print(f"Caption generation failed: {caption}")
161
  raise HTTPException(status_code=500, detail=caption)
162
 
163
- # Construct the detailed prompt for Gemma-IT.
164
- # Instruct it clearly to write a story based on the theme and incorporating the caption.
165
  prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
166
 
167
- # Generate the story using the local Gemma model
168
  try:
169
- story = generate_story_gemma(
170
  prompt_text,
171
- max_new_tokens=300, # Generate up to 300 new tokens
172
- temperature=0.7, # Controls randomness (higher = more random)
173
- top_p=0.9, # Controls diversity (nucleus sampling)
174
- top_k=50 # Controls diversity (top-k sampling)
175
  )
176
- # Basic cleanup: Sometimes models might start with whitespace or unwanted characters
177
  story = story.strip()
178
 
179
  except Exception as e:
180
- print(f"Story generation failed: {e}") # Log generation errors
181
- # Provide more detail in the HTTP exception for debugging
182
  raise HTTPException(status_code=500, detail=f"Story generation failed: {e}. Please check Space logs for details.")
183
 
184
 
185
- # Translate the story if the target language is not English
186
  if language.lower() != "english":
187
  try:
188
  translator = GoogleTranslator(source='english', target=language.lower())
@@ -190,7 +230,6 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
190
 
191
  if translated_story is None:
192
  print(f"Translation returned None for language: {language}")
193
- # Return English story with a warning
194
  return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
195
 
196
  story = translated_story
@@ -199,21 +238,17 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
199
  print(f"Invalid target language requested: {language}")
200
  raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
201
  except Exception as e:
202
- print(f"Translation failed for language {language}: {e}") # Log translation errors
203
  raise HTTPException(status_code=500, detail=f"Translation failed: {e}")
204
 
205
-
206
- # Return the generated (and potentially translated) story
207
  return {"story": story}
208
 
209
- # --- Optional: Serve a simple HTML form for testing (Needs templates dir and index.html) ---
210
  # from fastapi import Request
211
  # from fastapi.templating import Jinja2Templates
212
  # from fastapi.staticfiles import StaticFiles
213
-
214
  # templates = Jinja2Templates(directory="templates")
215
  # app.mount("/static", StaticFiles(directory="static"), name="static")
216
-
217
  # @app.get("/", response_class=HTMLResponse)
218
  # async def read_root(request: Request):
219
  # return templates.TemplateResponse("index.html", {"request": request})
 
5
  # from fastapi.templating import Jinja2Templates
6
  # from fastapi.responses import FileResponse
7
 
8
+ # Removed 'requests' as we'll primarily use gradio_client for captioning
9
+ # import requests
10
+ import base64 # Still useful if you need base64 for anything else
11
  import os
12
  import random
13
 
14
  # Import necessary classes from transformers
15
  import torch
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
 
18
+ # Import the Gradio Client
19
+ from gradio_client import Client
20
 
21
  from deep_translator import GoogleTranslator
22
  from deep_translator.exceptions import InvalidSourceOrTargetLanguage
 
25
  app = FastAPI()
26
 
27
  # --- Hugging Face Model Setup (Local) ---
28
+ # Model name for TinyLlama 1.1B Chat (instruction-tuned version)
29
+ # Or use "google/gemma-2b-it" if you got access and prefer its quality
30
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31
  tokenizer = None
32
  model = None
33
 
34
+ # Global Gradio Client for Captioning
35
+ caption_client = None
36
+ # The Space URL for the captioning API
37
+ CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base"
38
+
39
+
40
+ # Function to load the language model and tokenizer
41
  def load_model():
42
  global tokenizer, model
43
+ print(f"Loading language model: {model_name}...")
44
 
 
 
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
 
 
 
 
 
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
  model_name,
51
+ torch_dtype=torch.float16, # Use float16 precision to save RAM
52
+ # device_map="auto" # Not needed for single CPU
53
  )
54
+ # model.to("cpu") # Explicitly move if needed, though default is CPU
55
+
56
+ print(f"Language model {model_name} loaded successfully.")
57
+
58
+ # Function to initialize the Gradio Client
59
+ def initialize_caption_client():
60
+ global caption_client
61
+ print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
62
+ try:
63
+ caption_client = Client(CAPTION_SPACE_URL)
64
+ print("Gradio client initialized successfully.")
65
+ except Exception as e:
66
+ print(f"Error initializing Gradio client: {e}")
67
+ # Depending on your needs, you might raise an exception here
68
+ # or handle it gracefully later if caption_client is None.
69
+ caption_client = None # Ensure it's None if initialization failed
70
 
 
71
 
72
+ # Load models and initialize clients when the app starts
73
  @app.on_event("startup")
74
  async def startup_event():
75
  load_model()
76
+ initialize_caption_client()
77
+
78
+
79
+ # --- Image Captioning (Using gradio_client) ---
80
+ # Modify to accept UploadFile directly and use the gradio_client
81
+ def generate_image_caption(image_file: UploadFile):
82
+ """
83
+ Generates a caption for the uploaded image using the external Gradio Space API.
84
+ """
85
+ if caption_client is None:
86
+ # Handle cases where client initialization failed
87
+ error_msg = "Gradio caption client not initialized. Cannot generate caption."
88
+ print(error_msg)
89
+ return f"Error: {error_msg}"
90
+
91
+ try:
92
+ print(f"Calling caption API /predict for file {image_file.filename}...")
93
+ # The gradio_client can take a file-like object directly.
94
+ # image_file.file is the actual SpooledTemporaryFile object.
95
+ caption = caption_client.predict(img=image_file.file, api_name="/predict")
96
+ print(f"Caption generated: {caption}")
97
+ return caption
98
+ except Exception as e:
99
+ # Catch potential exceptions from gradio_client.predict (network, API error, etc.)
100
+ print(f"Error during caption generation API call: {e}")
101
+ return f"Error: Unable to generate caption from API. Details: {e}"
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ # --- Language Model Story Generation Function ---
105
+ # Use the appropriate function based on your chosen model (TinyLlama or Gemma)
106
+ # This function name should match the model_name you've chosen.
107
 
108
+ def generate_story_tinyllama(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
 
 
109
  """
110
+ Generates text using the loaded TinyLlama model.
111
+ Applies the chat template.
112
  """
113
  if tokenizer is None or model is None:
114
+ raise RuntimeError("Language model and tokenizer not loaded. App startup failed?")
115
 
116
+ # TinyLlama-Chat uses a chat template similar to Llama/Gemma
117
  messages = [
118
  {"role": "user", "content": prompt_text}
 
 
119
  ]
120
 
121
+ try:
122
+ input_text = tokenizer.apply_chat_template(
123
+ messages,
124
+ tokenize=False,
125
+ add_generation_prompt=True
126
+ )
127
+ except AttributeError: # Fallback for models without apply_chat_template
128
+ print("Warning: apply_chat_template not found. Using basic prompt formatting.")
129
+ input_text = f"<s>[INST] {prompt_text} [/INST]"
 
 
 
 
 
130
 
131
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
 
132
 
 
 
133
  generate_ids = model.generate(
134
  inputs.input_ids,
135
  max_new_tokens=max_new_tokens,
136
+ do_sample=True,
137
  temperature=temperature,
138
  top_p=top_p,
139
  top_k=top_k,
140
+ pad_token_id=tokenizer.pad_token_id,
 
 
141
  )
142
 
 
 
 
 
143
  generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
 
 
 
 
144
  return generated_text.strip()
145
 
146
+ # If using Gemma 2B instead of TinyLlama, use this function:
147
+ # def generate_story_gemma(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
148
+ # """
149
+ # Generates text using the loaded Gemma model.
150
+ # Applies the Gemma-IT chat template.
151
+ # """
152
+ # if tokenizer is None or model is None:
153
+ # raise RuntimeError("Language model and tokenizer not loaded. App startup failed?")
154
+
155
+ # messages = [
156
+ # {"role": "user", "content": prompt_text}
157
+ # ]
158
+ # input_text = tokenizer.apply_chat_template(
159
+ # messages,
160
+ # tokenize=False,
161
+ # add_generation_prompt=True
162
+ # )
163
+
164
+ # inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
165
+
166
+ # generate_ids = model.generate(
167
+ # inputs.input_ids,
168
+ # max_new_tokens=max_new_tokens,
169
+ # do_sample=True,
170
+ # temperature=temperature,
171
+ # top_p=top_p,
172
+ # top_k=top_k,
173
+ # pad_token_id=tokenizer.pad_token_id,
174
+ # )
175
+
176
+ # generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
177
+ # return generated_text.strip()
178
+
179
+
180
  # --- FastAPI Endpoint ---
181
  @app.post("/generate-story/")
182
  async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
183
+ # No longer need to read the image data fully here
184
+ # image_data = await image_file.read()
185
 
 
186
  story_theme = random.choice([
187
  'an adventurous journey',
188
  'a mysterious encounter',
 
196
  'a journey into the unknown'
197
  ])
198
 
199
+ # Get image caption using the gradio_client function
200
+ # Pass the UploadFile object directly
201
+ caption = generate_image_caption(image_file)
202
  if caption.startswith("Error"):
203
  print(f"Caption generation failed: {caption}")
204
  raise HTTPException(status_code=500, detail=caption)
205
 
206
+ # Construct the prompt for the language model
 
207
  prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
208
 
209
+ # Generate the story using the appropriate function (adjust if using Gemma)
210
  try:
211
+ story = generate_story_tinyllama( # <--- Make sure this matches your chosen model function
212
  prompt_text,
213
+ max_new_tokens=300,
214
+ temperature=0.7,
215
+ top_p=0.9,
216
+ top_k=50
217
  )
 
218
  story = story.strip()
219
 
220
  except Exception as e:
221
+ print(f"Story generation failed: {e}")
 
222
  raise HTTPException(status_code=500, detail=f"Story generation failed: {e}. Please check Space logs for details.")
223
 
224
 
225
+ # Translate the story
226
  if language.lower() != "english":
227
  try:
228
  translator = GoogleTranslator(source='english', target=language.lower())
 
230
 
231
  if translated_story is None:
232
  print(f"Translation returned None for language: {language}")
 
233
  return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
234
 
235
  story = translated_story
 
238
  print(f"Invalid target language requested: {language}")
239
  raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
240
  except Exception as e:
241
+ print(f"Translation failed for language {language}: {e}")
242
  raise HTTPException(status_code=500, detail=f"Translation failed: {e}")
243
 
 
 
244
  return {"story": story}
245
 
246
+ # --- Optional: HTML form for testing (Needs templates dir and index.html) ---
247
  # from fastapi import Request
248
  # from fastapi.templating import Jinja2Templates
249
  # from fastapi.staticfiles import StaticFiles
 
250
  # templates = Jinja2Templates(directory="templates")
251
  # app.mount("/static", StaticFiles(directory="static"), name="static")
 
252
  # @app.get("/", response_class=HTMLResponse)
253
  # async def read_root(request: Request):
254
  # return templates.TemplateResponse("index.html", {"request": request})