|
import re |
|
import json |
|
import unicodedata |
|
import io |
|
from flask import current_app |
|
from gradio_client import Client |
|
import pandas as pd |
|
from PIL import Image |
|
import base64 |
|
|
|
class ContentService: |
|
"""Service for AI content generation using Hugging Face models.""" |
|
|
|
def __init__(self, hugging_key=None): |
|
|
|
self.hugging_key = hugging_key or current_app.config.get('HUGGING_KEY') |
|
|
|
self.client = Client("Zelyanoth/Linkedin_poster_dev", hf_token=self.hugging_key) |
|
|
|
def validate_unicode_content(self, content): |
|
"""Validate Unicode content while preserving original formatting and spaces.""" |
|
if not content or not isinstance(content, str): |
|
return content |
|
|
|
try: |
|
|
|
content.encode('utf-8') |
|
return content |
|
except UnicodeEncodeError: |
|
try: |
|
|
|
return content.encode('utf-8', errors='replace').decode('utf-8') |
|
except: |
|
|
|
return str(content) |
|
|
|
def preserve_formatting(self, content): |
|
"""Preserve spaces, line breaks, and paragraph formatting.""" |
|
if not content: |
|
return content |
|
|
|
|
|
|
|
try: |
|
|
|
content.encode('utf-8') |
|
return content |
|
except UnicodeEncodeError: |
|
|
|
return content.encode('utf-8', errors='replace').decode('utf-8') |
|
|
|
def sanitize_content_for_api(self, content): |
|
"""Sanitize content for API calls while preserving original text, spaces, and formatting.""" |
|
if not content: |
|
return content |
|
|
|
|
|
preserved = self.preserve_formatting(content) |
|
|
|
|
|
validated = self.validate_unicode_content(preserved) |
|
|
|
|
|
if '\x00' in validated: |
|
validated = validated.replace('\x00', '') |
|
|
|
|
|
validated = validated.replace('\r\n', '\n').replace('\r', '\n') |
|
|
|
return validated |
|
|
|
def _is_base64_image(self, data): |
|
"""Check if the data is a base64 encoded image string.""" |
|
if not isinstance(data, str): |
|
return False |
|
|
|
|
|
if data.startswith('data:image/'): |
|
return True |
|
|
|
|
|
try: |
|
|
|
if ',' in data: |
|
base64_part = data.split(',')[1] |
|
else: |
|
base64_part = data |
|
|
|
|
|
base64.b64decode(base64_part, validate=True) |
|
return True |
|
except Exception: |
|
return False |
|
|
|
def _base64_to_bytes(self, base64_string): |
|
"""Convert a base64 encoded string to bytes.""" |
|
try: |
|
|
|
if base64_string.startswith('data:image/'): |
|
base64_part = base64_string.split(',')[1] |
|
else: |
|
base64_part = base64_string |
|
|
|
|
|
return base64.b64decode(base64_part, validate=True) |
|
except Exception as e: |
|
current_app.logger.error(f"Failed to decode base64 image: {str(e)}") |
|
raise Exception(f"Failed to decode base64 image: {str(e)}") |
|
|
|
def generate_post_content(self, user_id: str) -> tuple: |
|
""" |
|
Generate post content using AI. |
|
|
|
Args: |
|
user_id (str): User ID for personalization |
|
|
|
Returns: |
|
tuple: (Generated post content, Image URL or None) |
|
""" |
|
try: |
|
|
|
result = self.client.predict( |
|
code=user_id, |
|
api_name="/poster_linkedin" |
|
) |
|
|
|
|
|
|
|
try: |
|
parsed_result = json.loads(result) |
|
except json.JSONDecodeError: |
|
|
|
try: |
|
|
|
import ast |
|
parsed_result = ast.literal_eval(result) |
|
except (ValueError, SyntaxError): |
|
|
|
parsed_result = [result] |
|
|
|
|
|
if isinstance(parsed_result, list): |
|
generated_content = parsed_result[0] if parsed_result and parsed_result[0] is not None else "Generated content will appear here..." |
|
|
|
image_data = parsed_result[1] if len(parsed_result) > 1 and parsed_result[1] is not None else None |
|
else: |
|
generated_content = str(parsed_result) if parsed_result is not None else "Generated content will appear here..." |
|
image_data = None |
|
|
|
|
|
sanitized_content = self.sanitize_content_for_api(generated_content) |
|
|
|
|
|
final_content = self.preserve_formatting(sanitized_content) |
|
|
|
|
|
image_bytes = None |
|
if image_data: |
|
if self._is_base64_image(image_data): |
|
|
|
image_bytes = self._base64_to_bytes(image_data) |
|
else: |
|
|
|
image_bytes = image_data |
|
|
|
return (final_content, image_bytes) |
|
|
|
except Exception as e: |
|
error_message = str(e) |
|
current_app.logger.error(f"Content generation failed: {error_message}") |
|
raise Exception(f"Content generation failed: {error_message}") |
|
|
|
def add_rss_source(self, rss_link: str, user_id: str) -> str: |
|
""" |
|
Add an RSS source for content generation. |
|
|
|
Args: |
|
rss_link (str): RSS feed URL |
|
user_id (str): User ID |
|
|
|
Returns: |
|
str: Result message |
|
""" |
|
try: |
|
|
|
rss_input = f"{rss_link}__thi_irrh'èçs_my_id__! {user_id}" |
|
sanitized_rss_input = self.sanitize_content_for_api(rss_input) |
|
|
|
result = self.client.predict( |
|
rss_link=sanitized_rss_input, |
|
api_name="/ajouter_rss" |
|
) |
|
|
|
|
|
sanitized_result = self.sanitize_content_for_api(result) |
|
return self.preserve_formatting(sanitized_result) |
|
|
|
except Exception as e: |
|
raise Exception(f"Failed to add RSS source: {str(e)}") |