neuralleap commited on
Commit
b188633
1 Parent(s): a1dfc9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -45
app.py CHANGED
@@ -55,7 +55,7 @@ h1 {
55
  model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token)
56
  model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
57
  tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
58
- text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.8, top_p=0.95, repetition_penalty=1.15)
59
 
60
  terminators = [
61
  tokenizer.eos_token_id,
@@ -63,51 +63,41 @@ terminators = [
63
  ]
64
 
65
 
66
- @spaces.GPU(duration=120)
67
- def chat_llama3_8b(message: str,
68
- history: list,
69
- temperature: float,
70
- max_new_tokens: int
71
- ) -> str:
72
- """
73
- Generate a streaming response using the llama3-8b model.
74
- Args:
75
- message (str): The input message.
76
- history (list): The conversation history used by ChatInterface.
77
- temperature (float): The temperature for generating the response.
78
- max_new_tokens (int): The maximum number of new tokens to generate.
79
- Returns:
80
- str: The generated response.
81
- """
82
- conversation = []
83
- for user, assistant in history:
84
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
85
- conversation.append({"role": "user", "content": message})
86
-
87
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
88
-
89
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
90
-
91
- generate_kwargs = dict(
92
- input_ids= input_ids,
93
- streamer=streamer,
94
- max_new_tokens=max_new_tokens,
95
- do_sample=True,
96
- temperature=temperature,
97
- eos_token_id=terminators,
98
- )
99
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
100
- if temperature == 0:
101
- generate_kwargs['do_sample'] = False
102
 
103
- t = Thread(target=model.generate, kwargs=generate_kwargs)
104
- t.start()
105
-
106
- outputs = []
107
- for text in streamer:
108
- outputs.append(text)
109
- #print(outputs)
110
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  # Gradio block
 
55
  model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token)
56
  model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
57
  tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
58
+ text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.8, top_p=0.95, repetition_penalty=1.15)
59
 
60
  terminators = [
61
  tokenizer.eos_token_id,
 
63
  ]
64
 
65
 
66
+ def generate_response(input_ids, generate_kwargs):
67
+ try:
68
+ # Generate the output using the model
69
+ output = model.generate(**generate_kwargs)
70
+ return output
71
+ except Exception as e:
72
+ print(f"Error during generation: {e}")
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ @gr.cache(max_size=100, expire=3600)
76
+ def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=512):
77
+ # Prepare conversation context
78
+ conversation = [{"role": "user", "content": message}] + [{"role": "assistant", "content": reply} for reply in history]
79
+ input_ids = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
80
+
81
+ generate_kwargs = {
82
+ "input_ids": input_ids,
83
+ "max_length": input_ids.shape[1] + max_new_tokens,
84
+ "temperature": temperature,
85
+ "num_return_sequences": 1
86
+ }
87
+
88
+ # Thread for generating model response
89
+ output_queue = []
90
+ response_thread = Thread(target=generate_response, args=(input_ids, generate_kwargs, output_queue))
91
+ response_thread.start()
92
+ response_thread.join() # Wait for the thread to complete
93
+
94
+ # Retrieve the output from the queue
95
+ if output_queue:
96
+ output = output_queue[0]
97
+ if output is not None:
98
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
99
+ return generated_text
100
+ return "An error occurred during text generation."
101
 
102
 
103
  # Gradio block