File size: 4,111 Bytes
d642290
 
 
0ba12c0
d642290
0ba12c0
 
 
 
 
d642290
 
 
 
0ba12c0
 
 
 
 
 
d642290
0ba12c0
 
 
 
 
d642290
 
0ba12c0
 
 
 
 
 
 
 
 
 
 
 
 
 
d642290
0ba12c0
d642290
 
0ba12c0
 
 
 
 
 
 
 
 
 
 
d642290
0ba12c0
 
 
d642290
0ba12c0
 
 
 
 
 
 
 
 
 
 
 
d642290
 
0ba12c0
d642290
 
 
 
 
0ba12c0
 
 
d642290
 
 
0ba12c0
 
d642290
 
0ba12c0
 
d642290
 
 
0ba12c0
d642290
 
 
 
 
 
0ba12c0
d642290
 
 
0ba12c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from fastapi import FastAPI, Query
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from functools import lru_cache
import torch
import logging

# Setup logging untuk debugging performa
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Inisialisasi FastAPI
app = FastAPI()

# Preload model dan tokenizer untuk efisiensi
logger.info("Memuat model saat startup...")

# Cache model di memori
model_cache = {}

def load_model(model_name):
    if model_name in model_cache:
        logger.info(f"Menggunakan model {model_name} dari cache")
        return model_cache[model_name]
    
    logger.info(f"Memuat model {model_name}...")
    if model_name == "mixtral":
        model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            # Gunakan 4-bit quantization untuk mengurangi penggunaan memori
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                torch_dtype=torch.float16,
                load_in_4bit=True,  # Quantization untuk kecepatan
                low_cpu_mem_usage=True
            )
            pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        except Exception as e:
            logger.error(f"Gagal memuat Mixtral: {str(e)}")
            raise
    elif model_name == "gpt2":
        pipe = pipeline("text-generation", model="gpt2", device=0 if torch.cuda.is_available() else -1)
    else:
        raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.")
    
    model_cache[model_name] = pipe
    logger.info(f"Model {model_name} berhasil dimuat")
    return pipe

# Preload model saat startup
try:
    load_model("gpt2")  # Load GPT-2 (ringan) terlebih dahulu
    load_model("mixtral")  # Load Mixtral dengan quantization
except Exception as e:
    logger.error(f"Error saat preload model: {str(e)}")

# Fungsi generate dengan caching
@lru_cache(maxsize=100)
def generate_text(prompt: str, model_name: str, max_length: int = 100):
    try:
        logger.info(f"Memproses prompt: {prompt[:30]}... dengan model {model_name}")
        generator = model_cache.get(model_name)
        if not generator:
            generator = load_model(model_name)
        # Generate teks
        output = generator(
            prompt,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            pad_token_id=generator.tokenizer.eos_token_id
        )
        return output[0]["generated_text"]
    except Exception as e:
        logger.error(f"Error saat generasi: {str(e)}")
        return f"Error: {str(e)}"

# Endpoint API
@app.get("/generate")
async def generate(prompt: str = Query(..., description="Teks input untuk model"),
                  model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'"),
                  max_length: int = Query(100, description="Panjang maksimum teks yang dihasilkan")):
    result = generate_text(prompt, model, max_length)
    return {"prompt": prompt, "model": model, "generated_text": result}

# Antarmuka Gradio
def gradio_generate(prompt, model_choice, max_length):
    return generate_text(prompt, model_choice, max_length)

with gr.Blocks() as demo:
    gr.Markdown("# AI Text Generation API (Optimized)")
    gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. API tersedia di `/generate`.")
    
    prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...")
    model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2")
    max_length = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Panjang Maksimum")
    submit_button = gr.Button("Generate")
    
    output_text = gr.Textbox(label="Hasil Generasi")
    
    submit_button.click(
        fn=gradio_generate,
        inputs=[prompt_input, model_choice, max_length],
        outputs=output_text
    )

# Untuk Hugging Face Spaces, langsung launch Gradio
demo.launch()