rahul7star commited on
Commit
be09bfa
·
verified ·
1 Parent(s): ff12e01

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +12 -18
app_flash.py CHANGED
@@ -1,21 +1,21 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
- from transformers import AutoModelForCausalLM
5
 
6
  # ============================================================
7
- # 1️⃣ FlashPack-enabled model class
8
  # ============================================================
9
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
 
10
  pass
11
 
12
  # ============================================================
13
- # 2️⃣ Model & tokenizer settings
14
  # ============================================================
15
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
16
  FLASHPACK_REPO = "rahul7star/FlashPack"
17
 
18
- # Load tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
 
21
  # ============================================================
@@ -25,17 +25,15 @@ try:
25
  print("📂 Loading model from FlashPack repository...")
26
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
27
  except FileNotFoundError:
28
- print("⚠️ FlashPack model not found on Hub. Creating and uploading...")
29
- # Load from HF Hub
30
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
31
- # Save as FlashPack directly to Hub
32
  model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
33
- print(f"✅ Model uploaded as FlashPack to Hugging Face Hub: {FLASHPACK_REPO}")
34
 
35
  # ============================================================
36
- # 4️⃣ Text-generation pipeline
37
  # ============================================================
38
- pipe = pipeline(
39
  "text-generation",
40
  model=model,
41
  tokenizer=tokenizer,
@@ -43,30 +41,26 @@ pipe = pipeline(
43
  )
44
 
45
  # ============================================================
46
- # 5️⃣ Prompt enhancement function
47
  # ============================================================
48
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
49
  chat_history = chat_history or []
50
 
51
- # Build messages
52
  messages = [
53
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
54
  {"role": "user", "content": user_prompt},
55
  ]
56
 
57
- # Apply chat-template
58
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
 
60
- # Generate output
61
  outputs = pipe(
62
  prompt,
63
  max_new_tokens=int(max_tokens),
64
  temperature=float(temperature),
65
- do_sample=True,
66
  )
67
  enhanced = outputs[0]["generated_text"].strip()
68
 
69
- # Update chat history
70
  chat_history.append({"role": "user", "content": user_prompt})
71
  chat_history.append({"role": "assistant", "content": enhanced})
72
  return chat_history
@@ -96,7 +90,7 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
96
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
97
  clear_btn = gr.Button("🧹 Clear Chat")
98
 
99
- # Bind actions
100
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
101
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
102
  clear_btn.click(lambda: [], None, chatbot)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
+ from transformers import AutoModelForCausalLM, pipeline as hf_pipeline
5
 
6
  # ============================================================
7
+ # 1️⃣ Define FlashPack-enabled model class
8
  # ============================================================
9
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
10
+ """Gemma 3 model wrapped with FlashPackTransformersModelMixin"""
11
  pass
12
 
13
  # ============================================================
14
+ # 2️⃣ Load tokenizer
15
  # ============================================================
16
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
17
  FLASHPACK_REPO = "rahul7star/FlashPack"
18
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
 
21
  # ============================================================
 
25
  print("📂 Loading model from FlashPack repository...")
26
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
27
  except FileNotFoundError:
28
+ print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...")
 
29
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
 
30
  model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
31
+ print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")
32
 
33
  # ============================================================
34
+ # 4️⃣ Build text-generation pipeline
35
  # ============================================================
36
+ pipe = hf_pipeline(
37
  "text-generation",
38
  model=model,
39
  tokenizer=tokenizer,
 
41
  )
42
 
43
  # ============================================================
44
+ # 5️⃣ Define prompt enhancement function
45
  # ============================================================
46
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
47
  chat_history = chat_history or []
48
 
 
49
  messages = [
50
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
51
  {"role": "user", "content": user_prompt},
52
  ]
53
 
 
54
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55
 
 
56
  outputs = pipe(
57
  prompt,
58
  max_new_tokens=int(max_tokens),
59
  temperature=float(temperature),
60
+ do_sample=True
61
  )
62
  enhanced = outputs[0]["generated_text"].strip()
63
 
 
64
  chat_history.append({"role": "user", "content": user_prompt})
65
  chat_history.append({"role": "assistant", "content": enhanced})
66
  return chat_history
 
90
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
91
  clear_btn = gr.Button("🧹 Clear Chat")
92
 
93
+ # Bind UI actions
94
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
95
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
96
  clear_btn.click(lambda: [], None, chatbot)