Spaces:
Sleeping
Sleeping
Samanta Das
commited on
Update llama.py
Browse files
llama.py
CHANGED
@@ -8,18 +8,10 @@ import logging
|
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
|
11 |
-
# Load environment variables
|
12 |
load_dotenv()
|
13 |
-
API_KEYS = [
|
14 |
-
os.getenv("GORQ_API_KEY_1"),
|
15 |
-
os.getenv("GORQ_API_KEY_2"),
|
16 |
-
os.getenv("GORQ_API_KEY_3"),
|
17 |
-
os.getenv("GORQ_API_KEY_4")
|
18 |
-
]
|
19 |
|
20 |
class GorQClient:
|
21 |
-
"""Client for interacting with the Groq API to call the LLaMA model."""
|
22 |
-
|
23 |
def __init__(self):
|
24 |
self.current_api_index = 0
|
25 |
self.client = Groq(api_key=self.get_api_key())
|
@@ -27,17 +19,21 @@ class GorQClient:
|
|
27 |
|
28 |
def switch_api_key(self):
|
29 |
"""Switch to the next available API key."""
|
30 |
-
self.current_api_index = (self.current_api_index + 1) %
|
31 |
logging.info("Switched to API key index: %d", self.current_api_index)
|
32 |
self.client = Groq(api_key=self.get_api_key())
|
33 |
|
34 |
def get_api_key(self):
|
35 |
-
"""Get the current API key."""
|
36 |
-
return
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def call_llama_model(self, user_message):
|
39 |
"""Call the LLaMA model and return the response."""
|
40 |
-
max_retries =
|
41 |
retries = 0
|
42 |
|
43 |
while retries < max_retries:
|
@@ -71,12 +67,14 @@ class GorQClient:
|
|
71 |
for chunk in completion:
|
72 |
response += chunk.choices[0].delta.content or ""
|
73 |
|
|
|
74 |
response = self.clean_response(response)
|
|
|
75 |
logging.info("Response received from LLaMA model.")
|
76 |
return response
|
77 |
|
78 |
except Exception as e:
|
79 |
-
logging.error(f"API error
|
80 |
self.switch_api_key()
|
81 |
retries += 1
|
82 |
logging.info("Retrying with new API key.")
|
@@ -85,10 +83,8 @@ class GorQClient:
|
|
85 |
|
86 |
def clean_response(self, response):
|
87 |
"""Remove any unwanted characters or formatting artifacts."""
|
88 |
-
|
89 |
-
|
90 |
-
response = response.replace(char, "")
|
91 |
-
return response.strip()
|
92 |
|
93 |
def format_confidence_level(confidence):
|
94 |
"""Convert numerical confidence to descriptive text."""
|
@@ -106,6 +102,7 @@ def generate_response_based_on_yolo(yolo_output):
|
|
106 |
if not yolo_output:
|
107 |
return "Fracture Analysis Report: No fractures detected in the provided image. However, if you're experiencing pain or discomfort, please consult a healthcare professional for a thorough evaluation."
|
108 |
|
|
|
109 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
110 |
details = []
|
111 |
|
@@ -130,6 +127,7 @@ def generate_response_based_on_yolo(yolo_output):
|
|
130 |
}
|
131 |
})
|
132 |
|
|
|
133 |
user_message = json.dumps({
|
134 |
"timestamp": current_time,
|
135 |
"analysis_type": "Fracture Detection and Analysis",
|
@@ -137,10 +135,12 @@ def generate_response_based_on_yolo(yolo_output):
|
|
137 |
"request": "Provide a comprehensive analysis including immediate care recommendations, pain management strategies, and follow-up care instructions. Include specific details about each detected fracture and potential complications to watch for."
|
138 |
}, indent=2)
|
139 |
|
|
|
140 |
response = gorq_client.call_llama_model(user_message)
|
141 |
|
142 |
return response
|
143 |
|
|
|
144 |
if __name__ == "__main__":
|
145 |
# Example YOLO output for testing
|
146 |
example_yolo_output = [
|
@@ -152,5 +152,6 @@ if __name__ == "__main__":
|
|
152 |
}
|
153 |
]
|
154 |
|
|
|
155 |
response = generate_response_based_on_yolo(example_yolo_output)
|
156 |
print(response)
|
|
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
|
11 |
+
# Load environment variables (only needed if using a local .env file)
|
12 |
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class GorQClient:
|
|
|
|
|
15 |
def __init__(self):
|
16 |
self.current_api_index = 0
|
17 |
self.client = Groq(api_key=self.get_api_key())
|
|
|
19 |
|
20 |
def switch_api_key(self):
|
21 |
"""Switch to the next available API key."""
|
22 |
+
self.current_api_index = (self.current_api_index + 1) % self.number_of_api_keys()
|
23 |
logging.info("Switched to API key index: %d", self.current_api_index)
|
24 |
self.client = Groq(api_key=self.get_api_key())
|
25 |
|
26 |
def get_api_key(self):
|
27 |
+
"""Get the current API key from environment variables."""
|
28 |
+
return os.getenv(f"GORQ_API_KEY_{self.current_api_index + 1}")
|
29 |
+
|
30 |
+
def number_of_api_keys(self):
|
31 |
+
"""Get the number of available API keys."""
|
32 |
+
return 4 # Change this if you add more keys
|
33 |
|
34 |
def call_llama_model(self, user_message):
|
35 |
"""Call the LLaMA model and return the response."""
|
36 |
+
max_retries = self.number_of_api_keys()
|
37 |
retries = 0
|
38 |
|
39 |
while retries < max_retries:
|
|
|
67 |
for chunk in completion:
|
68 |
response += chunk.choices[0].delta.content or ""
|
69 |
|
70 |
+
# Clean up the response to remove formatting artifacts
|
71 |
response = self.clean_response(response)
|
72 |
+
|
73 |
logging.info("Response received from LLaMA model.")
|
74 |
return response
|
75 |
|
76 |
except Exception as e:
|
77 |
+
logging.error(f"API error: {e}")
|
78 |
self.switch_api_key()
|
79 |
retries += 1
|
80 |
logging.info("Retrying with new API key.")
|
|
|
83 |
|
84 |
def clean_response(self, response):
|
85 |
"""Remove any unwanted characters or formatting artifacts."""
|
86 |
+
response = response.replace("*", "").replace("•", "").replace("#", "").replace("`", "").strip()
|
87 |
+
return response
|
|
|
|
|
88 |
|
89 |
def format_confidence_level(confidence):
|
90 |
"""Convert numerical confidence to descriptive text."""
|
|
|
102 |
if not yolo_output:
|
103 |
return "Fracture Analysis Report: No fractures detected in the provided image. However, if you're experiencing pain or discomfort, please consult a healthcare professional for a thorough evaluation."
|
104 |
|
105 |
+
# Process YOLO output to create detailed analysis
|
106 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
107 |
details = []
|
108 |
|
|
|
127 |
}
|
128 |
})
|
129 |
|
130 |
+
# Construct detailed message for LLM
|
131 |
user_message = json.dumps({
|
132 |
"timestamp": current_time,
|
133 |
"analysis_type": "Fracture Detection and Analysis",
|
|
|
135 |
"request": "Provide a comprehensive analysis including immediate care recommendations, pain management strategies, and follow-up care instructions. Include specific details about each detected fracture and potential complications to watch for."
|
136 |
}, indent=2)
|
137 |
|
138 |
+
# Get detailed response from LLM
|
139 |
response = gorq_client.call_llama_model(user_message)
|
140 |
|
141 |
return response
|
142 |
|
143 |
+
# Test the functionality (This part should be executed in a separate test file or interactive environment)
|
144 |
if __name__ == "__main__":
|
145 |
# Example YOLO output for testing
|
146 |
example_yolo_output = [
|
|
|
152 |
}
|
153 |
]
|
154 |
|
155 |
+
# Generate a response based on the example YOLO output
|
156 |
response = generate_response_based_on_yolo(example_yolo_output)
|
157 |
print(response)
|