Bashaarat1's picture
Update app.py
30b9f33 verified
# ============================================================================
# CONTENTFORGE AI - FASTAPI BACKEND
# REST API for multi-modal AI platform
# ============================================================================
from fastapi import FastAPI, HTTPException, Header, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import torch
import os
from huggingface_hub import login
import base64
from io import BytesIO
import numpy as np
import wave
import struct
# ============================================================================
# AUTHENTICATION
# ============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
print("πŸ” Authenticating with HuggingFace...")
login(token=HF_TOKEN)
print("βœ… Authenticated!\n")
from transformers import (
T5Tokenizer, T5ForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLProcessor,
AutoProcessor, MusicgenForConditionalGeneration
)
from peft import PeftModel
from qwen_vl_utils import process_vision_info
from diffusers import StableDiffusionPipeline
from PIL import Image
# ============================================================================
# FASTAPI APP SETUP
# ============================================================================
app = FastAPI(
title="ContentForge AI API",
description="Multi-modal AI API for education and social media content generation",
version="1.0.0"
)
# CORS - Allow requests from your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production: ["https://yourwebsite.vercel.app"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Simple API key authentication (improve this for production!)
API_KEYS = {
"demo_key_123": "Demo User",
"sk_test_456": "Test User",
}
def verify_api_key(x_api_key: str = Header(None)):
"""Verify API key from header"""
if x_api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API Key")
return API_KEYS[x_api_key]
# ============================================================================
# LOAD MODELS
# ============================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸ–₯️ Using device: {device}")
print("πŸ“¦ Loading models...\n")
# 1. T5 Model
print("πŸ“ Loading T5...")
t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer")
t5_model = T5ForConditionalGeneration.from_pretrained(
"Bashaarat1/t5-small-arxiv-summarizer"
).to(device)
t5_model.eval()
print("βœ… T5 loaded!")
# 2. Qwen VLM
print("πŸ€– Loading Qwen...")
qwen_base = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
device_map="auto",
torch_dtype=torch.bfloat16
)
qwen_model = PeftModel.from_pretrained(
qwen_base,
"Bashaarat1/qwen-finetuned-scienceqa"
)
qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
qwen_model.eval()
print("βœ… Qwen loaded!")
# 3. MusicGen
print("🎡 Loading MusicGen...")
music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small")
music_model = MusicgenForConditionalGeneration.from_pretrained(
"Bashaarat1/fine-tuned-musicgen-small"
).to(device)
music_model.eval()
print("βœ… MusicGen loaded!")
# 4. Stable Diffusion
print("🎨 Loading Stable Diffusion...")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None
).to(device)
print("βœ… Stable Diffusion loaded!")
print("\nπŸŽ‰ All models loaded! API ready.\n")
# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================
class SummarizeRequest(BaseModel):
text: str
max_length: int = 128
class SummarizeResponse(BaseModel):
summary: str
original_words: int
summary_words: int
class QARequest(BaseModel):
question: str
image_base64: Optional[str] = None
class QAResponse(BaseModel):
answer: str
class ImageRequest(BaseModel):
prompt: str
negative_prompt: str = ""
num_steps: int = 25
class ImageResponse(BaseModel):
image_base64: str
class MusicRequest(BaseModel):
prompt: str
duration: int = 10
class MusicResponse(BaseModel):
audio_base64: str
sampling_rate: int
format: str
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def numpy_to_wav(audio_data: np.ndarray, sampling_rate: int) -> bytes:
"""Convert numpy array to WAV format bytes"""
# Normalize audio to -1 to 1 range
audio_data = np.clip(audio_data, -1, 1)
# Convert to 16-bit PCM
audio_int16 = (audio_data * 32767).astype(np.int16)
# Create WAV file in memory
wav_io = BytesIO()
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sampling_rate)
wav_file.writeframes(audio_int16.tobytes())
return wav_io.getvalue()
# ============================================================================
# API ENDPOINTS
# ============================================================================
@app.get("/")
def root():
"""API health check"""
return {
"status": "online",
"message": "ContentForge AI API",
"version": "1.0.0",
"endpoints": [
"/summarize",
"/qa",
"/generate-image",
"/generate-music"
]
}
@app.post("/summarize", response_model=SummarizeResponse)
def summarize(
request: SummarizeRequest,
user: str = Header(None, alias="x-api-key")
):
"""Summarize text using fine-tuned T5"""
verify_api_key(user)
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
try:
inputs = t5_tokenizer(
f"summarize: {request.text}",
return_tensors="pt",
max_length=512,
truncation=True
).to(device)
with torch.no_grad():
outputs = t5_model.generate(
**inputs,
max_length=request.max_length,
min_length=30,
num_beams=4,
early_stopping=True
)
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
return SummarizeResponse(
summary=summary,
original_words=len(request.text.split()),
summary_words=len(summary.split())
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/qa", response_model=QAResponse)
def question_answer(
request: QARequest,
user: str = Header(None, alias="x-api-key")
):
"""Answer questions with optional image using Qwen VLM"""
verify_api_key(user)
if not request.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
try:
image = None
if request.image_base64:
# Decode base64 image
image_data = base64.b64decode(request.image_base64)
image = Image.open(BytesIO(image_data)).convert('RGB')
# Prepare messages
if image is not None:
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": request.question}
]
}]
else:
messages = [{
"role": "user",
"content": [{"type": "text", "text": request.question}]
}]
text_prompt = qwen_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
if image is not None:
img_inputs, _ = process_vision_info(messages)
inputs = qwen_processor(
text=[text_prompt],
images=img_inputs,
return_tensors="pt"
).to(device)
else:
inputs = qwen_processor(
text=[text_prompt],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = qwen_model.generate(**inputs, max_new_tokens=200)
answer = qwen_processor.batch_decode(
outputs[:, inputs.input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return QAResponse(answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-image", response_model=ImageResponse)
def generate_image(
request: ImageRequest,
user: str = Header(None, alias="x-api-key")
):
"""Generate image using Stable Diffusion"""
verify_api_key(user)
if not request.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
try:
with torch.no_grad():
image = sd_pipe(
request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.num_steps,
guidance_scale=7.5
).images[0]
# Convert image to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return ImageResponse(image_base64=img_str)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-music", response_model=MusicResponse)
def generate_music(
request: MusicRequest,
user: str = Header(None, alias="x-api-key")
):
"""Generate music using MusicGen"""
verify_api_key(user)
if not request.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
try:
inputs = music_processor(
text=[request.prompt],
padding=True,
return_tensors="pt"
).to(device)
max_tokens = int(request.duration * 50)
with torch.no_grad():
audio_values = music_model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True
)
sampling_rate = music_model.config.audio_encoder.sampling_rate
audio_data = audio_values[0, 0].cpu().numpy()
# Convert to WAV format
wav_bytes = numpy_to_wav(audio_data, sampling_rate)
# Encode to base64
audio_str = base64.b64encode(wav_bytes).decode()
return MusicResponse(
audio_base64=audio_str,
sampling_rate=sampling_rate,
format="wav"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# RUN SERVER
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)