project21 commited on
Commit
eed221f
·
verified ·
1 Parent(s): 1113ff0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +481 -0
app.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ ChatGPT-Premium-like open-source Gradio app with:
4
+ - multi-image upload (practical "unlimited" via disk+queue)
5
+ - OCR (PaddleOCR preferred, fallback to pytesseract)
6
+ - Visual reasoning (LLaVA/MiniGPT-style if model available)
7
+ - Math/aptitude pipeline (OCR -> math-specialized LLM)
8
+ - Caching of processed images & embeddings
9
+ - Simple in-process queue & streaming text output
10
+ - Rate-limiting per-client (token-bucket)
11
+
12
+ NOTES:
13
+ - Replace model IDs with ones that match your hardware/quotas.
14
+ - For production, swap the in-process queue with Redis/Celery and use S3/MinIO for storage.
15
+ - Achieving strictly "better than ChatGPT" across the board is unrealistic; this app aims to be the best open-source approximation.
16
+ """
17
+
18
+ import os
19
+ import time
20
+ import uuid
21
+ import threading
22
+ import queue
23
+ import json
24
+ import math
25
+ from pathlib import Path
26
+ from typing import List, Dict, Tuple, Optional
27
+ from collections import defaultdict, deque
28
+
29
+ import gradio as gr
30
+ from PIL import Image
31
+ import torch
32
+ from transformers import (
33
+ AutoProcessor, AutoModelForCausalLM,
34
+ AutoTokenizer, TextIteratorStreamer
35
+ )
36
+
37
+ # Optional OCR libs
38
+ try:
39
+ from paddleocr import PaddleOCR # pip install paddleocr
40
+ PADDLE_AVAILABLE = True
41
+ except Exception:
42
+ PADDLE_AVAILABLE = False
43
+
44
+ try:
45
+ import pytesseract # pip install pytesseract
46
+ TESSERACT_AVAILABLE = True
47
+ except Exception:
48
+ TESSERACT_AVAILABLE = False
49
+
50
+ # ---------------------------
51
+ # CONFIG: change these values
52
+ # ---------------------------
53
+ # Paths
54
+ DATA_DIR = Path("data")
55
+ IMAGES_DIR = DATA_DIR / "images"
56
+ CACHE_DIR = DATA_DIR / "cache"
57
+ IMAGES_DIR.mkdir(parents=True, exist_ok=True)
58
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
59
+
60
+ # Models - pick models appropriate to your hardware.
61
+ # Visual reasoning model (LLaVA-style). If not available locally, this pipeline will skip visual-model step.
62
+ VISUAL_MODEL_ID = "liuhaotian/llava-v1.5-7b" # heavy; change to smaller if needed
63
+ VISUAL_USE = True # set False to skip LLaVA step
64
+
65
+ # Math/Reasoning LLM
66
+ MATH_LLM_ID = "mistralai/Mistral-7B-Instruct-v0.2" # good balance; change if you prefer LLaMA etc.
67
+
68
+ # Device
69
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
+
71
+ # Limits & performance tuning
72
+ MAX_IMAGES_PER_REQUEST = 64 # reasonable UI limit
73
+ BATCH_SIZE = 4 # how many images we process at once for visual models
74
+ MAX_HISTORY_TOKENS = 2048
75
+ STREAM_CHUNK_SECONDS = 0.12 # how often we yield tokens to user during streaming
76
+
77
+ # Rate limit settings (simple token bucket)
78
+ RATE_TOKENS = 40 # tokens added per interval
79
+ RATE_INTERVAL = 60 # seconds for refill
80
+ TOKENS_PER_REQUEST = 1 # cost per chat request (tune)
81
+
82
+ # ---------------------------
83
+ # Utilities: storage, caching
84
+ # ---------------------------
85
+ def save_uploaded_image(tempfile) -> Path:
86
+ # tempfile is from Gradio; it has .name attribute
87
+ uid = uuid.uuid4().hex
88
+ ext = Path(tempfile.name).suffix or ".png"
89
+ dest = IMAGES_DIR / f"{int(time.time())}_{uid}{ext}"
90
+ # Copy content
91
+ with open(tempfile.name, "rb") as src, open(dest, "wb") as dst:
92
+ dst.write(src.read())
93
+ return dest
94
+
95
+ # simple file-based cache for captions & ocr text
96
+ def cache_get(key: str) -> Optional[str]:
97
+ p = CACHE_DIR / f"{key}.json"
98
+ if p.exists():
99
+ try:
100
+ return json.loads(p.read_text())["value"]
101
+ except Exception:
102
+ return None
103
+ return None
104
+
105
+ def cache_set(key: str, value: str):
106
+ p = CACHE_DIR / f"{key}.json"
107
+ p.write_text(json.dumps({"value": value}))
108
+
109
+ def path_hash(p: Path) -> str:
110
+ # simple hash: file size + mtime
111
+ st = p.stat()
112
+ return f"{p.name}-{st.st_size}-{int(st.st_mtime)}"
113
+
114
+ # ---------------------------
115
+ # Rate limiter (per ip)
116
+ # ---------------------------
117
+ class TokenBucket:
118
+ def __init__(self, rate=RATE_TOKENS, per=RATE_INTERVAL):
119
+ self.rate = rate
120
+ self.per = per
121
+ self.allowance = rate
122
+ self.last_check = time.time()
123
+
124
+ def consume(self, tokens=1) -> bool:
125
+ now = time.time()
126
+ elapsed = now - self.last_check
127
+ self.last_check = now
128
+ self.allowance += elapsed * (self.rate / self.per)
129
+ if self.allowance > self.rate:
130
+ self.allowance = self.rate
131
+ if self.allowance >= tokens:
132
+ self.allowance -= tokens
133
+ return True
134
+ return False
135
+
136
+ rate_buckets = defaultdict(lambda: TokenBucket())
137
+
138
+ def rate_ok(client_id: str) -> bool:
139
+ return rate_buckets[client_id].consume(TOKENS_PER_REQUEST)
140
+
141
+ # ---------------------------
142
+ # OCR utilities
143
+ # ---------------------------
144
+ paddle_ocr = None
145
+ if PADDLE_AVAILABLE:
146
+ paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en") # slow to init first time
147
+
148
+ def run_ocr(path: Path) -> str:
149
+ """
150
+ High-quality OCR pipeline: PaddleOCR -> pytesseract fallback
151
+ """
152
+ key = f"ocr-{path_hash(path)}"
153
+ cached = cache_get(key)
154
+ if cached:
155
+ return cached
156
+
157
+ text = ""
158
+ try:
159
+ if paddle_ocr:
160
+ result = paddle_ocr.ocr(str(path), cls=True)
161
+ lines = []
162
+ for rec in result:
163
+ for box, rec_res in rec:
164
+ txt = rec_res[0]
165
+ lines.append(txt)
166
+ text = "\n".join(lines).strip()
167
+ except Exception as e:
168
+ # paddle may fail on some setups
169
+ text = ""
170
+
171
+ if not text and TESSERACT_AVAILABLE:
172
+ try:
173
+ pil = Image.open(path).convert("RGB")
174
+ text = pytesseract.image_to_string(pil)
175
+ text = text.strip()
176
+ except Exception:
177
+ text = ""
178
+
179
+ if not text:
180
+ text = ""
181
+
182
+ cache_set(key, text or "")
183
+ return text
184
+
185
+ # ---------------------------
186
+ # Visual reasoning (LLaVA) wrapper
187
+ # ---------------------------
188
+ visual_processor = None
189
+ visual_model = None
190
+ visual_tokenizer = None
191
+
192
+ def init_visual_model():
193
+ global visual_processor, visual_model, visual_tokenizer
194
+ if not VISUAL_USE:
195
+ return
196
+ try:
197
+ visual_processor = AutoProcessor.from_pretrained(VISUAL_MODEL_ID)
198
+ visual_model = AutoModelForCausalLM.from_pretrained(
199
+ VISUAL_MODEL_ID,
200
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
201
+ device_map="auto"
202
+ )
203
+ # Some LLaVA models need tokenizer from model repo
204
+ visual_tokenizer = AutoTokenizer.from_pretrained(VISUAL_MODEL_ID, use_fast=False)
205
+ print("Visual model loaded.")
206
+ except Exception as e:
207
+ print("Could not load visual model:", e)
208
+ # disable visual if fails
209
+ visual_processor = visual_model = visual_tokenizer = None
210
+
211
+ # Combine visual and text pipelines: pass image + question -> string answer
212
+ def run_visual_reasoning(image_path: Path, question: str, max_new_tokens=256) -> str:
213
+ if visual_processor is None or visual_model is None:
214
+ return ""
215
+ key = f"visual-{path_hash(image_path)}-{question[:96]}"
216
+ cached = cache_get(key)
217
+ if cached:
218
+ return cached
219
+
220
+ try:
221
+ image = Image.open(image_path).convert("RGB")
222
+ inputs = visual_processor(images=image, text=question, return_tensors="pt").to(DEVICE)
223
+ with torch.no_grad():
224
+ outs = visual_model.generate(**inputs, max_new_tokens=max_new_tokens)
225
+ ans = visual_tokenizer.decode(outs[0], skip_special_tokens=True)
226
+ cache_set(key, ans)
227
+ return ans
228
+ except Exception as e:
229
+ print("Visual reasoning error:", e)
230
+ return ""
231
+
232
+ # ---------------------------
233
+ # Math/Reasoning LLM init
234
+ # ---------------------------
235
+ math_tokenizer = None
236
+ math_model = None
237
+
238
+ def init_math_model():
239
+ global math_tokenizer, math_model
240
+ try:
241
+ math_tokenizer = AutoTokenizer.from_pretrained(MATH_LLM_ID, use_fast=False)
242
+ math_model = AutoModelForCausalLM.from_pretrained(
243
+ MATH_LLM_ID,
244
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
245
+ device_map="auto"
246
+ )
247
+ print("Math LLM loaded.")
248
+ except Exception as e:
249
+ print("Could not load math model:", e)
250
+ math_model = None
251
+
252
+ def ask_math_llm(prompt: str, stream=False):
253
+ """
254
+ If stream=True, return a generator which yields partial text as generated.
255
+ Otherwise, return final string.
256
+ """
257
+ if math_model is None:
258
+ return "Math model not available."
259
+
260
+ inputs = math_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_HISTORY_TOKENS).to(DEVICE)
261
+
262
+ if not stream:
263
+ with torch.no_grad():
264
+ out_ids = math_model.generate(**inputs, max_new_tokens=512)
265
+ return math_tokenizer.decode(out_ids[0], skip_special_tokens=True)
266
+
267
+ # streaming mode using TextIteratorStreamer
268
+ streamer = TextIteratorStreamer(math_tokenizer, skip_prompt=True, skip_special_tokens=True)
269
+ generation_kwargs = dict(
270
+ **inputs,
271
+ streamer=streamer,
272
+ max_new_tokens=512,
273
+ do_sample=True,
274
+ temperature=0.7,
275
+ top_p=0.9
276
+ )
277
+ thread = threading.Thread(target=math_model.generate, kwargs=generation_kwargs)
278
+ thread.start()
279
+ # yield chunks from streamer
280
+ buffer = ""
281
+ for new_text in streamer:
282
+ buffer += new_text
283
+ yield buffer
284
+
285
+ # ---------------------------
286
+ # Simple in-process queue for heavy tasks (visual + OCR)
287
+ # ---------------------------
288
+ work_q = queue.Queue(maxsize=256)
289
+ results_cache = {} # job_id -> result
290
+
291
+ def worker_loop():
292
+ while True:
293
+ job = work_q.get()
294
+ if job is None:
295
+ break
296
+ job_id, image_paths, question = job
297
+ try:
298
+ ocr_texts = [run_ocr(p) for p in image_paths]
299
+ visual_texts = []
300
+ if visual_processor and visual_model:
301
+ for p in image_paths:
302
+ v = run_visual_reasoning(p, question)
303
+ visual_texts.append(v)
304
+ # combine
305
+ combined = {
306
+ "ocr": ocr_texts,
307
+ "visual": visual_texts
308
+ }
309
+ results_cache[job_id] = combined
310
+ except Exception as e:
311
+ results_cache[job_id] = {"error": str(e)}
312
+ finally:
313
+ work_q.task_done()
314
+
315
+ # start a few worker threads
316
+ NUM_WORKERS = max(1, min(4, (os.cpu_count() or 2)//2))
317
+ for _ in range(NUM_WORKERS):
318
+ t = threading.Thread(target=worker_loop, daemon=True)
319
+ t.start()
320
+
321
+ # ---------------------------
322
+ # Main chat pipeline: orchestrates OCR/visual + math llm + chat memory
323
+ # ---------------------------
324
+ def build_prompt(system_prompt: str, chat_history: List[Tuple[str,str]], extracted_texts: List[str], user_question: str) -> str:
325
+ # Keep a compact, relevant prompt
326
+ history_text = ""
327
+ for role, text in chat_history[-8:]: # keep last N turns
328
+ history_text += f"{role}: {text}\n"
329
+ img_ctx = ""
330
+ if extracted_texts:
331
+ img_ctx = "\n\nEXTRACTED_FROM_IMAGES:\n" + "\n---\n".join(extracted_texts)
332
+ prompt = f"""{system_prompt}
333
+
334
+ Conversation:
335
+ {history_text}
336
+
337
+ User question:
338
+ {user_question}
339
+
340
+ {img_ctx}
341
+
342
+ Assistant (explain step-by-step, show calculations if any):"""
343
+ return prompt
344
+
345
+ SYSTEM_PROMPT = "You are a helpful assistant that solves aptitude, math, and image-based questions. Be precise, show steps, and if images contain diagrams refer to them."
346
+
347
+ # simple memory per-session (in-memory). For production, persist in DB.
348
+ SESSION_MEMORY = defaultdict(lambda: {"history": [], "embeddings": []})
349
+
350
+ def process_request(client_id: str, uploaded_files, user_question: str, stream=True):
351
+ # Rate limiting
352
+ if not rate_ok(client_id):
353
+ return ["Rate limit exceeded. Try again later."]
354
+
355
+ # Save uploaded files
356
+ image_paths = []
357
+ for f in (uploaded_files or []):
358
+ p = save_uploaded_image(f)
359
+ image_paths.append(p)
360
+ if len(image_paths) > MAX_IMAGES_PER_REQUEST:
361
+ return [f"Too many images - max {MAX_IMAGES_PER_REQUEST}"]
362
+
363
+ # Create job to process OCR+visual
364
+ job_id = uuid.uuid4().hex
365
+ work_q.put((job_id, image_paths, user_question))
366
+
367
+ # Wait for job to complete (small timeout) — for more scalable UI this should be async and notify user later.
368
+ wait_seconds = 0
369
+ while job_id not in results_cache and wait_seconds < 12:
370
+ time.sleep(0.25)
371
+ wait_seconds += 0.25
372
+
373
+ if job_id not in results_cache:
374
+ # fallback: run basic OCR inline (slower but reliable)
375
+ ocr_texts = [run_ocr(p) for p in image_paths]
376
+ visual_texts = []
377
+ if visual_processor and visual_model:
378
+ for p in image_paths:
379
+ visual_texts.append(run_visual_reasoning(p, user_question))
380
+ results = {"ocr": ocr_texts, "visual": visual_texts}
381
+ else:
382
+ results = results_cache.pop(job_id, {"ocr": [], "visual": []})
383
+
384
+ # Build final extracted_texts list combining OCR + visual captions intelligently
385
+ extracted_texts = []
386
+ for o, v in zip(results.get("ocr", []), results.get("visual", [])):
387
+ parts = []
388
+ if o:
389
+ parts.append("OCR: " + o)
390
+ if v:
391
+ parts.append("Visual: " + v)
392
+ combined = "\n".join(parts).strip()
393
+ if combined:
394
+ extracted_texts.append(combined)
395
+
396
+ # add to session memory
397
+ sess = SESSION_MEMORY[client_id]
398
+ sess["history"].append(("User", user_question))
399
+ # Build LLM prompt
400
+ prompt = build_prompt(SYSTEM_PROMPT, sess["history"], extracted_texts, user_question)
401
+
402
+ # stream or non-stream generation
403
+ if stream:
404
+ # streaming generator using ask_math_llm(stream=True)
405
+ yield from _stream_llm_response_generator(prompt, client_id)
406
+ else:
407
+ answer = ask_math_llm(prompt, stream=False)
408
+ sess["history"].append(("Assistant", answer))
409
+ return [answer]
410
+
411
+ def _stream_llm_response_generator(prompt: str, client_id: str):
412
+ # yield progressive updates to Gradio UI (the generator returns strings)
413
+ # Gradio chat with streaming expects generator that yields partial strings
414
+ session = SESSION_MEMORY[client_id]
415
+ # Start streaming
416
+ gen = ask_math_llm(prompt, stream=True)
417
+ partial = ""
418
+ for chunk in gen:
419
+ # chunk is the current buffer; yield once per small delay
420
+ partial = chunk
421
+ # also update session memory at end (approximate)
422
+ yield partial
423
+ # final append
424
+ session["history"].append(("Assistant", partial))
425
+
426
+ # ---------------------------
427
+ # GRADIO UI
428
+ # ---------------------------
429
+ with gr.Blocks(css="""
430
+ /* small CSS to make chat look nicer */
431
+ .chat-column { max-width: 900px; margin-left: auto; margin-right: auto; }
432
+ """) as demo:
433
+
434
+ gr.Markdown("# 🚀 Open-Source ChatGPT-like (Multimodal)")
435
+
436
+ with gr.Row():
437
+ with gr.Column(scale=8, elem_classes="chat-column"):
438
+ chatbot = gr.Chatbot(label="Assistant", elem_id="chatbot", show_label=False).style(height=600)
439
+ with gr.Row():
440
+ txt = gr.Textbox(label="Type a message...", placeholder="Ask a question or upload images", show_label=False)
441
+ submit = gr.Button("Send")
442
+ with gr.Row():
443
+ img_in = gr.File(label="Upload images (multiple)", file_count="multiple", file_types=["image"])
444
+ clear_btn = gr.Button("New Chat")
445
+ client_id_state = gr.State(str(uuid.uuid4())) # simple per-window client id for rate limiting
446
+
447
+ def handle_send(message, client_state, files):
448
+ client_id = client_state or str(uuid.uuid4())
449
+ # process_request yields a generator; Gradio supports returning generator -> progressive updates
450
+ # We return a generator that yields strings; then the front-end should append them to chat.
451
+ gen = process_request(client_id, files, message, stream=True)
452
+ # We'll wrap generator so Gradio can consume it; we will return a tuple (new user msg textbox, new history)
453
+ # But Gradio expects the function to return: (textbox_clear, new_chat_history)
454
+ # We'll implement a simple approach: produce a list of chunks and finally return them as a single assistant message.
455
+ collected = ""
456
+ try:
457
+ for part in gen:
458
+ collected = part # partial buffer
459
+ # return immediate partial update to be appended in chat — in current Gradio versions returning generator directly is best
460
+ yield "", [( "User", message ), ("Assistant", collected )]
461
+ except Exception as e:
462
+ yield "", [( "User", message ), ("Assistant", f"Error generating: {e}" )]
463
+ # final update (guarantee)
464
+ yield "", [( "User", message ), ("Assistant", collected )]
465
+
466
+ # Connect send button and textbox
467
+ submit.click(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot])
468
+ txt.submit(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot])
469
+
470
+ def clear_chat():
471
+ client_id_state.value = str(uuid.uuid4())
472
+ return [], ""
473
+ clear_btn.click(lambda: ([], "" ), None, [chatbot, txt])
474
+
475
+ # initialize heavy models in background to avoid blocking Gradio start
476
+ def bg_init():
477
+ init_visual_model()
478
+ init_math_model()
479
+ threading.Thread(target=bg_init, daemon=True).start()
480
+
481
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)