Spaces:
Runtime error
Runtime error
| from flask import Flask, request, send_file | |
| from flask_cors import CORS | |
| import torch | |
| import shap_e | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, decode_latent_mesh | |
| import os | |
| # To remove on diploy | |
| #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:768' | |
| app = Flask(__name__) | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| cors = CORS(app, resorces={r'/generate_3d': {"origins": '*'}}) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print('Device loaded: ', device) | |
| print('Loading models...') | |
| path_file = os.path.abspath(os.path.join(os.path.dirname(__file__), 'shape_e_model_cache')) | |
| trasmitter_path = os.path.join(path_file, 'transmitter') | |
| xm = load_model(trasmitter_path, device=device) | |
| model = load_model('text300M', device=device) | |
| diffusion = diffusion_from_config(load_config('diffusion')) | |
| def generate_3d(): | |
| if request.method == 'OPTIONS': | |
| return '', 200 | |
| print('Generating 3D...') | |
| batch_size = 1 | |
| guidance_scale = 15.0 | |
| prompt = request.json['prompt'] | |
| latents = sample_latents( | |
| batch_size=batch_size, | |
| model=model, | |
| diffusion=diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt] * batch_size), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True, | |
| use_karras=True, | |
| karras_steps=64, | |
| sigma_min=1E-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| render_mode = 'nerf' | |
| size = 64 | |
| cameras = create_pan_cameras(size, device) | |
| for i, latent in enumerate(latents): | |
| images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode) | |
| filename = f'tmp_mesh.obj' | |
| t = decode_latent_mesh(xm, latents[0]).tri_mesh() | |
| with open(filename, 'w') as f: | |
| t.write_obj(f) | |
| print('3D asset generated') | |
| return send_file(filename, as_attachment=True) | |
| if __name__ == '__main__': | |
| app.run(port=5000, debug=True) | |