neuralleap commited on
Commit
b783cda
1 Parent(s): 0c7ce10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -4,7 +4,13 @@ import spaces
4
  from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
 
 
 
 
 
7
 
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
@@ -48,13 +54,12 @@ h1 {
48
  }
49
  """
50
 
51
- # Load the tokenizer and model
52
- tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1")
53
- model = AutoModelForCausalLM.from_pretrained("physician-ai/mistral-finetuned1", device_map="auto") # to("cuda:0")
54
- terminators = [
55
- tokenizer.eos_token_id,
56
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
57
- ]
58
 
59
  @spaces.GPU(duration=120)
60
  def chat_llama3_8b(message: str,
 
4
  from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
+ import gradio as gr
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ import transformers
10
+ import torch
11
+ from peft import PeftModel, PeftConfig
12
 
13
+ access_token = os.getenv('HF_TOKEN')
14
  # Set an environment variable
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
 
54
  }
55
  """
56
 
57
+
58
+ #config = PeftConfig.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
59
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token)
60
+ model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
61
+ tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
62
+ text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.8, top_p=0.95, repetition_penalty=1.15)
 
63
 
64
  @spaces.GPU(duration=120)
65
  def chat_llama3_8b(message: str,