|
import os |
|
import re |
|
from functools import lru_cache |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface") |
|
|
|
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub") |
|
os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio") |
|
os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio") |
|
|
|
for p in [ |
|
|
|
"/data/.cache/huggingface/hub", |
|
"/data/gradio", |
|
]: |
|
try: |
|
os.makedirs(p, exist_ok=True) |
|
except Exception: |
|
pass |
|
|
|
|
|
try: |
|
from zoneinfo import ZoneInfo |
|
except Exception: |
|
ZoneInfo = None |
|
|
|
|
|
try: |
|
import cohere |
|
_HAS_COHERE = True |
|
except Exception: |
|
_HAS_COHERE = False |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from huggingface_hub import login, HfApi |
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024") |
|
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN") |
|
COHERE_API_KEY = os.getenv("COHERE_API_KEY") |
|
USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE) |
|
|
|
|
|
|
|
|
|
def pick_dtype_and_map(): |
|
if torch.cuda.is_available(): |
|
return torch.float16, "auto" |
|
if torch.backends.mps.is_available(): |
|
return torch.float16, {"": "mps"} |
|
return torch.float32, "cpu" |
|
|
|
def is_identity_query(message, history): |
|
patterns = [ |
|
r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", |
|
r"\bwhat\s+is\s+your\s+name\b", r"\bwho\s+is\s+this\b", |
|
r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b", |
|
r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", |
|
r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b" |
|
] |
|
def match(t): |
|
return any(re.search(p, (t or "").strip().lower()) for p in patterns) |
|
if match(message): |
|
return True |
|
if history: |
|
last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None |
|
if match(last_user): |
|
return True |
|
return False |
|
|
|
def _iter_user_assistant(history): |
|
""" |
|
Yield (user, assistant) pairs from a Gradio history list. |
|
Safely handles items that are lists/tuples with >2 elements. |
|
""" |
|
for item in (history or []): |
|
if isinstance(item, (list, tuple)): |
|
u = item[0] if len(item) > 0 else "" |
|
a = item[1] if len(item) > 1 else "" |
|
yield u, a |
|
|
|
|
|
def _history_to_prompt(message, history): |
|
"""Build a simple text prompt for the stable cohere.chat API.""" |
|
parts = [] |
|
for u, a in _iter_user_assistant(history): |
|
if u: |
|
parts.append(f"User: {u}") |
|
if a: |
|
parts.append(f"Assistant: {a}") |
|
parts.append(f"User: {message}") |
|
parts.append("Assistant:") |
|
return "\n".join(parts) |
|
|
|
|
|
|
|
|
|
_co_client = None |
|
if USE_HOSTED_COHERE: |
|
_co_client = cohere.Client(api_key=COHERE_API_KEY) |
|
|
|
def cohere_chat(message, history): |
|
try: |
|
prompt = _history_to_prompt(message, history) |
|
resp = _co_client.chat( |
|
model="command-r7b-12-2024", |
|
message=prompt, |
|
temperature=0.3, |
|
max_tokens=350, |
|
) |
|
if hasattr(resp, "text") and resp.text: |
|
return resp.text.strip() |
|
if hasattr(resp, "reply") and resp.reply: |
|
return resp.reply.strip() |
|
if hasattr(resp, "generations") and resp.generations: |
|
return resp.generations[0].text.strip() |
|
return "Sorry, I couldn't parse the response from Cohere." |
|
except Exception as e: |
|
return f"Error calling Cohere API: {e}" |
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_local_model(): |
|
if not HF_TOKEN: |
|
raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.") |
|
login(token=HF_TOKEN, add_to_git_credential=False) |
|
dtype, device_map = pick_dtype_and_map() |
|
tok = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
token=HF_TOKEN, |
|
use_fast=True, |
|
model_max_length=4096, |
|
padding_side="left", |
|
trust_remote_code=True, |
|
) |
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
token=HF_TOKEN, |
|
device_map=device_map, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=dtype, |
|
trust_remote_code=True, |
|
) |
|
if mdl.config.eos_token_id is None and tok.eos_token_id is not None: |
|
mdl.config.eos_token_id = tok.eos_token_id |
|
return mdl, tok |
|
|
|
def build_inputs(tokenizer, message, history): |
|
msgs = [] |
|
for u, a in _iter_user_assistant(history): |
|
if u: |
|
msgs.append({"role": "user", "content": u}) |
|
if a: |
|
msgs.append({"role": "assistant", "content": a}) |
|
msgs.append({"role": "user", "content": message}) |
|
return tokenizer.apply_chat_template( |
|
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt" |
|
) |
|
|
|
def local_generate(model, tokenizer, input_ids, max_new_tokens=350): |
|
input_ids = input_ids.to(model.device) |
|
with torch.no_grad(): |
|
out = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=0.3, |
|
top_p=0.9, |
|
repetition_penalty=1.15, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
gen_only = out[0, input_ids.shape[-1]:] |
|
return tokenizer.decode(gen_only, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
|
def chat_fn(message, history, user_tz): |
|
try: |
|
if is_identity_query(message, history): |
|
return "I am ClarityOps, your strategic decision making AI partner." |
|
if USE_HOSTED_COHERE: |
|
return cohere_chat(message, history) |
|
model, tokenizer = load_local_model() |
|
inputs = build_inputs(tokenizer, message, history) |
|
return local_generate(model, tokenizer, inputs, max_new_tokens=350) |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="teal", |
|
neutral_hue="slate", |
|
radius_size=gr.themes.sizes.radius_lg, |
|
) |
|
|
|
custom_css = """ |
|
:root { |
|
--brand-bg: #e6f7f8; /* soft medical teal */ |
|
--brand-accent: #0d9488; /* teal-600 */ |
|
--brand-text: #0f172a; |
|
--brand-text-light: #ffffff; |
|
} |
|
|
|
/* Page background */ |
|
.gradio-container { |
|
background: var(--brand-bg); |
|
} |
|
|
|
/* Title */ |
|
h1 { |
|
color: var(--brand-text); |
|
font-weight: 700; |
|
font-size: 28px !important; |
|
} |
|
|
|
/* Try to hide the default Chatbot label via CSS for multiple Gradio builds */ |
|
.chatbot header, |
|
.chatbot .label, |
|
.chatbot .label-wrap, |
|
.chatbot .top, |
|
.chatbot .header, |
|
.chatbot > .wrap > header { |
|
display: none !important; |
|
} |
|
|
|
/* Both bot and user bubbles teal with white text */ |
|
.message.user, .message.bot { |
|
background: var(--brand-accent) !important; |
|
color: var(--brand-text-light) !important; |
|
border-radius: 12px !important; |
|
padding: 8px 12px !important; |
|
} |
|
|
|
/* Inputs a bit softer */ |
|
textarea, input, .gr-input { |
|
border-radius: 12px !important; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=theme, css=custom_css) as demo: |
|
|
|
tz_box = gr.Textbox(visible=False) |
|
demo.load(lambda tz: tz, inputs=[tz_box], outputs=[tz_box], |
|
js="() => Intl.DateTimeFormat().resolvedOptions().timeZone") |
|
|
|
|
|
hide_label_sink = gr.HTML(visible=False) |
|
demo.load( |
|
fn=lambda: "", |
|
inputs=None, |
|
outputs=hide_label_sink, |
|
js=""" |
|
() => { |
|
const sel = [ |
|
'.chatbot header', |
|
'.chatbot .label', |
|
'.chatbot .label-wrap', |
|
'.chatbot .top', |
|
'.chatbot .header', |
|
'.chatbot > .wrap > header' |
|
]; |
|
sel.forEach(s => document.querySelectorAll(s).forEach(el => el.style.display = 'none')); |
|
return ""; |
|
} |
|
""" |
|
) |
|
|
|
|
|
gr.Markdown("# ClarityOps Augmented Decision AI") |
|
|
|
gr.ChatInterface( |
|
fn=chat_fn, |
|
type="messages", |
|
additional_inputs=[tz_box], |
|
chatbot=gr.Chatbot(label="", show_label=False, type="messages"), |
|
examples=[ |
|
["What are the symptoms of hypertension?", ""], |
|
["What are common drug interactions with aspirin?", ""], |
|
["What are the warning signs of diabetes?", ""], |
|
], |
|
cache_examples=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
port = int(os.environ.get("PORT", "7860")) |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=port, |
|
show_api=False, |
|
max_threads=8, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|