Spaces:
Running
Running
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import difflib | |
| # --- 1. Model Configuration --- | |
| # Model corresponding to paper: AutoGEO-mini (Based on Qwen1.5/1.7B) | |
| MODEL_ID = "cx-cmu/AutoGEO_mini_Qwen1.7B_ResearchyGEO" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| CUSTOM_CSS = """ | |
| #footer-content { | |
| height: auto !important; | |
| overflow: visible !important; | |
| } | |
| """ | |
| print(f"Loading model from {MODEL_ID} on {device}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Note: If there is no GPU memory or network issues, this will raise an exception. | |
| # raise e | |
| # --- 2. GEO-Rules Definitions (Aligned with Paper) --- | |
| # These rules correspond to the "GEO-Rules" dimensions in the paper | |
| RULES_DB = { | |
| "π‘οΈ Authority & Credibility": { | |
| "Factual Accuracy": "Ensure information is factually accurate and verifiable.", | |
| "Source Citation": "Attribute all factual claims to credible, authoritative sources with clear citations.", | |
| "Specific Evidence": "Substantiate claims with specific, concrete details like data, statistics, or named examples.", | |
| "Neutral Tone": "Maintain a neutral, objective tone, avoiding promotional language, personal opinions, and bias.", | |
| "Up-to-date": "Use current information, reflecting the latest state of knowledge." | |
| }, | |
| "ποΈ Structure & Logic": { | |
| "Logical Structure": "Structure content logically with clear headings, lists, and paragraphs to ensure a cohesive flow.", | |
| "Conclusion First": "State the key conclusion at the beginning of the document.", | |
| "Cohesive Flow": "Structure content logically with clear headings, lists, and paragraphs to ensure a cohesive flow.", | |
| "Self-Contained": "Present information as a self-contained unit, not requiring external links for core understanding." | |
| }, | |
| "π Content Quality": { | |
| "Comprehensive": "Cover the topic comprehensively, addressing all key aspects and sub-topics.", | |
| "In-Depth": "Provide explanatory depth by clarifying underlying causes, mechanisms, and context ('how' and 'why').", | |
| "Balanced View": "Present a balanced perspective on complex topics, acknowledging multiple significant viewpoints or counter-arguments.", | |
| "Actionable": "Provide clear, specific, and actionable steps." | |
| }, | |
| "π§Ή Hygiene & Clarity": { | |
| "Topic Focus": "Focus exclusively on the topic, eliminating irrelevant information, navigational links, and advertisements.", | |
| "Clear Language": "Use clear and concise language, avoiding jargon, ambiguity, and verbosity.", | |
| "Writing Quality": "Maintain high-quality writing, free from grammatical errors, typos, and formatting issues." | |
| } | |
| } | |
| # Flatten rules for easy processing | |
| ALL_RULES_FLAT = {k: v for cat in RULES_DB.values() for k, v in cat.items()} | |
| # --- 3. Helper Functions --- | |
| def build_autogeo_prompt(document: str) -> str: | |
| rules_text_list = list(ALL_RULES_FLAT.values()) | |
| rules_string = "- " + "\n- ".join(rules_text_list) | |
| user_prompt = f""" | |
| Here is the source: | |
| {document} | |
| You are given a website document as a source. This source, along with other sources, will be used by a language model (LLM) to generate answers to user questions, with each line in the generated answer being cited with its original source. Your task, as the owner of the source, is to **rewrite your document in a way that maximizes its visibility and impact in the LLM's final answer, ensuring your source is more likely to be quoted and cited**. | |
| You can regenerate the provided source so that it strictly adheres to the "Quality Guidelines", and you can also apply any other methods or techniques, as long as they help your rewritten source text rank higher in terms of relevance, authority, and impact in the LLM's generated answers. | |
| ## Quality Guidelines to Follow: | |
| {rules_string} | |
| """.strip() | |
| return user_prompt | |
| def generate_diff_html(text1, text2): | |
| d = difflib.Differ() | |
| diff = list(d.compare(text1.splitlines(), text2.splitlines())) | |
| html = '<div style="font-family: monospace; white-space: pre-wrap; height: 600px; overflow-y: scroll; border: 1px solid #e5e7eb; padding: 15px; border-radius: 8px; background-color: #fafafa;">' | |
| for line in diff: | |
| if line.startswith('+ '): | |
| html += f'<div style="background-color: #d1fae5; color: #065f46; padding: 2px 4px; border-radius: 2px;">{line[2:]}</div>' | |
| elif line.startswith('- '): | |
| html += f'<div style="background-color: #fee2e2; color: #b91c1c; text-decoration: line-through; opacity: 0.7; padding: 2px 4px; border-radius: 2px;">{line[2:]}</div>' | |
| elif line.startswith('? '): | |
| continue | |
| else: | |
| html += f'<div style="color: #374151; padding: 2px 4px;">{line[2:]}</div>' | |
| html += '</div>' | |
| return html | |
| def format_rules_for_display(): | |
| """Formats the rules dictionary into a Markdown string for the accordion.""" | |
| md_output = "" | |
| for category, rules in RULES_DB.items(): | |
| md_output += f"#### {category}\n" | |
| for name, desc in rules.items(): | |
| md_output += f"* **{name}:** {desc}\n" | |
| md_output += "\n" | |
| return md_output | |
| # --- 4. Main Processing Function --- | |
| def process_rewrite( | |
| raw_page, | |
| temp, max_tok, top_p | |
| ): | |
| if not raw_page or not raw_page.strip(): | |
| return "", "<div>Please enter text to optimize.</div>" | |
| # Note: Rules are now hardcoded in build_autogeo_prompt to use all rules | |
| prompt = build_autogeo_prompt(raw_page) | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_tok), | |
| do_sample=True, | |
| temperature=float(temp), | |
| top_p=float(top_p), | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_ids = output_ids[0][inputs["input_ids"].shape[-1]:] | |
| rewritten_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip().replace("**Rewritten Source: **", "") | |
| diff_html = generate_diff_html(raw_page, rewritten_text) | |
| return rewritten_text, diff_html | |
| except Exception as e: | |
| return f"Error during generation: {str(e)}", "<div>Error</div>" | |
| # --- 5. UI Construction --- | |
| def safe_read(filename, default=""): | |
| if os.path.exists(filename): | |
| with open(filename, "r", encoding='utf-8') as file: | |
| return file.read() | |
| return default | |
| EXAMPLE_TEXT = safe_read("assets/example1.txt", "Example Marketing Content...") | |
| EXAMPLE_UNSTRUCTURED = safe_read("assets/example2.txt", "Example Notes Content...") | |
| EXAMPLE_ACADEMIC = safe_read("assets/example3.txt", "Example Academic Content...") | |
| HEADER_MD = """ | |
| # π€ AutoGEO Studio | |
| ### Optimize Your Content for Generative Engines (GE) ([Project Page](https://zhongshsh.github.io/AutoGEO/) | [Paper](https://arxiv.org/abs/2510.11438) | [Code](https://github.com/cxcscmu/AutoGEO)) | |
| Welcome to the **AutoGEO<sub>Mini</sub>** demo. This tool utilizes a specialized model to rewrite web content, maximizing its visibility in Generative Engines (like ChatGPT Search, Perplexity, Gemini). | |
| π **Paper:** What Generative Search Engines Like and How to Optimize Web Content Cooperatively | |
| π₯ **Authors:** Yujiang Wu*, Shanshan Zhong*, Yubin Kim, Chenyan Xiong (*Equal contribution) | |
| """ | |
| FOOTER_MD = f""" | |
| --- | |
| #### Extract Custom Rules, Train Your Own AutoGEO Mini Model, or Use More Powerful AutoGEO<sub>API</sub> | |
| Explore more powerful AutoGEO on [AutoGEO Github](https://github.com/cxcscmu/AutoGEO). | |
| #### Current Model: AutoGEO<sub>Mini</sub> (ResearchyGEO) | |
| This demo runs on **[AutoGEO_mini_Qwen1.7B_ResearchyGEO](https://huggingface.co/cx-cmu/AutoGEO_mini_Qwen1.7B_ResearchyGEO)**, fine-tuned via reinforcement learning on [Gemini Researchy-GEO dataset](https://huggingface.co/datasets/cx-cmu/Researchy-GEO). | |
| * **Best for:** Academic papers, technical documentation, Wikipedia-style articles, and complex informational content. | |
| * **Focus:** Authority, citations, structure, and neutrality. | |
| #### Explore Other AutoGEO Models | |
| Looking for a different domain? Download our other specialized models on Hugging Face: | |
| * π **[AutoGEO<sub>Mini</sub> (E-commerce)](https://huggingface.co/cx-cmu/AutoGEO_mini_Qwen1.7B_Ecommerce)**: Optimized for product pages, shopping guides, and commercial content. | |
| * π **[AutoGEO<sub>Mini</sub> (GEOBench)](https://huggingface.co/cx-cmu/AutoGEO_mini_Qwen1.7B_GEOBench)**: Tuned for general knowledge queries and diverse web tasks. | |
| """ | |
| with gr.Blocks(title="AutoGEO Studio") as demo: | |
| gr.Markdown(HEADER_MD) | |
| with gr.Row(): | |
| # === Left Column: Input === | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Source Content") | |
| input_box = gr.Textbox( | |
| lines=20, | |
| placeholder="Paste original text here...", | |
| label="Input", | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("**Presets:**") | |
| ex_btn_1 = gr.Button("Example 1: Marketing") | |
| ex_btn_2 = gr.Button("Example 2: Academic") | |
| ex_btn_3 = gr.Button("Example 3: Notes") | |
| ex_btn_1.click(lambda: EXAMPLE_TEXT, outputs=input_box) | |
| ex_btn_2.click(lambda: EXAMPLE_ACADEMIC, outputs=input_box) | |
| ex_btn_3.click(lambda: EXAMPLE_UNSTRUCTURED, outputs=input_box) | |
| btn_submit = gr.Button("β¨ Optimize Content (Apply AutoGEO Rules)", variant="primary", size="lg") | |
| # === Right Column: Output === | |
| with gr.Column(scale=1): | |
| gr.Markdown("### AutoGEO Optimized Result") | |
| with gr.Tabs(): | |
| with gr.TabItem("π Diff View"): | |
| diff_output = gr.HTML(label="Difference") | |
| with gr.TabItem("π Raw Text"): | |
| raw_output = gr.Textbox(lines=20, label="Rewritten Text", show_label=False) | |
| # === Middle: View GEO Rules (Dropdown/Accordion - Read Only) === | |
| with gr.Row(): | |
| with gr.Accordion("βΉοΈ View Applied AutoGEO Rules", open=False): | |
| gr.Markdown("The model will automatically optimize your content based on the following guidelines:") | |
| gr.Markdown(format_rules_for_display()) | |
| # === Advanced Settings === | |
| with gr.Accordion("βοΈ Advanced Model Settings", open=False): | |
| with gr.Row(): | |
| temp_slider = gr.Slider(0.1, 1.5, value=0.7, label="Temperature") | |
| max_tok_slider = gr.Slider(256, 2048, value=1024, label="Max Tokens") | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.9, label="Top-P") | |
| # === Footer (Moved Content) === | |
| gr.Markdown(FOOTER_MD, elem_id="footer-content") | |
| # === Event Wiring === | |
| btn_submit.click( | |
| fn=process_rewrite, | |
| inputs=[ | |
| input_box, | |
| temp_slider, max_tok_slider, top_p_slider | |
| ], | |
| outputs=[raw_output, diff_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |