import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "cx-cmu/AutoGEO_mini_Qwen1.7B" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, trust_remote_code=True ) model.eval() DEFAULT_RULES_LIST = [ "Attribute all factual claims to credible, authoritative sources with clear citations.", "Cover the topic comprehensively, addressing all key aspects and sub-topics.", "Ensure information is factually accurate and verifiable.", "Focus exclusively on the topic, eliminating irrelevant information, navigational links, and advertisements.", "Maintain a neutral, objective tone, avoiding promotional language, personal opinions, and bias.", "Maintain high-quality writing, free from grammatical errors, typos, and formatting issues.", "Present a balanced perspective on complex topics, acknowledging multiple significant viewpoints or counter-arguments.", "Present information as a self-contained unit, not requiring external links for core understanding.", "Provide clear, specific, and actionable steps.", "Provide explanatory depth by clarifying underlying causes, mechanisms, and context ('how' and 'why').", "State the key conclusion at the beginning of the document.", "Structure content logically with clear headings, lists, and paragraphs to ensure a cohesive flow.", "Substantiate claims with specific, concrete details like data, statistics, or named examples.", "Use clear and concise language, avoiding jargon, ambiguity, and verbosity.", "Use current information, reflecting the latest state of knowledge." ] DEFAULT_RULES_TEXT = "\n".join(DEFAULT_RULES_LIST) def build_autogeo_prompt(summary: str, rules_text: str | None) -> str: if rules_text is not None and rules_text.strip(): rules = [line.strip() for line in rules_text.splitlines() if line.strip()] else: rules = DEFAULT_RULES_LIST rules_string = "- " + "\n- ".join(rules) user_prompt = f""" Here is the source: {summary} You are given a website document as a source. This source, along with other sources, will be used by a language model (LLM) to generate answers to user questions, with each line in the generated answer being cited with its original source. Your task, as the owner of the source, is to rewrite your document in a way that maximizes its visibility and impact in the LLM's final answer, ensuring your source is more likely to be quoted and cited. You can regenerate the provided source so that it strictly adheres to the "Quality Guidelines", and you can also apply any other methods or techniques, as long as they help your rewritten source text rank higher in terms of relevance, authority, and impact in the LLM's generated answers. ## Quality Guidelines to Follow: {rules_string} Now rewrite the source accordingly. """.strip() return user_prompt @torch.no_grad() def rewrite_document( raw_page: str, rules_text: str, temperature: float = 0.7, max_new_tokens: int = 2048, top_p: float = 0.9 ) -> str: if not raw_page.strip(): return "Please paste the original web page content in the input box." prompt = build_autogeo_prompt(raw_page, rules_text) inputs = tokenizer( prompt, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items()} output_ids = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), pad_token_id=tokenizer.eos_token_id ) generated_ids = output_ids[0][inputs["input_ids"].shape[-1]:] text = tokenizer.decode(generated_ids, skip_special_tokens=True) return text.strip() with gr.Blocks(title="AutoGEO Mini Rewriting Demo") as demo: gr.Markdown( """ # AutoGEO Mini Rewriting Demo Paste an original web page/document on the left. Optionally customize the rewriting rules in the middle. The rewritten document will appear on the right. If the rules area is left empty, the demo will use the default rule set **extracted on the Researchy-GEO dataset with Gemini-2.5-Flash-Lite as the generative engine**. """ ) with gr.Row(equal_height=True): with gr.Column(scale=4): gr.Markdown("### 1️⃣ Original web page content") input_box = gr.Textbox( lines=22, label="", placeholder="Paste the original web page HTML/text here...", show_label=False ) with gr.Column(scale=3): gr.Markdown( """ ### 2️⃣ Rewriting rules (one rule per line) - You can edit, add, or remove rules below. - If you clear this box and leave it empty, the default AutoGEO rule set will be used. """ ) rules_box = gr.Textbox( value=DEFAULT_RULES_TEXT, lines=22, label="Custom rules (optional)", placeholder="One rule per line. Leave empty to use the default Researchy-GEO rule set." ) with gr.Column(scale=4): gr.Markdown("### 3️⃣ Rewritten document") output_box = gr.Textbox( lines=22, label="", placeholder="Model output will appear here.", show_label=False ) gr.Markdown("---") with gr.Row(): with gr.Column(scale=3): temperature_slider = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature" ) with gr.Column(scale=3): max_tokens_slider = gr.Slider( minimum=256, maximum=2048, value=1024, step=64, label="Max new tokens" ) with gr.Column(scale=3): top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p" ) with gr.Column(scale=2, min_width=120): submit_btn = gr.Button( "Rewrite with AutoGEO Mini", variant="primary" ) submit_btn.click( fn=rewrite_document, inputs=[input_box, rules_box, temperature_slider, max_tokens_slider, top_p_slider], outputs=[output_box] ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())