made1570 commited on
Commit
76ca090
·
verified ·
1 Parent(s): 4cbaeb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -1,55 +1,48 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoProcessor
3
  import gradio as gr
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
6
 
7
- # Load model and processor
8
- model = AutoModelForCausalLM.from_pretrained("adarsh3601/my_gemma3_pt", device_map="auto")
9
- processor = AutoProcessor.from_pretrained("adarsh3601/my_gemma3_pt")
10
 
11
  def chat(user_input, history):
12
- # Format history as messages
13
  messages = []
14
- for i, (user_msg, bot_msg) in enumerate(history):
15
  messages.append({"role": "user", "content": user_msg})
16
  messages.append({"role": "assistant", "content": bot_msg})
17
  messages.append({"role": "user", "content": user_input})
18
 
19
- try:
20
- # Try using chat template
21
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
22
- except Exception as e:
23
- print(f"[WARNING] Failed to apply chat_template: {e}")
24
- prompt = None
25
-
26
- # Fallback if prompt fails
27
- if not prompt:
28
- prompt = "<bos>"
29
- for i, msg in enumerate(messages):
30
- role = "model" if msg["role"] == "assistant" else msg["role"]
31
- prompt += f"<start_of_turn>{role}\n{msg['content'].strip()}<end_of_turn>\n"
32
- prompt += "<start_of_turn>model\n"
33
-
34
- print(f"[DEBUG] Prompt:\n{prompt}")
35
 
36
- inputs = processor(prompt, return_tensors="pt").to(device)
37
 
38
  outputs = model.generate(
39
  **inputs,
40
- max_new_tokens=512,
41
- do_sample=False,
42
- num_beams=1,
43
- eos_token_id=processor.tokenizer.eos_token_id,
44
- pad_token_id=processor.tokenizer.pad_token_id
 
 
45
  )
46
 
47
- response = processor.decode(outputs[0], skip_special_tokens=True)
48
- # Extract only the assistant response (after last <start_of_turn>model)
49
- if "<start_of_turn>model" in response:
50
- response = response.split("<start_of_turn>model")[-1].strip()
 
 
51
 
52
  return response
53
 
54
- # Launch the Gradio interface
55
- iface = gr.ChatInterface(fn=chat, title="Gemma-3 Chat").launch(share=True)
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
3
  import gradio as gr
4
 
5
+ # Load model and tokenizer using Unsloth-style
6
+ model_name = "adarsh3601/my_gemma3_pt"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
  def chat(user_input, history):
 
13
  messages = []
14
+ for user_msg, bot_msg in history:
15
  messages.append({"role": "user", "content": user_msg})
16
  messages.append({"role": "assistant", "content": bot_msg})
17
  messages.append({"role": "user", "content": user_input})
18
 
19
+ # Apply chat template
20
+ prompt = tokenizer.apply_chat_template(
21
+ messages,
22
+ add_generation_prompt=True,
23
+ tokenize=False
24
+ )
 
 
 
 
 
 
 
 
 
 
25
 
26
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
27
 
28
  outputs = model.generate(
29
  **inputs,
30
+ max_new_tokens=1024,
31
+ temperature=1.0,
32
+ top_p=0.95,
33
+ top_k=64,
34
+ do_sample=True,
35
+ pad_token_id=tokenizer.pad_token_id,
36
+ eos_token_id=tokenizer.eos_token_id
37
  )
38
 
39
+ # Decode and extract just the last assistant message
40
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ if "<start_of_turn>assistant" in decoded:
42
+ response = decoded.split("<start_of_turn>assistant")[-1].strip()
43
+ else:
44
+ response = decoded
45
 
46
  return response
47
 
48
+ gr.ChatInterface(fn=chat, title="Chat with Gemma-3").launch(share=True)