Spaces:
Paused
Paused
| # ============================================================================ | |
| # CONTENTFORGE AI - FASTAPI BACKEND | |
| # REST API for multi-modal AI platform | |
| # ============================================================================ | |
| from fastapi import FastAPI, HTTPException, Header, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| import base64 | |
| from io import BytesIO | |
| import numpy as np | |
| import wave | |
| import struct | |
| # ============================================================================ | |
| # AUTHENTICATION | |
| # ============================================================================ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| print("π Authenticating with HuggingFace...") | |
| login(token=HF_TOKEN) | |
| print("β Authenticated!\n") | |
| from transformers import ( | |
| T5Tokenizer, T5ForConditionalGeneration, | |
| Qwen2VLForConditionalGeneration, Qwen2VLProcessor, | |
| AutoProcessor, MusicgenForConditionalGeneration | |
| ) | |
| from peft import PeftModel | |
| from qwen_vl_utils import process_vision_info | |
| from diffusers import StableDiffusionPipeline | |
| from PIL import Image | |
| # ============================================================================ | |
| # FASTAPI APP SETUP | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="ContentForge AI API", | |
| description="Multi-modal AI API for education and social media content generation", | |
| version="1.0.0" | |
| ) | |
| # CORS - Allow requests from your frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production: ["https://yourwebsite.vercel.app"] | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Simple API key authentication (improve this for production!) | |
| API_KEYS = { | |
| "demo_key_123": "Demo User", | |
| "sk_test_456": "Test User", | |
| } | |
| def verify_api_key(x_api_key: str = Header(None)): | |
| """Verify API key from header""" | |
| if x_api_key not in API_KEYS: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| return API_KEYS[x_api_key] | |
| # ============================================================================ | |
| # LOAD MODELS | |
| # ============================================================================ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π₯οΈ Using device: {device}") | |
| print("π¦ Loading models...\n") | |
| # 1. T5 Model | |
| print("π Loading T5...") | |
| t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer") | |
| t5_model = T5ForConditionalGeneration.from_pretrained( | |
| "Bashaarat1/t5-small-arxiv-summarizer" | |
| ).to(device) | |
| t5_model.eval() | |
| print("β T5 loaded!") | |
| # 2. Qwen VLM | |
| print("π€ Loading Qwen...") | |
| qwen_base = Qwen2VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2-VL-2B-Instruct", | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| qwen_model = PeftModel.from_pretrained( | |
| qwen_base, | |
| "Bashaarat1/qwen-finetuned-scienceqa" | |
| ) | |
| qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
| qwen_model.eval() | |
| print("β Qwen loaded!") | |
| # 3. MusicGen | |
| print("π΅ Loading MusicGen...") | |
| music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small") | |
| music_model = MusicgenForConditionalGeneration.from_pretrained( | |
| "Bashaarat1/fine-tuned-musicgen-small" | |
| ).to(device) | |
| music_model.eval() | |
| print("β MusicGen loaded!") | |
| # 4. Stable Diffusion | |
| print("π¨ Loading Stable Diffusion...") | |
| sd_pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| safety_checker=None | |
| ).to(device) | |
| print("β Stable Diffusion loaded!") | |
| print("\nπ All models loaded! API ready.\n") | |
| # ============================================================================ | |
| # REQUEST/RESPONSE MODELS | |
| # ============================================================================ | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| max_length: int = 128 | |
| class SummarizeResponse(BaseModel): | |
| summary: str | |
| original_words: int | |
| summary_words: int | |
| class QARequest(BaseModel): | |
| question: str | |
| image_base64: Optional[str] = None | |
| class QAResponse(BaseModel): | |
| answer: str | |
| class ImageRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: str = "" | |
| num_steps: int = 25 | |
| class ImageResponse(BaseModel): | |
| image_base64: str | |
| class MusicRequest(BaseModel): | |
| prompt: str | |
| duration: int = 10 | |
| class MusicResponse(BaseModel): | |
| audio_base64: str | |
| sampling_rate: int | |
| format: str | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def numpy_to_wav(audio_data: np.ndarray, sampling_rate: int) -> bytes: | |
| """Convert numpy array to WAV format bytes""" | |
| # Normalize audio to -1 to 1 range | |
| audio_data = np.clip(audio_data, -1, 1) | |
| # Convert to 16-bit PCM | |
| audio_int16 = (audio_data * 32767).astype(np.int16) | |
| # Create WAV file in memory | |
| wav_io = BytesIO() | |
| with wave.open(wav_io, 'wb') as wav_file: | |
| wav_file.setnchannels(1) # Mono | |
| wav_file.setsampwidth(2) # 16-bit | |
| wav_file.setframerate(sampling_rate) | |
| wav_file.writeframes(audio_int16.tobytes()) | |
| return wav_io.getvalue() | |
| # ============================================================================ | |
| # API ENDPOINTS | |
| # ============================================================================ | |
| def root(): | |
| """API health check""" | |
| return { | |
| "status": "online", | |
| "message": "ContentForge AI API", | |
| "version": "1.0.0", | |
| "endpoints": [ | |
| "/summarize", | |
| "/qa", | |
| "/generate-image", | |
| "/generate-music" | |
| ] | |
| } | |
| def summarize( | |
| request: SummarizeRequest, | |
| user: str = Header(None, alias="x-api-key") | |
| ): | |
| """Summarize text using fine-tuned T5""" | |
| verify_api_key(user) | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| try: | |
| inputs = t5_tokenizer( | |
| f"summarize: {request.text}", | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = t5_model.generate( | |
| **inputs, | |
| max_length=request.max_length, | |
| min_length=30, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return SummarizeResponse( | |
| summary=summary, | |
| original_words=len(request.text.split()), | |
| summary_words=len(summary.split()) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def question_answer( | |
| request: QARequest, | |
| user: str = Header(None, alias="x-api-key") | |
| ): | |
| """Answer questions with optional image using Qwen VLM""" | |
| verify_api_key(user) | |
| if not request.question.strip(): | |
| raise HTTPException(status_code=400, detail="Question cannot be empty") | |
| try: | |
| image = None | |
| if request.image_base64: | |
| # Decode base64 image | |
| image_data = base64.b64decode(request.image_base64) | |
| image = Image.open(BytesIO(image_data)).convert('RGB') | |
| # Prepare messages | |
| if image is not None: | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": request.question} | |
| ] | |
| }] | |
| else: | |
| messages = [{ | |
| "role": "user", | |
| "content": [{"type": "text", "text": request.question}] | |
| }] | |
| text_prompt = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| if image is not None: | |
| img_inputs, _ = process_vision_info(messages) | |
| inputs = qwen_processor( | |
| text=[text_prompt], | |
| images=img_inputs, | |
| return_tensors="pt" | |
| ).to(device) | |
| else: | |
| inputs = qwen_processor( | |
| text=[text_prompt], | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = qwen_model.generate(**inputs, max_new_tokens=200) | |
| answer = qwen_processor.batch_decode( | |
| outputs[:, inputs.input_ids.size(1):], | |
| skip_special_tokens=True | |
| )[0].strip() | |
| return QAResponse(answer=answer) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def generate_image( | |
| request: ImageRequest, | |
| user: str = Header(None, alias="x-api-key") | |
| ): | |
| """Generate image using Stable Diffusion""" | |
| verify_api_key(user) | |
| if not request.prompt.strip(): | |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
| try: | |
| with torch.no_grad(): | |
| image = sd_pipe( | |
| request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| num_inference_steps=request.num_steps, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| # Convert image to base64 | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return ImageResponse(image_base64=img_str) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def generate_music( | |
| request: MusicRequest, | |
| user: str = Header(None, alias="x-api-key") | |
| ): | |
| """Generate music using MusicGen""" | |
| verify_api_key(user) | |
| if not request.prompt.strip(): | |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
| try: | |
| inputs = music_processor( | |
| text=[request.prompt], | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| max_tokens = int(request.duration * 50) | |
| with torch.no_grad(): | |
| audio_values = music_model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True | |
| ) | |
| sampling_rate = music_model.config.audio_encoder.sampling_rate | |
| audio_data = audio_values[0, 0].cpu().numpy() | |
| # Convert to WAV format | |
| wav_bytes = numpy_to_wav(audio_data, sampling_rate) | |
| # Encode to base64 | |
| audio_str = base64.b64encode(wav_bytes).decode() | |
| return MusicResponse( | |
| audio_base64=audio_str, | |
| sampling_rate=sampling_rate, | |
| format="wav" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============================================================================ | |
| # RUN SERVER | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |