radames commited on
Commit
5cc76b1
1 Parent(s): 7ea9fd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -1,36 +1,34 @@
1
  import gradio as gr
2
 
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
4
  import transformers
5
  import torch
6
 
7
- model = "tiiuae/falcon-40b"
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(model)
10
- pipeline = transformers.pipeline(
11
- "text-generation",
12
- model=model,
13
- tokenizer=tokenizer,
14
- torch_dtype=torch.bfloat16,
15
  trust_remote_code=True,
16
- load_in_4bit=True,
17
  device_map="auto",
18
  )
 
19
 
20
- def falcon(input_text):
21
 
22
- sequences = pipeline(
23
- input_text, # "Was ist das höchste Gebäude in der Welt?"
24
- max_length=200,
25
- do_sample=True,
26
- top_k=10,
27
- num_return_sequences=1,
28
- eos_token_id=tokenizer.eos_token_id,
29
- )
30
- for seq in sequences:
31
- print(f"Result: {seq['generated_text']}")
32
 
33
- return sequences[0]['generated_text']
34
 
35
  iface = gr.Interface(fn=falcon, inputs="text", outputs="text")
36
- iface.launch() # To create a public link, set `share=True`
 
1
  import gradio as gr
2
 
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoConfig,
6
+ AutoTokenizer,
7
+ BitsAndBytesConfig,
8
+ )
9
  import transformers
10
  import torch
11
 
12
+ model_name = "tiiuae/falcon-40b"
13
 
14
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model,
17
+ quantization_config=BitsAndBytesConfig(load_in_4bit=True),
 
 
18
  trust_remote_code=True,
19
+ torch_dtype=torch.bfloat16,
20
  device_map="auto",
21
  )
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
 
24
 
25
+ def falcon(input_text):
26
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
27
+ outputs = model.generate(input_ids, max_length=100, do_sample=True, top_k=10)
28
+ decoded = tokenizer.decode(outputs[0])
29
+
30
+ return decoded
 
 
 
 
31
 
 
32
 
33
  iface = gr.Interface(fn=falcon, inputs="text", outputs="text")
34
+ iface.launch() # To create a public link, set `share=True`