Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| # Logging einrichten | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Lade das Modell und den Processor | |
| try: | |
| logger.info("Loading model: microsoft/florence-2-base") | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-base", trust_remote_code=True) | |
| processor = AutoProcessor.from_pretrained("microsoft/florence-2-base", trust_remote_code=True) | |
| logger.info("Model and processor loaded successfully") | |
| except Exception as e: | |
| logger.error("Failed to load model: %s", str(e)) | |
| raise | |
| def analyze_image(image, prompt): | |
| logger.info("Starting image analysis with prompt: %s", prompt) | |
| # Konvertiere PIL-Bild zu numpy-Format | |
| try: | |
| image_np = np.array(image) | |
| image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| logger.info("Image shape: %s", image_np.shape) | |
| except Exception as e: | |
| logger.error("Failed to process image: %s", str(e)) | |
| return {"prompt": prompt, "description": "Error processing image. Ensure a valid image is uploaded."} | |
| # Bildvorverarbeitung: Kontrast erhöhen | |
| try: | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY) | |
| enhanced = clahe.apply(gray) | |
| image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) | |
| logger.info("Image preprocessing completed") | |
| except Exception as e: | |
| logger.warning("Failed to preprocess image: %s", str(e)) | |
| # Allgemeine Bildbeschreibung | |
| if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower(): | |
| try: | |
| inputs = processor(text=prompt, images=image_np, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_length=1024, | |
| num_beams=3 | |
| ) | |
| description = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return {"prompt": prompt, "description": description} | |
| except Exception as e: | |
| logger.error("Failed to generate description: %s", str(e)) | |
| return {"prompt": prompt, "description": "Error generating description. Try again with a clear image."} | |
| # Kerzen-Analyse | |
| elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower(): | |
| try: | |
| task_prompt = "<OD>" # Objekterkennung | |
| inputs = processor(text=task_prompt, images=image_np, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_length=1024, | |
| num_beams=3 | |
| ) | |
| predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0])) | |
| logger.info("Detected objects: %s", predictions) | |
| detections = [] | |
| if "<OD>" in predictions: | |
| for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])): | |
| # Erweitere Filter für Kerzen | |
| if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower() and "candlestick" not in label.lower(): | |
| continue | |
| xmin, ymin, xmax, ymax = map(int, bbox) | |
| # Extrahiere Farbe | |
| candle_roi = image_cv[ymin:ymax, xmin:xmax] | |
| if candle_roi.size == 0: | |
| logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax) | |
| continue | |
| mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int) | |
| color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})" | |
| # OCR für Preise (erweiterte ROI) | |
| price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200), | |
| max(0, xmin-200):min(image_np.shape[1], xmax+200)] | |
| ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt") | |
| with torch.no_grad(): | |
| ocr_outputs = model.generate( | |
| input_ids=ocr_inputs["input_ids"], | |
| pixel_values=ocr_inputs["pixel_values"], | |
| max_length=1024 | |
| ) | |
| prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0] | |
| detections.append({ | |
| "pattern": label, | |
| "color": color_rgb, | |
| "prices": prices if prices else "No price detected", | |
| "x_center": (xmin + xmax) / 2 | |
| }) | |
| # Sortiere nach x-Position (rechts nach links = neueste Kerzen) | |
| detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8] | |
| logger.info("Sorted detections: %d", len(detections)) | |
| if not detections: | |
| logger.warning("No candlesticks detected. Ensure clear image with visible candles.") | |
| return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."} | |
| return {"prompt": prompt, "detections": detections} | |
| except Exception as e: | |
| logger.error("Failed to analyze candles: %s", str(e)) | |
| return {"prompt": prompt, "description": "Error analyzing candles. Try a clearer screenshot with visible candles and prices."} | |
| else: | |
| return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."} | |
| # Erstelle Gradio-Schnittstelle | |
| iface = gr.Interface( | |
| fn=analyze_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload an Image"), | |
| gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'") | |
| ], | |
| outputs="json", | |
| title="Image Analysis with Florence-2-base", | |
| description="Upload an image to analyze candlesticks or get a general description." | |
| ) | |
| iface.launch() |