skkjodhpur commited on
Commit
b519b92
·
verified ·
1 Parent(s): a270145
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -2,29 +2,33 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load model and tokenizer
6
- model_name = "skkjodhpur/Gemma-Code-Instruct-Finetune-by-skk"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
9
 
10
- # Move model to GPU if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = model.to(device)
13
 
14
  def generate_text(prompt):
15
  if not prompt.strip():
16
  return "Please enter a valid question."
17
-
18
  try:
 
19
  input_ids = tokenizer.encode(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").to(device)
 
 
20
  with torch.no_grad():
21
  output = model.generate(
22
  input_ids,
23
- max_length=200,
24
  num_return_sequences=1,
25
- do_sample=True,
26
- temperature=0.7,
27
  )
 
 
28
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
29
  return generated_text
30
  except Exception as e:
@@ -40,7 +44,8 @@ iface = gr.Interface(
40
  outputs="text",
41
  title="Doctors-Patient Chatbot",
42
  subtitle="Fine-Tuning GEMMA-2B for Doctor-Patient Interaction",
43
- description="Ask me any question related to patient concerns. This model is designed for educational and informational purposes only. Please do not use it for medical diagnosis or treatment. Always consult a qualified healthcare provider for medical advice."
 
44
  )
45
 
46
  iface.launch(share=True)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load a smaller model and tokenizer
6
+ model_name = "skkjodhpur/Gemma-Code-Instruct-Finetune-by-skk" # Consider a smaller model if available
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Move model to CPU
11
+ device = "cpu"
12
  model = model.to(device)
13
 
14
  def generate_text(prompt):
15
  if not prompt.strip():
16
  return "Please enter a valid question."
17
+
18
  try:
19
+ # Tokenize input
20
  input_ids = tokenizer.encode(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").to(device)
21
+
22
+ # Generate text with greedy search for faster response
23
  with torch.no_grad():
24
  output = model.generate(
25
  input_ids,
26
+ max_length=100, # Reduced max length for faster generation
27
  num_return_sequences=1,
28
+ do_sample=False, # Use greedy search
 
29
  )
30
+
31
+ # Decode and return the generated text
32
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
33
  return generated_text
34
  except Exception as e:
 
44
  outputs="text",
45
  title="Doctors-Patient Chatbot",
46
  subtitle="Fine-Tuning GEMMA-2B for Doctor-Patient Interaction",
47
+ description="Ask me any question related to patient concerns. This model is designed for educational and informational purposes only. Please do not use it for medical diagnosis or treatment. Always consult a qualified healthcare provider for medical advice.",
48
+ allow_flagging="never", # Disable flagging if not needed
49
  )
50
 
51
  iface.launch(share=True)