not-lain commited on
Commit
3a82207
1 Parent(s): 96bfd00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -9
app.py CHANGED
@@ -1,16 +1,61 @@
1
  import gradio as gr
 
 
 
 
 
2
  import os
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
  token = os.environ["HF_TOKEN"]
5
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b",token=token)
6
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",token=token)
7
- streamer = TextStreamer(tokenizer,skip_prompt=True)
8
 
 
 
 
 
 
9
 
10
- def generate(inputs,history):
11
- inputs = tokenizer([inputs], return_tensors="pt")
12
- yield model.generate(**inputs, streamer=streamer)
13
 
 
14
 
15
- app = gr.ChatInterface(generate)
16
- app.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ import time
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
  import os
8
+ from threading import Thread
9
  token = os.environ["HF_TOKEN"]
 
 
 
10
 
11
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,token=token)
12
+ tok = AutoTokenizer.from_pretrained("google/gemma-2b-it",token=token)
13
+ # using CUDA for an optimal experience
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ model = model.to(device)
16
 
 
 
 
17
 
18
+ start_message = ""
19
 
20
+ def user(message, history):
21
+ # Append the user's message to the conversation history
22
+ return "", history + [[message, ""]]
23
+
24
+
25
+ def chat(message, history):
26
+ chat = []
27
+ for item in history:
28
+ chat.append({"role": "user", "content": item[0]})
29
+ if item[1] is not None:
30
+ chat.append({"role": "assistant", "content": item[1]})
31
+ chat.append({"role": "user", "content": message})
32
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
33
+ # Tokenize the messages string
34
+ model_inputs = tok([messages], return_tensors="pt").to(device)
35
+ streamer = TextIteratorStreamer(
36
+ tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
37
+ generate_kwargs = dict(
38
+ model_inputs,
39
+ streamer=streamer,
40
+ max_new_tokens=1024,
41
+ do_sample=True,
42
+ top_p=0.95,
43
+ top_k=1000,
44
+ temperature=0.75,
45
+ num_beams=1,
46
+ )
47
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
48
+ t.start()
49
+
50
+ # Initialize an empty string to store the generated text
51
+ partial_text = ""
52
+ for new_text in streamer:
53
+ # print(new_text)
54
+ partial_text += new_text
55
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
56
+ yield partial_text
57
+
58
+
59
+
60
+ demo = gr.ChatInterface(fn=chat, examples=[["Write me a poem about Machine Learning."]], title="gemma 2b-it")
61
+ demo.launch()