saikrishnagorijala commited on
Commit
ff4d990
·
verified ·
1 Parent(s): b5db786

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
  import torch
4
 
@@ -7,36 +7,31 @@ model_id = "saikrishnagorijala/friday-V1"
7
  # Load tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
9
 
10
- # Load model in 8-bit mode (requires bitsandbytes)
 
 
 
 
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
- device_map="auto", # automatically put layers on GPU
14
- load_in_8bit=True, # enable 8-bit quantization
15
- torch_dtype=torch.float16 # keep computations in FP16 where needed
16
  )
17
 
18
  def chat(prompt):
19
- # Tokenize input and move to model device
20
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
21
-
22
- # Generate response
23
  outputs = model.generate(
24
  **inputs,
25
  max_new_tokens=200,
26
- do_sample=True, # allow sampling for varied responses
27
- temperature=1.2, # optional creativity control
28
  top_p=0.9
29
  )
30
-
31
- # Decode and return text
32
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
 
34
- # Gradio interface
35
- demo = gr.Interface(
36
- fn=chat,
37
- inputs="text",
38
- outputs="text",
39
- title="Friday-V1 Chatbot"
40
- )
41
-
42
  demo.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2
  import gradio as gr
3
  import torch
4
 
 
7
  # Load tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
9
 
10
+ # Define quantization config for 8-bit inference
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_8bit=True,
13
+ bnb_8bit_use_double_quant=True,
14
+ bnb_8bit_quant_type="nf4",
15
+ bnb_8bit_compute_dtype=torch.float16
16
+ )
17
+
18
+ # Load model with quantization_config
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
+ device_map="auto",
22
+ quantization_config=bnb_config
 
23
  )
24
 
25
  def chat(prompt):
 
26
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
27
  outputs = model.generate(
28
  **inputs,
29
  max_new_tokens=200,
30
+ do_sample=True,
31
+ temperature=1.2,
32
  top_p=0.9
33
  )
 
 
34
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
 
36
+ demo = gr.Interface(fn=chat, inputs="text", outputs="text", title="Friday-V1 Chatbot")
 
 
 
 
 
 
 
37
  demo.launch()