Spaces:
Paused
Paused
Commit
•
b188633
1
Parent(s):
a1dfc9d
Update app.py
Browse files
app.py
CHANGED
@@ -55,7 +55,7 @@ h1 {
|
|
55 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token)
|
56 |
model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
|
57 |
tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
|
58 |
-
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=
|
59 |
|
60 |
terminators = [
|
61 |
tokenizer.eos_token_id,
|
@@ -63,51 +63,41 @@ terminators = [
|
|
63 |
]
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
Args:
|
75 |
-
message (str): The input message.
|
76 |
-
history (list): The conversation history used by ChatInterface.
|
77 |
-
temperature (float): The temperature for generating the response.
|
78 |
-
max_new_tokens (int): The maximum number of new tokens to generate.
|
79 |
-
Returns:
|
80 |
-
str: The generated response.
|
81 |
-
"""
|
82 |
-
conversation = []
|
83 |
-
for user, assistant in history:
|
84 |
-
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
85 |
-
conversation.append({"role": "user", "content": message})
|
86 |
-
|
87 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
88 |
-
|
89 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
90 |
-
|
91 |
-
generate_kwargs = dict(
|
92 |
-
input_ids= input_ids,
|
93 |
-
streamer=streamer,
|
94 |
-
max_new_tokens=max_new_tokens,
|
95 |
-
do_sample=True,
|
96 |
-
temperature=temperature,
|
97 |
-
eos_token_id=terminators,
|
98 |
-
)
|
99 |
-
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
100 |
-
if temperature == 0:
|
101 |
-
generate_kwargs['do_sample'] = False
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
# Gradio block
|
|
|
55 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token)
|
56 |
model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
|
57 |
tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
|
58 |
+
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.8, top_p=0.95, repetition_penalty=1.15)
|
59 |
|
60 |
terminators = [
|
61 |
tokenizer.eos_token_id,
|
|
|
63 |
]
|
64 |
|
65 |
|
66 |
+
def generate_response(input_ids, generate_kwargs):
|
67 |
+
try:
|
68 |
+
# Generate the output using the model
|
69 |
+
output = model.generate(**generate_kwargs)
|
70 |
+
return output
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error during generation: {e}")
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
@gr.cache(max_size=100, expire=3600)
|
76 |
+
def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=512):
|
77 |
+
# Prepare conversation context
|
78 |
+
conversation = [{"role": "user", "content": message}] + [{"role": "assistant", "content": reply} for reply in history]
|
79 |
+
input_ids = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
|
80 |
+
|
81 |
+
generate_kwargs = {
|
82 |
+
"input_ids": input_ids,
|
83 |
+
"max_length": input_ids.shape[1] + max_new_tokens,
|
84 |
+
"temperature": temperature,
|
85 |
+
"num_return_sequences": 1
|
86 |
+
}
|
87 |
+
|
88 |
+
# Thread for generating model response
|
89 |
+
output_queue = []
|
90 |
+
response_thread = Thread(target=generate_response, args=(input_ids, generate_kwargs, output_queue))
|
91 |
+
response_thread.start()
|
92 |
+
response_thread.join() # Wait for the thread to complete
|
93 |
+
|
94 |
+
# Retrieve the output from the queue
|
95 |
+
if output_queue:
|
96 |
+
output = output_queue[0]
|
97 |
+
if output is not None:
|
98 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
99 |
+
return generated_text
|
100 |
+
return "An error occurred during text generation."
|
101 |
|
102 |
|
103 |
# Gradio block
|