Makhinur commited on
Commit
f9a665c
·
verified ·
1 Parent(s): 2c1590e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +191 -73
main.py CHANGED
@@ -1,101 +1,219 @@
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from fastapi.responses import FileResponse
 
 
6
  import requests
7
  import base64
8
- from typing import Iterator
9
  import os
10
- from text_generation import Client
 
 
 
 
 
 
11
  from deep_translator import GoogleTranslator
 
 
12
 
13
  app = FastAPI()
14
 
15
- model_id = 'codellama/CodeLlama-34b-Instruct-hf'
16
-
17
- API_URL = "https://api-inference.huggingface.co/models/" + model_id
18
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
-
20
- client = Client(
21
- API_URL,
22
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
23
- )
24
- EOS_STRING = "</s>"
25
- EOT_STRING = "<EOT>"
26
-
27
-
28
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
29
- system_prompt: str) -> str:
30
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
31
- do_strip = False
32
- for user_input, response in chat_history:
33
- user_input = user_input.strip() if do_strip else user_input
34
- do_strip = True
35
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
36
- message = message.strip() if do_strip else message
37
- texts.append(f'{message} [/INST]')
38
- return ''.join(texts)
39
-
40
-
41
- def run(message: str,
42
- chat_history: list[tuple[str, str]],
43
- system_prompt: str,
44
- max_new_tokens: int = 1024,
45
- temperature: float = 0.1,
46
- top_p: float = 0.9,
47
- top_k: int = 50) -> Iterator[str]:
48
- prompt = get_prompt(message, chat_history, system_prompt)
49
-
50
- generate_kwargs = dict(
51
- max_new_tokens=max_new_tokens,
52
- do_sample=True,
53
- top_p=top_p,
54
- top_k=top_k,
55
- temperature=temperature,
56
  )
57
- stream = client.generate_stream(prompt, **generate_kwargs)
58
- output = ""
59
- for response in stream:
60
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
61
- yield output
62
- output = ""
63
- else:
64
- output += response.token.text
65
 
 
66
 
 
 
 
 
 
 
 
67
  def generate_image_caption(image_data):
68
- image_base64 = base64.b64encode(image_data).decode('utf-8')
69
- payload = {"data": ["data:image/jpeg;base64," + image_base64]}
 
70
  response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
71
  if response.status_code == 200:
72
- caption = response.json()["data"][0]
73
- return caption
 
 
 
 
74
  else:
75
- return "Error: Unable to generate caption"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
77
 
78
- import random
 
79
 
80
- from fastapi import Query
81
- from deep_translator import GoogleTranslator
82
- from deep_translator.exceptions import InvalidSourceOrTargetLanguage
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
84
 
85
- from fastapi import Query
 
 
 
86
 
 
87
  @app.post("/generate-story/")
88
- async def generate_story(image_file: UploadFile = File(...), language: str = Form(...)):
89
  image_data = await image_file.read()
90
- system_prompt = f"write an attractive story in 300 words about {random.choice(['an adventurous journey', 'a mysterious encounter', 'a heroic quest', 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', 'a dangerous mission', 'a romantic escapade', 'an epic battle', 'a journey into the unknown'])}"
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  caption = generate_image_caption(image_data)
93
  if caption.startswith("Error"):
 
94
  raise HTTPException(status_code=500, detail=caption)
95
- ai_response = next(run(caption, [], system_prompt))
96
-
97
- if language != "english":
98
- translator = GoogleTranslator(source='english', target=language)
99
- ai_response = translator.translate(ai_response)
100
 
101
- return {"story": ai_response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ # Keep these if you use them elsewhere in your app (HTML, static files)
3
+ # from fastapi.responses import HTMLResponse
4
+ # from fastapi.staticfiles import StaticFiles
5
+ # from fastapi.templating import Jinja2Templates
6
+ # from fastapi.responses import FileResponse
7
+
8
  import requests
9
  import base64
 
10
  import os
11
+ import random
12
+
13
+ # Import necessary classes from transformers
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Added BitsAndBytesConfig in case you ever need quantization
16
+
17
+
18
  from deep_translator import GoogleTranslator
19
+ from deep_translator.exceptions import InvalidSourceOrTargetLanguage
20
+
21
 
22
  app = FastAPI()
23
 
24
+ # --- Hugging Face Model Setup (Local) ---
25
+ # Model name for Gemma 2B Instruction-Tuned
26
+ # This version is trained to follow instructions, ideal for your task.
27
+ model_name = "google/gemma-2b-it"
28
+ tokenizer = None
29
+ model = None
30
+
31
+ # Function to load the model and tokenizer
32
+ def load_model():
33
+ global tokenizer, model
34
+ print(f"Loading model: {model_name}...")
35
+
36
+ # Load tokenizer
37
+ # trust_remote_code=True might be needed for some newer models/features,
38
+ # but standard Gemma usually works without it. Let's omit it for security unless necessary.
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+
41
+ # Load model - Gemma can be loaded in float16 to save RAM
42
+ # On CPU, float16 performance can vary, but it reduces memory bandwidth
43
+ # which can sometimes help. 16GB RAM is plenty for Gemma 2B float16 (~2GB).
44
+ # We don't need quantization (load_in_8bit/4bit) for Gemma 2B with 16GB RAM,
45
+ # but it's an option for larger models or less RAM.
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ torch_dtype=torch.float16, # Use float16 precision
49
+ # device_map="auto" # Not strictly needed for single CPU inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+ # model.to("cpu") # Explicitly ensure it's on CPU, although from_pretrained on CPU does this.
 
 
 
 
 
 
 
52
 
53
+ print(f"Model {model_name} loaded successfully.")
54
 
55
+ # Load the model when the app starts
56
+ @app.on_event("startup")
57
+ async def startup_event():
58
+ load_model()
59
+
60
+ # --- Image Captioning (External API - Keep) ---
61
+ # Keep this as it is, it uses an external service
62
  def generate_image_caption(image_data):
63
+ payload = {"data": ["data:image/jpeg;base64," + base64.b64encode(image_data).decode('utf-8')]}
64
+ # Use the correct URL for the captioning API. This is the one from your original code.
65
+ # Ensure it's stable or replace if needed.
66
  response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
67
  if response.status_code == 200:
68
+ try:
69
+ result = response.json()
70
+ caption = result.get("data", ["Error: Unexpected API response format"])[0]
71
+ return caption
72
+ except Exception as e:
73
+ return f"Error: Failed to parse caption API response: {e}"
74
  else:
75
+ return f"Error: Caption API returned status code {response.status_code}: {response.text}"
76
+
77
+
78
+ # --- Gemma Story Generation Function ---
79
+ # Replace the old generation function with one specific to Gemma-IT
80
+ def generate_story_gemma(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
81
+ """
82
+ Generates text using the loaded Gemma model.
83
+ Applies the Gemma-IT chat template to the prompt.
84
+ """
85
+ if tokenizer is None or model is None:
86
+ raise RuntimeError("Model and tokenizer not loaded. App startup failed?")
87
+
88
+ # Gemma-IT uses a specific chat template. We wrap the user's prompt in it.
89
+ messages = [
90
+ {"role": "user", "content": prompt_text}
91
+ # You could add a system prompt here if desired, but Gemma-IT
92
+ # often works well with a detailed user prompt.
93
+ ]
94
+
95
+ # Apply the chat template. This adds the necessary special tokens
96
+ # and formatting for the model to understand the instruction.
97
+ # `add_generation_prompt=True` adds the token that signals the model
98
+ # should start generating its response.
99
+ input_text = tokenizer.apply_chat_template(
100
+ messages,
101
+ tokenize=False, # Keep as string for encoding later
102
+ add_generation_prompt=True # Add the assistant turn prompt
103
+ )
104
 
105
+ # Encode the templated prompt
106
+ # Max length should consider the prompt length + generated length
107
+ # Max input context for Gemma is 8192 tokens, but keeping prompt shorter is better for CPU
108
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024) # Using a reasonable max_length for input
109
 
110
+ # Ensure inputs are on the correct device (CPU by default)
111
+ # inputs = {k: v.to(model.device) for k, v in inputs.items()} # Redundant on CPU
112
 
113
+ # Generate text
114
+ # The generate method returns the input_ids plus the generated tokens
115
+ generate_ids = model.generate(
116
+ inputs.input_ids,
117
+ max_new_tokens=max_new_tokens,
118
+ do_sample=True, # Set to True for creative text generation
119
+ temperature=temperature,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ pad_token_id=tokenizer.pad_token_id, # Use the pad token during generation
123
+ # Gemma's EOS token is handled by default generate logic often
124
+ # eos_token_id=tokenizer.eos_token_id
125
+ )
126
 
127
+ # Decode the generated text.
128
+ # We slice generate_ids to exclude the input prompt tokens, only decoding the new ones.
129
+ # The slicing [0, inputs.input_ids.shape[-1]:] selects the generated part for the first (and only) item in the batch
130
+ # The `skip_special_tokens=True` removes special tokens like <start_of_turn>, <end_of_turn>, <eos>
131
+ generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
132
 
133
+ # Gemma responses might sometimes include extra whitespace or turn markers if decoding is not perfect.
134
+ # Further cleanup might be needed depending on the exact output format, but skip_special_tokens helps.
135
+ # We can also remove leading/trailing whitespace.
136
+ return generated_text.strip()
137
 
138
+ # --- FastAPI Endpoint ---
139
  @app.post("/generate-story/")
140
+ async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
141
  image_data = await image_file.read()
 
142
 
143
+ # Choose a random theme for the story prompt
144
+ story_theme = random.choice([
145
+ 'an adventurous journey',
146
+ 'a mysterious encounter',
147
+ 'a heroic quest',
148
+ 'a magical adventure',
149
+ 'a thrilling escape',
150
+ 'an unexpected discovery',
151
+ 'a dangerous mission',
152
+ 'a romantic escapade',
153
+ 'an epic battle',
154
+ 'a journey into the unknown'
155
+ ])
156
+
157
+ # Get image caption
158
  caption = generate_image_caption(image_data)
159
  if caption.startswith("Error"):
160
+ print(f"Caption generation failed: {caption}")
161
  raise HTTPException(status_code=500, detail=caption)
 
 
 
 
 
162
 
163
+ # Construct the detailed prompt for Gemma-IT.
164
+ # Instruct it clearly to write a story based on the theme and incorporating the caption.
165
+ prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
166
+
167
+ # Generate the story using the local Gemma model
168
+ try:
169
+ story = generate_story_gemma(
170
+ prompt_text,
171
+ max_new_tokens=300, # Generate up to 300 new tokens
172
+ temperature=0.7, # Controls randomness (higher = more random)
173
+ top_p=0.9, # Controls diversity (nucleus sampling)
174
+ top_k=50 # Controls diversity (top-k sampling)
175
+ )
176
+ # Basic cleanup: Sometimes models might start with whitespace or unwanted characters
177
+ story = story.strip()
178
+
179
+ except Exception as e:
180
+ print(f"Story generation failed: {e}") # Log generation errors
181
+ # Provide more detail in the HTTP exception for debugging
182
+ raise HTTPException(status_code=500, detail=f"Story generation failed: {e}. Please check Space logs for details.")
183
+
184
+
185
+ # Translate the story if the target language is not English
186
+ if language.lower() != "english":
187
+ try:
188
+ translator = GoogleTranslator(source='english', target=language.lower())
189
+ translated_story = translator.translate(story)
190
+
191
+ if translated_story is None:
192
+ print(f"Translation returned None for language: {language}")
193
+ # Return English story with a warning
194
+ return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
195
+
196
+ story = translated_story
197
+
198
+ except InvalidSourceOrTargetLanguage:
199
+ print(f"Invalid target language requested: {language}")
200
+ raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
201
+ except Exception as e:
202
+ print(f"Translation failed for language {language}: {e}") # Log translation errors
203
+ raise HTTPException(status_code=500, detail=f"Translation failed: {e}")
204
+
205
+
206
+ # Return the generated (and potentially translated) story
207
+ return {"story": story}
208
+
209
+ # --- Optional: Serve a simple HTML form for testing (Needs templates dir and index.html) ---
210
+ # from fastapi import Request
211
+ # from fastapi.templating import Jinja2Templates
212
+ # from fastapi.staticfiles import StaticFiles
213
+
214
+ # templates = Jinja2Templates(directory="templates")
215
+ # app.mount("/static", StaticFiles(directory="static"), name="static")
216
+
217
+ # @app.get("/", response_class=HTMLResponse)
218
+ # async def read_root(request: Request):
219
+ # return templates.TemplateResponse("index.html", {"request": request})