|
import gradio as gr |
|
from fastapi import FastAPI, Query |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
|
from functools import lru_cache |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
logger.info("Memuat model saat startup...") |
|
|
|
|
|
model_cache = {} |
|
|
|
def load_model(model_name): |
|
if model_name in model_cache: |
|
logger.info(f"Menggunakan model {model_name} dari cache") |
|
return model_cache[model_name] |
|
|
|
logger.info(f"Memuat model {model_name}...") |
|
if model_name == "mixtral": |
|
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
load_in_4bit=True, |
|
low_cpu_mem_usage=True |
|
) |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") |
|
except Exception as e: |
|
logger.error(f"Gagal memuat Mixtral: {str(e)}") |
|
raise |
|
elif model_name == "gpt2": |
|
pipe = pipeline("text-generation", model="gpt2", device=0 if torch.cuda.is_available() else -1) |
|
else: |
|
raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.") |
|
|
|
model_cache[model_name] = pipe |
|
logger.info(f"Model {model_name} berhasil dimuat") |
|
return pipe |
|
|
|
|
|
try: |
|
load_model("gpt2") |
|
load_model("mixtral") |
|
except Exception as e: |
|
logger.error(f"Error saat preload model: {str(e)}") |
|
|
|
|
|
@lru_cache(maxsize=100) |
|
def generate_text(prompt: str, model_name: str, max_length: int = 100): |
|
try: |
|
logger.info(f"Memproses prompt: {prompt[:30]}... dengan model {model_name}") |
|
generator = model_cache.get(model_name) |
|
if not generator: |
|
generator = load_model(model_name) |
|
|
|
output = generator( |
|
prompt, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
pad_token_id=generator.tokenizer.eos_token_id |
|
) |
|
return output[0]["generated_text"] |
|
except Exception as e: |
|
logger.error(f"Error saat generasi: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
@app.get("/generate") |
|
async def generate(prompt: str = Query(..., description="Teks input untuk model"), |
|
model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'"), |
|
max_length: int = Query(100, description="Panjang maksimum teks yang dihasilkan")): |
|
result = generate_text(prompt, model, max_length) |
|
return {"prompt": prompt, "model": model, "generated_text": result} |
|
|
|
|
|
def gradio_generate(prompt, model_choice, max_length): |
|
return generate_text(prompt, model_choice, max_length) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# AI Text Generation API (Optimized)") |
|
gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. API tersedia di `/generate`.") |
|
|
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...") |
|
model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2") |
|
max_length = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Panjang Maksimum") |
|
submit_button = gr.Button("Generate") |
|
|
|
output_text = gr.Textbox(label="Hasil Generasi") |
|
|
|
submit_button.click( |
|
fn=gradio_generate, |
|
inputs=[prompt_input, model_choice, max_length], |
|
outputs=output_text |
|
) |
|
|
|
|
|
demo.launch() |