Tonic commited on
Commit
138bd73
1 Parent(s): 52e94e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,9 +1,9 @@
 
1
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
2
  import torch
3
  import gradio as gr
4
  import random
5
  from textwrap import wrap
6
- import spaces
7
 
8
  def wrap_text(text, width=90):
9
  lines = text.split('\n')
@@ -51,7 +51,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id = model_id, trust_remote_code
51
  # Specify the configuration class for the model
52
  #model_config = AutoConfig.from_pretrained(base_model_id)
53
 
54
- model = MistralForCaumodel = AutoModelForCausalLM.from_pretrained(model_id)
55
 
56
  class ChatBot:
57
  def __init__(self):
@@ -64,7 +64,7 @@ class ChatBot:
64
 
65
  def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
66
  # Combine the user's input with the system prompt
67
- formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"
68
 
69
  # Encode the formatted input using the tokenizer
70
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
 
1
+ import spaces
2
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
3
  import torch
4
  import gradio as gr
5
  import random
6
  from textwrap import wrap
 
7
 
8
  def wrap_text(text, width=90):
9
  lines = text.split('\n')
 
51
  # Specify the configuration class for the model
52
  #model_config = AutoConfig.from_pretrained(base_model_id)
53
 
54
+ model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
55
 
56
  class ChatBot:
57
  def __init__(self):
 
64
 
65
  def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
66
  # Combine the user's input with the system prompt
67
+ formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
68
 
69
  # Encode the formatted input using the tokenizer
70
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")