Tonic commited on
Commit
c30f436
1 Parent(s): 8867e8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModel, AutoConfig
4
 
5
  # Use the base model's ID
6
  base_model_id = "mistralai/Mistral-7B-v0.1"
7
- config = AutoConfig.from_pretrained(base_model_id)
 
 
8
 
9
  # Load the fine-tuned model "Tonic/mistralmed"
10
  model = AutoModel.from_pretrained("Tonic/mistralmed", config=config)
@@ -21,7 +23,7 @@ class ChatBot:
21
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
22
  flat_history = [item for sublist in self.history for item in sublist]
23
  flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
24
- bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
25
  chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
26
  self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
27
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModel, BertConfig # Use BertConfig for your Mistral model
4
 
5
  # Use the base model's ID
6
  base_model_id = "mistralai/Mistral-7B-v0.1"
7
+
8
+ # Create a configuration object specific to the base model (you can replace with your model's actual configuration if available)
9
+ config = BertConfig()
10
 
11
  # Load the fine-tuned model "Tonic/mistralmed"
12
  model = AutoModel.from_pretrained("Tonic/mistralmed", config=config)
 
23
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
24
  flat_history = [item for sublist in self.history for item in sublist]
25
  flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
26
+ bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self history else new_user_input_ids
27
  chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
28
  self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
29
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)