from transformers import AutoTokenizer, AutoModelForCausalLM import os from fastapi import FastAPI from pydantic import BaseModel # Load Gemma-2B tokenizer and model tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=os.environ["token"]) model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=os.environ["token"]) # Define Pydantic object for input validation class GenerationRequest(BaseModel): prompt: str repeat_penalty: float = 1.0 # Default repeat penalty # Initialize FastAPI app = FastAPI() # Define route for generating text @app.post("/generate_text") async def generate_text(request: GenerationRequest): # Tokenize the input prompt input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + "\n" + "" + request.prompt + "" # Encode the input prompt input_ids = tokenizer.encode(input_prompt, return_tensors="pt") # Generate text based on the input prompt outputs = model.generate(input_ids, repeat_penalty=request.repeat_penalty) # Decode the generated output and return generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text}