Pclanglais commited on
Commit
208476f
1 Parent(s): 459a15e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -14,6 +14,8 @@ from chromadb.utils import embedding_functions
14
  from FlagEmbedding import BGEM3FlagModel
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
 
 
17
  model = BGEM3FlagModel('BAAI/bge-m3',
18
  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
19
 
@@ -22,16 +24,16 @@ embeddings_data = pd.read_json("embeddings_tchap.json")
22
  embeddings_text = embeddings_data["text_with_context"].tolist()
23
 
24
  # Define the device
25
- #device = "cuda" if torch.cuda.is_available() else "cpu"
26
- #Define variables
27
  temperature=0.2
28
  max_new_tokens=1000
29
  top_p=0.92
30
  repetition_penalty=1.7
31
 
32
- #model_name = "Pclanglais/Tchap"
33
 
34
- #llm = LLM(model_name, max_model_len=4096)
 
 
35
 
36
  system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
37
 
@@ -78,7 +80,7 @@ def predict(message, history):
78
 
79
  messages = system_prompt + messages
80
 
81
- """"model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
82
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
83
  generate_kwargs = dict(
84
  model_inputs,
@@ -98,7 +100,7 @@ def predict(message, history):
98
  for new_token in streamer:
99
  if new_token != '<':
100
  partial_message += new_token
101
- yield partial_message"""
102
  return messages
103
 
104
  # Define the Gradio interface
 
14
  from FlagEmbedding import BGEM3FlagModel
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
  model = BGEM3FlagModel('BAAI/bge-m3',
20
  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
21
 
 
24
  embeddings_text = embeddings_data["text_with_context"].tolist()
25
 
26
  # Define the device
 
 
27
  temperature=0.2
28
  max_new_tokens=1000
29
  top_p=0.92
30
  repetition_penalty=1.7
31
 
32
+ model_name = "Pclanglais/Tchap"
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
36
+ model = model.to('cuda:0')
37
 
38
  system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
39
 
 
80
 
81
  messages = system_prompt + messages
82
 
83
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
84
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
85
  generate_kwargs = dict(
86
  model_inputs,
 
100
  for new_token in streamer:
101
  if new_token != '<':
102
  partial_message += new_token
103
+ yield partial_message
104
  return messages
105
 
106
  # Define the Gradio interface