OpenApi / app.py
ApaCu's picture
Update app.py
0ba12c0 verified
import gradio as gr
from fastapi import FastAPI, Query
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from functools import lru_cache
import torch
import logging
# Setup logging untuk debugging performa
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Inisialisasi FastAPI
app = FastAPI()
# Preload model dan tokenizer untuk efisiensi
logger.info("Memuat model saat startup...")
# Cache model di memori
model_cache = {}
def load_model(model_name):
if model_name in model_cache:
logger.info(f"Menggunakan model {model_name} dari cache")
return model_cache[model_name]
logger.info(f"Memuat model {model_name}...")
if model_name == "mixtral":
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Gunakan 4-bit quantization untuk mengurangi penggunaan memori
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True, # Quantization untuk kecepatan
low_cpu_mem_usage=True
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
except Exception as e:
logger.error(f"Gagal memuat Mixtral: {str(e)}")
raise
elif model_name == "gpt2":
pipe = pipeline("text-generation", model="gpt2", device=0 if torch.cuda.is_available() else -1)
else:
raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.")
model_cache[model_name] = pipe
logger.info(f"Model {model_name} berhasil dimuat")
return pipe
# Preload model saat startup
try:
load_model("gpt2") # Load GPT-2 (ringan) terlebih dahulu
load_model("mixtral") # Load Mixtral dengan quantization
except Exception as e:
logger.error(f"Error saat preload model: {str(e)}")
# Fungsi generate dengan caching
@lru_cache(maxsize=100)
def generate_text(prompt: str, model_name: str, max_length: int = 100):
try:
logger.info(f"Memproses prompt: {prompt[:30]}... dengan model {model_name}")
generator = model_cache.get(model_name)
if not generator:
generator = load_model(model_name)
# Generate teks
output = generator(
prompt,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
return output[0]["generated_text"]
except Exception as e:
logger.error(f"Error saat generasi: {str(e)}")
return f"Error: {str(e)}"
# Endpoint API
@app.get("/generate")
async def generate(prompt: str = Query(..., description="Teks input untuk model"),
model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'"),
max_length: int = Query(100, description="Panjang maksimum teks yang dihasilkan")):
result = generate_text(prompt, model, max_length)
return {"prompt": prompt, "model": model, "generated_text": result}
# Antarmuka Gradio
def gradio_generate(prompt, model_choice, max_length):
return generate_text(prompt, model_choice, max_length)
with gr.Blocks() as demo:
gr.Markdown("# AI Text Generation API (Optimized)")
gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. API tersedia di `/generate`.")
prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...")
model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2")
max_length = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Panjang Maksimum")
submit_button = gr.Button("Generate")
output_text = gr.Textbox(label="Hasil Generasi")
submit_button.click(
fn=gradio_generate,
inputs=[prompt_input, model_choice, max_length],
outputs=output_text
)
# Untuk Hugging Face Spaces, langsung launch Gradio
demo.launch()