shri171981 commited on
Commit
2195ca0
·
verified ·
1 Parent(s): 74e3539

Using InferenceClient API

Browse files
Files changed (1) hide show
  1. app.py +35 -59
app.py CHANGED
@@ -1,39 +1,19 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel
5
 
6
- # 1. Define your model ID
7
- # REPLACE THIS with your actual username/repo name
8
- ADAPTER_ID = "shri171981/medical_chat_generative"
9
 
10
- def load_model():
11
- # Load Base Model (Llama-3-8B)
12
- # We use "cpu" and float32 if you are on the Free Tier (Slow but works)
13
- # If you have a GPU in your Space, change device_map to "auto"
14
- base_model_name = "unsloth/llama-3-8b-instruct-bnb-4bit"
15
-
16
- print("Loading base model...")
17
- base_model = AutoModelForCausalLM.from_pretrained(
18
- base_model_name,
19
- device_map="cpu", # Change to "auto" if you have a GPU Space
20
- torch_dtype=torch.float32,
21
- low_cpu_mem_usage=True
22
- )
23
-
24
- print("Loading adapter...")
25
- model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
26
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
27
- return model, tokenizer
28
 
29
- # Load the model once at startup
30
- model, tokenizer = load_model()
31
-
32
- def ask_doctor(message, history):
33
- # 1. Format the input for Llama-3
34
- # We strictly enforce the "HACK_DOC" format
35
  system_prompt = "You are a helpful and empathetic medical doctor. Answer the patient's question based on the input provided."
36
- full_prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
37
 
38
  ### Instruction:
39
  {system_prompt}
@@ -43,36 +23,32 @@ def ask_doctor(message, history):
43
 
44
  ### Response:
45
  """
46
-
47
- # 2. Tokenize and Generate
48
- inputs = tokenizer(full_prompt, return_tensors="pt")
49
-
50
- # Generate response
51
- with torch.no_grad():
52
- outputs = model.generate(
53
- **inputs,
54
  max_new_tokens=128,
55
- temperature=0.7
 
56
  )
57
-
58
- # 3. Decode output
59
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
-
61
- # 4. Clean up the text (Remove the prompt part)
62
- # We split by "Response:" and take the last part
63
- clean_answer = response.split("Response:")[-1].strip()
64
-
65
- return clean_answer
66
-
67
- # 3. Build the UI
68
- interface = gr.ChatInterface(
69
- fn=ask_doctor,
70
- title="🚑 HACK_DOC AI",
71
- description="I am a specialized medical assistant. Ask me about symptoms!",
72
- examples=["I have a sharp pain in my chest.", "What should I take for a fever?", "My skin is itchy and red."],
73
- # theme="soft"
74
  )
75
 
76
- # 4. Launch
77
  if __name__ == "__main__":
78
- interface.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
 
4
 
5
+ # 1. Setup the Client
6
+ # We fetch the token you just added to Secrets
7
+ client = InferenceClient(token=os.getenv("HF_TOKEN"))
8
 
9
+ # 2. Your Model ID (The Adapter)
10
+ # The API is smart enough to see it's an adapter and load the Base Model automatically.
11
+ MODEL_ID = "shri171981/genai_hack_doc"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def ask_api(message, history):
14
+ # 3. Format the prompt (Strict Llama-3 format)
 
 
 
 
15
  system_prompt = "You are a helpful and empathetic medical doctor. Answer the patient's question based on the input provided."
16
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
17
 
18
  ### Instruction:
19
  {system_prompt}
 
23
 
24
  ### Response:
25
  """
26
+
27
+ try:
28
+ # 4. Send to the API
29
+ response = client.text_generation(
30
+ prompt,
31
+ model=MODEL_ID,
 
 
32
  max_new_tokens=128,
33
+ temperature=0.7,
34
+ return_full_text=False # We only want the new part
35
  )
36
+ return response
37
+
38
+ except Exception as e:
39
+ # 5. Handle "Model Loading" errors
40
+ # If the model is cold, the API returns a 503 error.
41
+ if "Model is loading" in str(e):
42
+ return "⚠️ The model is waking up (Cold Start). Please wait 30 seconds and try again!"
43
+ return f"Error: {str(e)}"
44
+
45
+ # 6. Launch
46
+ demo = gr.ChatInterface(
47
+ fn=ask_api,
48
+ title="🚑 HACK_DOC (API Powered)",
49
+ description="Running on Hugging Face Serverless GPU via API.",
50
+ examples=["I have a sharp pain in my chest.", "What is good for a fever?"],
 
 
51
  )
52
 
 
53
  if __name__ == "__main__":
54
+ demo.launch()