Karzan commited on
Commit
e4ccddf
·
verified ·
1 Parent(s): c9e0f6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -1,14 +1,21 @@
1
- from transformers import pipeline,GemmaForCausalLM,AutoTokenizer
2
  import gradio as gr
3
  import spaces
 
4
  # ignore_mismatched_sizes=True
 
5
  tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b')
6
- model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',ignore_mismatched_sizes=True)
7
- pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
 
 
8
 
9
  @spaces.GPU(duration=120)
10
  def generate(prompt):
11
- return pipe(prompt)[0]['generated_text']
 
 
 
12
 
13
  gr.Interface(
14
  fn=generate,
 
1
+ from transformers import pipeline,GemmaForCausalLM,AutoTokenizer,BitsAndBytesConfig
2
  import gradio as gr
3
  import spaces
4
+ import torch
5
  # ignore_mismatched_sizes=True
6
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
7
  tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b')
8
+ model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',
9
+ quantization_config=quantization_config
10
+ )
11
+ # pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
12
 
13
  @spaces.GPU(duration=120)
14
  def generate(prompt):
15
+ input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
16
+ outputs = model.generate(**input_ids)
17
+ return tokenizer.decode(outputs[0]);
18
+ # return pipe(prompt)[0]['generated_text']
19
 
20
  gr.Interface(
21
  fn=generate,