Spaces:
Build error
Build error
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
# Load Gemma-2B tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=os.environ["token"]) | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=os.environ["token"]) | |
# Define Pydantic object for input validation | |
class GenerationRequest(BaseModel): | |
prompt: str | |
repeat_penalty: float = 1.0 # Default repeat penalty | |
# Initialize FastAPI | |
app = FastAPI() | |
# Define route for generating text | |
async def generate_text(request: GenerationRequest): | |
# Tokenize the input prompt | |
input_prompt = "<s>Below is an instruction that describes a task. Write a response that appropriately completes the request.</s>" + "\n" + "<s>" + request.prompt + "</s>" | |
# Encode the input prompt | |
input_ids = tokenizer.encode(input_prompt, return_tensors="pt") | |
# Generate text based on the input prompt | |
outputs = model.generate(input_ids, repeat_penalty=request.repeat_penalty) | |
# Decode the generated output and return | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"generated_text": generated_text} | |