ApaCu commited on
Commit
0ba12c0
·
verified ·
1 Parent(s): 038924a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -27
app.py CHANGED
@@ -1,67 +1,110 @@
1
  import gradio as gr
2
  from fastapi import FastAPI, Query
3
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
- import uvicorn
5
  import torch
 
 
 
 
 
6
 
7
  # Inisialisasi FastAPI
8
  app = FastAPI()
9
 
10
- # Inisialisasi model dan tokenizer
 
 
 
 
 
11
  def load_model(model_name):
 
 
 
 
 
12
  if model_name == "mixtral":
13
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
14
- tokenizer = AutoTokenizer.from_pretrained(model_id)
15
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
16
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
17
  elif model_name == "gpt2":
18
- return pipeline("text-generation", model="gpt2")
19
  else:
20
  raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.")
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Fungsi untuk menghasilkan teks
23
- def generate_text(prompt, model_name, max_length=100):
 
24
  try:
25
- generator = load_model(model_name)
26
- # Menghasilkan teks
27
- output = generator(prompt, max_length=max_length, num_return_sequences=1, do_sample=True)
 
 
 
 
 
 
 
 
 
28
  return output[0]["generated_text"]
29
  except Exception as e:
 
30
  return f"Error: {str(e)}"
31
 
32
  # Endpoint API
33
  @app.get("/generate")
34
  async def generate(prompt: str = Query(..., description="Teks input untuk model"),
35
- model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'")):
36
- result = generate_text(prompt, model)
 
37
  return {"prompt": prompt, "model": model, "generated_text": result}
38
 
39
  # Antarmuka Gradio
40
- def gradio_generate(prompt, model_choice):
41
- return generate_text(prompt, model_choice)
42
 
43
  with gr.Blocks() as demo:
44
- gr.Markdown("# AI Text Generation API")
45
- gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. Gunakan API di `/generate` untuk akses programatik.")
46
 
47
- # Komponen input
48
  prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...")
49
  model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2")
 
50
  submit_button = gr.Button("Generate")
51
 
52
- # Komponen output
53
  output_text = gr.Textbox(label="Hasil Generasi")
54
 
55
- # Menghubungkan tombol dengan fungsi
56
  submit_button.click(
57
  fn=gradio_generate,
58
- inputs=[prompt_input, model_choice],
59
  outputs=output_text
60
  )
61
 
62
- # Menjalankan aplikasi (untuk lokal, bukan di Hugging Face)
63
- if __name__ == "__main__":
64
- uvicorn.run(app, host="0.0.0.0", port=7860)
65
- else:
66
- # Untuk Hugging Face Spaces, luncurkan Gradio
67
- demo.launch()
 
1
  import gradio as gr
2
  from fastapi import FastAPI, Query
3
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
+ from functools import lru_cache
5
  import torch
6
+ import logging
7
+
8
+ # Setup logging untuk debugging performa
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  # Inisialisasi FastAPI
13
  app = FastAPI()
14
 
15
+ # Preload model dan tokenizer untuk efisiensi
16
+ logger.info("Memuat model saat startup...")
17
+
18
+ # Cache model di memori
19
+ model_cache = {}
20
+
21
  def load_model(model_name):
22
+ if model_name in model_cache:
23
+ logger.info(f"Menggunakan model {model_name} dari cache")
24
+ return model_cache[model_name]
25
+
26
+ logger.info(f"Memuat model {model_name}...")
27
  if model_name == "mixtral":
28
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
29
+ try:
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ # Gunakan 4-bit quantization untuk mengurangi penggunaan memori
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ device_map="auto",
35
+ torch_dtype=torch.float16,
36
+ load_in_4bit=True, # Quantization untuk kecepatan
37
+ low_cpu_mem_usage=True
38
+ )
39
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
40
+ except Exception as e:
41
+ logger.error(f"Gagal memuat Mixtral: {str(e)}")
42
+ raise
43
  elif model_name == "gpt2":
44
+ pipe = pipeline("text-generation", model="gpt2", device=0 if torch.cuda.is_available() else -1)
45
  else:
46
  raise ValueError("Model tidak didukung. Pilih 'mixtral' atau 'gpt2'.")
47
+
48
+ model_cache[model_name] = pipe
49
+ logger.info(f"Model {model_name} berhasil dimuat")
50
+ return pipe
51
+
52
+ # Preload model saat startup
53
+ try:
54
+ load_model("gpt2") # Load GPT-2 (ringan) terlebih dahulu
55
+ load_model("mixtral") # Load Mixtral dengan quantization
56
+ except Exception as e:
57
+ logger.error(f"Error saat preload model: {str(e)}")
58
 
59
+ # Fungsi generate dengan caching
60
+ @lru_cache(maxsize=100)
61
+ def generate_text(prompt: str, model_name: str, max_length: int = 100):
62
  try:
63
+ logger.info(f"Memproses prompt: {prompt[:30]}... dengan model {model_name}")
64
+ generator = model_cache.get(model_name)
65
+ if not generator:
66
+ generator = load_model(model_name)
67
+ # Generate teks
68
+ output = generator(
69
+ prompt,
70
+ max_length=max_length,
71
+ num_return_sequences=1,
72
+ do_sample=True,
73
+ pad_token_id=generator.tokenizer.eos_token_id
74
+ )
75
  return output[0]["generated_text"]
76
  except Exception as e:
77
+ logger.error(f"Error saat generasi: {str(e)}")
78
  return f"Error: {str(e)}"
79
 
80
  # Endpoint API
81
  @app.get("/generate")
82
  async def generate(prompt: str = Query(..., description="Teks input untuk model"),
83
+ model: str = Query("gpt2", description="Model AI: 'mixtral' atau 'gpt2'"),
84
+ max_length: int = Query(100, description="Panjang maksimum teks yang dihasilkan")):
85
+ result = generate_text(prompt, model, max_length)
86
  return {"prompt": prompt, "model": model, "generated_text": result}
87
 
88
  # Antarmuka Gradio
89
+ def gradio_generate(prompt, model_choice, max_length):
90
+ return generate_text(prompt, model_choice, max_length)
91
 
92
  with gr.Blocks() as demo:
93
+ gr.Markdown("# AI Text Generation API (Optimized)")
94
+ gr.Markdown("Masukkan teks dan pilih model untuk menghasilkan teks. API tersedia di `/generate`.")
95
 
 
96
  prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan teks di sini...")
97
  model_choice = gr.Dropdown(choices=["gpt2", "mixtral"], label="Pilih Model", value="gpt2")
98
+ max_length = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Panjang Maksimum")
99
  submit_button = gr.Button("Generate")
100
 
 
101
  output_text = gr.Textbox(label="Hasil Generasi")
102
 
 
103
  submit_button.click(
104
  fn=gradio_generate,
105
+ inputs=[prompt_input, model_choice, max_length],
106
  outputs=output_text
107
  )
108
 
109
+ # Untuk Hugging Face Spaces, langsung launch Gradio
110
+ demo.launch()