Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer | |
| from threading import Thread | |
| import time | |
| import logging | |
| import os | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Qwen 2.5 0.5B API", | |
| description="API for generating text with Qwen 2.5 0.5B Instruct model") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request model | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 128000 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| do_sample: bool = True | |
| stream: bool = False | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| device = None | |
| async def startup_event(): | |
| global model, tokenizer, device | |
| try: | |
| # Model name | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| logger.info(f"Loading {model_name}...") | |
| # Force CPU usage | |
| device = torch.device("cpu") | |
| # Load tokenizer with trust_remote_code | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| logger.info("Tokenizer loaded successfully") | |
| # Load the model config to modify sliding window settings | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| # Disable sliding window attention in the config | |
| if hasattr(config, "sliding_window"): | |
| logger.info("Disabling sliding window attention to resolve warning...") | |
| config.sliding_window = None | |
| # Load model with modified config | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| config=config, | |
| device_map="cpu", # Ensure CPU usage | |
| torch_dtype=torch.float32, # Use float32 for CPU compatibility | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"Model loaded successfully on {device}") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| # Print stack trace for debugging | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise e | |
| async def root(): | |
| return {"message": "Qwen 2.5 0.5B Instruct API is running."} | |
| async def generate(request: GenerationRequest): | |
| global model, tokenizer, device | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model or tokenizer not loaded yet") | |
| try: | |
| # If streaming is requested, use streaming response | |
| if request.stream: | |
| return StreamingResponse( | |
| stream_generate(request), | |
| media_type="text/event-stream" | |
| ) | |
| # Regular non-streaming generation | |
| start_time = time.time() | |
| # Format the input for Qwen2.5-Instruct format | |
| prompt = f"<|im_start|>user\n{request.prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| do_sample=request.do_sample, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the response | |
| response = tokenizer.decode(output[0], skip_special_tokens=False) | |
| # Extract just the assistant's response | |
| assistant_response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0] | |
| generation_time = time.time() - start_time | |
| return { | |
| "generated_text": assistant_response, | |
| "generation_time_seconds": generation_time | |
| } | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| # Print stack trace for debugging | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") | |
| async def stream_generate(request: GenerationRequest): | |
| """Generate text in a streaming fashion using the TextIteratorStreamer.""" | |
| try: | |
| start_time = time.time() | |
| # Format the input for Qwen2.5-Instruct format | |
| prompt = f"<|im_start|>user\n{request.prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Create a streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) | |
| # Set up generation parameters | |
| generation_kwargs = { | |
| "input_ids": inputs.input_ids, | |
| "attention_mask": inputs.attention_mask, | |
| "max_new_tokens": request.max_new_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "do_sample": request.do_sample, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer | |
| } | |
| # Create a thread to run the generation | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Track complete response for post-processing | |
| complete_response = "" | |
| last_output = "" | |
| # Stream tokens as they are generated | |
| for new_text in streamer: | |
| complete_response += new_text | |
| # Check if we've encountered the end tag | |
| if "<|im_end|>" in complete_response and "<|im_end|>" not in last_output: | |
| # We've reached the end tag, trim the response | |
| final_response = complete_response.split("<|im_end|>")[0] | |
| yield f"data: {new_text}\n\n" | |
| break | |
| # Send the new text chunk | |
| yield f"data: {new_text}\n\n" | |
| last_output = new_text | |
| # Signal completion | |
| yield f"data: [DONE]\n\n" | |
| logger.info(f"Streaming generation completed in {time.time() - start_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Error in streaming generation: {str(e)}") | |
| yield f"data: [ERROR] {str(e)}\n\n" | |
| async def add_process_time_header(request: Request, call_next): | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| response.headers["X-Process-Time"] = str(process_time) | |
| return response | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |