LLM_Deployment / main.py
Prabhash's picture
Update main.py
9d2536a verified
raw
history blame
1.26 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from fastapi import FastAPI
from pydantic import BaseModel
# Load Gemma-2B tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=os.environ["token"])
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=os.environ["token"])
# Define Pydantic object for input validation
class GenerationRequest(BaseModel):
prompt: str
repeat_penalty: float = 1.0 # Default repeat penalty
# Initialize FastAPI
app = FastAPI()
# Define route for generating text
@app.post("/generate_text")
async def generate_text(request: GenerationRequest):
# Tokenize the input prompt
input_prompt = "<s>Below is an instruction that describes a task. Write a response that appropriately completes the request.</s>" + "\n" + "<s>" + request.prompt + "</s>"
# Encode the input prompt
input_ids = tokenizer.encode(input_prompt, return_tensors="pt")
# Generate text based on the input prompt
outputs = model.generate(input_ids, repeat_penalty=request.repeat_penalty)
# Decode the generated output and return
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated_text}