| import torch |
| import time |
| import logging |
| from fastapi import FastAPI |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from threading import Thread |
|
|
| |
| |
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| MODEL_ID = "AshokGakr/model-tiny" |
|
|
| logger.info("Loading model...") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ).to(device) |
|
|
| model.eval() |
|
|
| logger.info(f"Model loaded on {device}") |
|
|
|
|
| |
| |
| |
| @app.get("/") |
| def root(): |
| return {"status": "API is running"} |
|
|
|
|
| |
| |
| |
| def generate_stream(prompt: str): |
| logger.info("Starting generation...") |
| start_time = time.time() |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True |
| ) |
|
|
| generation_kwargs = dict( |
| **inputs, |
| max_new_tokens=120, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.1, |
| do_sample=True, |
| streamer=streamer |
| ) |
|
|
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| for new_text in streamer: |
| yield new_text |
|
|
| duration = round(time.time() - start_time, 2) |
| logger.info(f"Generation finished in {duration} seconds.") |
|
|
|
|
| |
| |
| |
| @app.post("/chat") |
| async def chat(data: dict): |
| system_prompt = data.get("system", "You are a helpful AI assistant.") |
| history = data.get("history", "") |
| message = data.get("message", "") |
|
|
| |
| max_history_chars = 2000 |
| if len(history) > max_history_chars: |
| history = history[-max_history_chars:] |
|
|
| logger.info("----- NEW REQUEST -----") |
| logger.info(f"User message: {message}") |
| logger.info(f"History length: {len(history)}") |
|
|
| full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:" |
|
|
| return StreamingResponse( |
| generate_stream(full_prompt), |
| media_type="text/plain" |
| ) |