Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import torch | |
| import gradio as gr | |
| from datetime import datetime, timedelta | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---------------------------- | |
| # Default profile and blocks | |
| # ---------------------------- | |
| DEFAULT_PROFILE = { | |
| "name": "Learner", | |
| "style": ["concise", "reflective", "Basque context where relevant"], | |
| "goals": ["conversation-first learning", "daily language blocks", "CPU-only"] | |
| } | |
| DEFAULT_BLOCKS = [ | |
| {"type": "style", "rule": "Ask clarifying questions when uncertain."}, | |
| {"type": "vocab", "rule": "Use sensory detail + local place anchoring when writing creatively."}, | |
| {"type": "conversation", "rule": "Keep answers short and specific; avoid repeating conclusions."}, | |
| {"type": "conversation", "rule": "Offer warm, encouraging replies with a touch of humor or playfulness to lighten the mood."}, | |
| {"type": "conversation", "rule": "Use light satire, gentle irony, and clever humor when appropriate to keep dialogue playful and engaging."}, | |
| {"type": "conversation", "rule": "Clever swearing is allowed when it adds humor or emphasis, but keep it light, playful, and never offensive."} | |
| ] | |
| BLOCKS_FILE = "blocks.json" | |
| # ---------------------------- | |
| # Model options | |
| # ---------------------------- | |
| MODEL_OPTIONS = { | |
| "Phi-3.5 Mini Instruct (4B)": "microsoft/Phi-3.5-mini-instruct", | |
| "Phi-3.5 MoE Instruct (42B)": "microsoft/Phi-3.5-MoE-instruct", | |
| "Phi-3 Mini 4K Instruct (4B)": "microsoft/Phi-3-mini-4k-instruct", | |
| "Phi-3 Mini 128K Instruct (4B)": "microsoft/Phi-3-mini-128k-instruct" | |
| } | |
| # ---------------------------- | |
| # Example prompts | |
| # ---------------------------- | |
| EXAMPLES = [ | |
| "Tell me a about the oldest language in Europe, Euskera.", | |
| "I’ll teach you a concept. Repeat it back to me in simple words: Solar panels turn sunlight into electricity.", | |
| "Here’s a new phrase: 'The sea is calm today.' Try saying it in Basque.", | |
| "Let’s practice style: noir detective. Write one short sentence about Gros in that style.", | |
| "Here’s a Shakespeare line: 'All the world’s a stage.' What do you think it means?", | |
| "Read a Dickens passage and tell me how it feels — happy, sad, or something else?", | |
| "Summarize this paragraph....", | |
| "I’ll give you a sentence with a mistake: 'He go to school yesterday.' Can you fix it?" | |
| ] | |
| # ---------------------------- | |
| # Persistence helpers | |
| # ---------------------------- | |
| def load_blocks(): | |
| if os.path.exists(BLOCKS_FILE): | |
| try: | |
| with open(BLOCKS_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| pass | |
| return {"user_profile": DEFAULT_PROFILE, "language_blocks": DEFAULT_BLOCKS} | |
| def save_blocks(data): | |
| with open(BLOCKS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| def normalize_rule_text(text: str) -> str: | |
| return " ".join(text.strip().split()) | |
| def is_duplicate_rule(rules_list, new_rule_text, new_type="conversation"): | |
| key = (new_type.lower(), normalize_rule_text(new_rule_text).lower()) | |
| for r in rules_list: | |
| if (r.get("type", "").lower(), normalize_rule_text(r.get("rule", "")).lower()) == key: | |
| return True | |
| return False | |
| def add_block(data, rule_text, block_type="conversation", add_review=False): | |
| rule_text = normalize_rule_text(rule_text) | |
| if not rule_text: | |
| return data, "Rule is empty. Nothing added." | |
| rules = data.get("language_blocks", []) | |
| if is_duplicate_rule(rules, rule_text, block_type): | |
| return data, "Duplicate rule detected. Skipped." | |
| entry = {"type": block_type, "rule": rule_text} | |
| if add_review: | |
| entry["review_schedule"] = schedule_reviews() | |
| rules.append(entry) | |
| data["language_blocks"] = rules | |
| save_blocks(data) | |
| return data, f"Added rule: {rule_text}" | |
| def schedule_reviews(): | |
| today = datetime.utcnow().date() | |
| return [ | |
| str(today + timedelta(days=1)), | |
| str(today + timedelta(days=3)), | |
| str(today + timedelta(days=7)) | |
| ] | |
| # ---------------------------- | |
| # Model loading (CPU-only) | |
| # ---------------------------- | |
| _loaded = {} | |
| def load_model(model_id): | |
| if model_id in _loaded: | |
| return _loaded[model_id] | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| model.eval() | |
| _loaded[model_id] = (tokenizer, model) | |
| return tokenizer, model | |
| # ---------------------------- | |
| # Prompt construction | |
| # ---------------------------- | |
| def format_blocks(blocks): | |
| return "\n".join([f"- [{b['type']}] {b['rule']}" for b in blocks]) | |
| SYSTEM_TEMPLATE = """You are a conversation-first learning chatbot. | |
| Follow the user's style and goals, reinforce today's blocks, and confirm corrections. | |
| Active language blocks: | |
| {blocks} | |
| """ | |
| def build_messages(user_text, profile, blocks): | |
| system = SYSTEM_TEMPLATE.format(blocks=format_blocks(blocks)) | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user_text} | |
| ] | |
| def chat(user_text, model_label, blocks_json): | |
| data = load_blocks() | |
| blocks = parse_blocks_editor(blocks_json, data.get("language_blocks", [])) | |
| model_id = MODEL_OPTIONS[model_label] | |
| tokenizer, model = load_model(model_id) | |
| messages = build_messages(user_text, data["user_profile"], blocks) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True # ensures inputs is a dict, not just a tensor | |
| ).to("cpu") | |
| start = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, # now safe, inputs is a dict | |
| max_new_tokens=200, | |
| do_sample=False, | |
| use_cache=False | |
| ) | |
| latency = time.time() - start | |
| gen_text = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[-1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| input_tokens = int(inputs["input_ids"].shape[-1]) | |
| output_tokens = int(outputs[0].shape[-1] - inputs["input_ids"].shape[-1]) | |
| metrics = f"Input tokens: {input_tokens} | Output tokens: {output_tokens} | Latency: {latency:.2f}s" | |
| return gen_text, metrics | |
| def parse_blocks_editor(text, fallback): | |
| if not text or not text.strip(): | |
| return fallback | |
| text = text.strip() | |
| try: | |
| parsed = json.loads(text) | |
| if isinstance(parsed, list): | |
| return parsed | |
| except Exception: | |
| pass | |
| blocks = [] | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if ":" in line: | |
| t, r = line.split(":", 1) | |
| blocks.append({"type": t.strip(), "rule": r.strip()}) | |
| else: | |
| blocks.append({"type": "rule", "rule": line}) | |
| return blocks or fallback | |
| # ---------------------------- | |
| # Reflection | |
| # ---------------------------- | |
| def heuristic_rule(user_text, assistant_text): | |
| if "?" in assistant_text: | |
| return "Ask clarifying questions when uncertain." | |
| low = user_text.lower() | |
| if "translate" in low: | |
| return "Confirm translation intent and target tone before translating." | |
| if "style" in low or "noir" in low: | |
| return "Confirm style constraints before writing and keep it concise." | |
| return "Keep answers short, specific, and avoid repeating conclusions." | |
| def reflect_and_save(user_text, assistant_text, blocks_editor_value): | |
| data = load_blocks() | |
| proposal = heuristic_rule(user_text, assistant_text) | |
| data, msg = add_block(data, proposal, block_type="conversation", add_review=False) | |
| pretty = json.dumps(data["language_blocks"], ensure_ascii=False, indent=2) | |
| return pretty, msg | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| def launch(): | |
| data = load_blocks() | |
| default_blocks_text = json.dumps( | |
| data["language_blocks"], ensure_ascii=False, indent=2 | |
| ) | |
| with gr.Blocks(title="Conversation Learning Lab (CPU): Tiny Instruct") as demo: | |
| # Header | |
| gr.Markdown("# 🗣️ Conversation Learning Lab (CPU-friendly): Tiny Instruct") | |
| gr.Markdown( | |
| "Focus on daily dialogue. Reinforce validated language blocks. " | |
| "Transparent tokens and latency." | |
| ) | |
| # Model selector + input | |
| with gr.Row(): | |
| model_dd = gr.Dropdown( | |
| label="Choose a model", | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value="Phi-3.5 Mini Instruct (4B)" | |
| ) | |
| with gr.Row(): | |
| user_in = gr.Textbox( | |
| label="Your short message with clear instruction", | |
| placeholder="Start a conversation or choose an example below...", | |
| lines=3 | |
| ) | |
| # Example prompts | |
| gr.Markdown("### 🧪 Try an example prompt:") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=user_in | |
| ) | |
| # Generate button comes right after examples | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate (CPU)") | |
| # Output + metrics | |
| with gr.Row(): | |
| output = gr.Textbox(label="Assistant", lines=8) | |
| with gr.Row(): | |
| metrics = gr.Markdown("") | |
| # JSON blocks editor + Reflect button at the bottom | |
| gr.Markdown("### 📋 Today's Blocks") | |
| blocks_editor = gr.Textbox( | |
| label="Editable rules (JSON array or 'type: rule' lines)", | |
| value=default_blocks_text, | |
| lines=10 | |
| ) | |
| with gr.Row(): | |
| reflect_btn = gr.Button("Reflect & Save Rule") | |
| # Wire up events | |
| generate_btn.click( | |
| fn=chat, | |
| inputs=[user_in, model_dd, blocks_editor], | |
| outputs=[output, metrics] | |
| ) | |
| reflect_btn.click( | |
| fn=reflect_and_save, | |
| inputs=[user_in, output, blocks_editor], | |
| outputs=[blocks_editor, metrics] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| launch() |