File size: 4,111 Bytes
d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 d642290 0ba12c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
from fastapi import FastAPI, Query
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from functools import lru_cache
import torch
import logging
# Setup logging untuk debugging performa
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Inisialisasi FastAPI
app = FastAPI()
# Preload model dan tokenizer untuk efisiensi
logger.info("Memuat model saat startup...")
# Cache model di memori
model_cache = {}
def load_model(model_name):
if model_name in model_cache:
logger.info(f"Menggunakan model {model_name} dari cache")
return model_cache[model_name]
logger.info(f"Memuat model {model_name}...")
if model_name == "mixtral":
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Gunakan 4-bit quantization untuk mengurangi penggunaan memori
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True, # Quantization untuk kecepatan
low_cpu_mem_usage=True
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
except Exception as e:
logger.error(f"Gagal memuat Mixtral: {str(e)}")
raise
elif model_name == "gpt2":
pipe = pipeline("text-generation", model="gpt2", device=0 if torch.cuda.is_available() else -1)
else:
raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.")
model_cache[model_name] = pipe
logger.info(f"Model {model_name} berhasil dimuat")
return pipe
# Preload model saat startup
try:
load_model("gpt2") # Load GPT-2 (ringan) terlebih dahulu
load_model("mixtral") # Load Mixtral dengan quantization
except Exception as e:
logger.error(f"Error saat preload model: {str(e)}")
# Fungsi generate dengan caching
@lru_cache(maxsize=100)
def generate_text(prompt: str, model_name: str, max_length: int = 100):
try:
logger.info(f"Memproses prompt: {prompt[:30]}... dengan model {model_name}")
generator = model_cache.get(model_name)
if not generator:
generator = load_model(model_name)
# Generate teks
output = generator(
prompt,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
return output[0]["generated_text"]
except Exception as e:
logger.error(f"Error saat generasi: {str(e)}")
return f"Error: {str(e)}"
# Endpoint API
@app.get("/generate")
async def generate(prompt: str = Query(..., description="Teks input untuk model"),
model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'"),
max_length: int = Query(100, description="Panjang maksimum teks yang dihasilkan")):
result = generate_text(prompt, model, max_length)
return {"prompt": prompt, "model": model, "generated_text": result}
# Antarmuka Gradio
def gradio_generate(prompt, model_choice, max_length):
return generate_text(prompt, model_choice, max_length)
with gr.Blocks() as demo:
gr.Markdown("# AI Text Generation API (Optimized)")
gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. API tersedia di `/generate`.")
prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...")
model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2")
max_length = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Panjang Maksimum")
submit_button = gr.Button("Generate")
output_text = gr.Textbox(label="Hasil Generasi")
submit_button.click(
fn=gradio_generate,
inputs=[prompt_input, model_choice, max_length],
outputs=output_text
)
# Untuk Hugging Face Spaces, langsung launch Gradio
demo.launch() |