Spaces:
Running
Running
| import os | |
| import io | |
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| import torchvision.utils as vutils | |
| from fastapi import FastAPI, Response, HTTPException, Query | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download, login | |
| from models import Generator | |
| app = FastAPI() | |
| # CORS Configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configuration Constants | |
| Z_DIM = 100 | |
| DEVICE = torch.device("cpu") | |
| REPO_ID = "SaniaE/GeoGen" | |
| FILENAME = "dcgans_model_checkpoint.pt" | |
| gen_model = None | |
| def load_model(): | |
| global gen_model | |
| try: | |
| token = os.getenv("HF_Token") | |
| if token: | |
| login(token=token) | |
| print("Login successful.") | |
| else: | |
| print("No HF_TOKEN found - attempting public download.") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token) | |
| print(f"File downloaded to: {model_path}") | |
| checkpoint = torch.load(model_path, map_location=DEVICE) | |
| gen_model = Generator(z_dim=Z_DIM).to(DEVICE) | |
| missing, unexpected = gen_model.load_state_dict( | |
| checkpoint["gen_state_dict"], strict=False | |
| ) | |
| print("Unexpected keys: ", unexpected) | |
| print("Missing keys: ", missing) | |
| gen_model.eval() | |
| print("SUCCESS: Petrol Pump GAN is live!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| def postprocess_image(tensor): | |
| # Unnormalize: tanh output [-1, 1] -> [0, 1] | |
| img_tensor = (tensor + 1) / 2 | |
| img_tensor = img_tensor.clamp(0, 1) | |
| grid = vutils.make_grid(img_tensor, padding=0, normalize=False) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| return Image.fromarray(ndarr) | |
| def get_image_stream(tensor): | |
| """Helper to convert tensor to a streaming-ready PNG.""" | |
| img_tensor = (tensor + 1) / 2 | |
| img_tensor = img_tensor.clamp(0, 1) | |
| grid = vutils.make_grid(img_tensor, padding=0) | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| pil_img = Image.fromarray(ndarr) | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return buf | |
| def read_root(): | |
| return {"status": "online", "model": REPO_ID} | |
| def generate_random(seed: int = Query(None)): | |
| """Endpoint 1: Fixed context generation for a session.""" | |
| if gen_model is None: raise HTTPException(status_code=503) | |
| # Use the provided session seed or fallback to random | |
| active_seed = seed if seed is not None else torch.seed() | |
| torch.manual_seed(active_seed) | |
| with torch.inference_mode(): | |
| noise = torch.randn(1, Z_DIM, device=DEVICE) | |
| fake_img = gen_model(noise) | |
| return StreamingResponse(get_image_stream(fake_img), media_type="image/png") | |
| def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)): | |
| """Endpoint 2: Controlled generation for 'Tuning'.""" | |
| if gen_model is None: raise HTTPException(status_code=503) | |
| try: | |
| with torch.inference_mode(): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) | |
| noise = torch.randn(1, Z_DIM, device=DEVICE) | |
| # Structured control | |
| noise[:, :10] += x_shift | |
| noise[:, 10:20] += y_shift | |
| # Random direction | |
| direction = torch.randn_like(noise) | |
| noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift)) | |
| print("NOISE:", noise[0, :5]) | |
| fake_img = gen_model(noise) | |
| return StreamingResponse(get_image_stream(fake_img), media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |