| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from huggingface_hub import login |
| import os |
| import torch |
| import uvicorn |
|
|
| login(os.getenv("HF_TOKEN")) |
|
|
| app = FastAPI( |
| title="VexaAI-Lab: Google Gemma-3-270M", |
| description="Self-hosted AI-Model Google Gemma-3-270M, powered by VexaAI.", |
| version="0.9" |
| ) |
|
|
| model_name = "google/gemma-3-270m" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.float32 |
| ) |
| model.eval() |
|
|
| class GenerateRequest(BaseModel): |
| prompt: str |
| max_new_tokens: int = 512 |
| temperature: float = 0.7 |
|
|
| @app.post("/generate") |
| async def generate_text(request: GenerateRequest): |
| try: |
| inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=request.max_new_tokens, |
| temperature=request.temperature, |
| do_sample=True, |
| repetition_penalty=1.1, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| generated_text = full_text[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):].strip() |
| |
| return {"generated_text": generated_text} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"VexaAI-Lab: HTTP/S error: {str(e)}") |
|
|
| @app.get("/") |
| async def root(): |
| return {"message": "To start generating text, use /generate."} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |