Skip to content

Fix generated evaluation consistency and preserve locality evidence#702

Open
BY-Elysia wants to merge 2 commits into
zjunlp:mainfrom
BY-Elysia:main
Open

Fix generated evaluation consistency and preserve locality evidence#702
BY-Elysia wants to merge 2 commits into
zjunlp:mainfrom
BY-Elysia:main

Conversation

@BY-Elysia

Copy link
Copy Markdown
Contributor

Summary

This PR fixes two evaluation issues:

  1. generate-text and LLM-judge previously returned inconsistent result structures, which could mix generated text with numeric accuracy fields.
  2. Locality post-processing removed the original pre/post outputs after computing locality_acc, preventing users from auditing how the score was produced.

Changes

Generated evaluation result consistency

  • Separate raw text generation from judge scoring.
  • Make the judge helper consistently return scores and generated text.
  • Store numeric values only in *_acc.
  • Store generated text only in *_gen_content.
  • Fix rewrite, rephrase, portability, and locality result structures.
  • Use the same locality post-processing for edit() and batch_edit().
  • Allow generate-text results to pass through summary_metrics() safely.

Locality evidence preservation

  • Preserve pre/post token outputs for standard text locality evaluation.
  • Preserve pre/post locality outputs in ConceptEditor.
  • Preserve generated pre/post text for generate-text and LLM-judge.
  • Preserve aligned top-1 token IDs for multimodal text locality.
  • Preserve aligned top-10 token IDs for multimodal image locality.
  • Continue removing full multimodal logits to avoid excessive result size and JSON serialization problems.
  • Keep the existing locality_acc calculations unchanged.

Validation

The full validation script is included below. It covers:

  • compileall
  • git diff --check
  • source-level result contract checks
  • generate-text rewrite/rephrase/portability structures
  • LLM-judge fallback score/text separation
  • generated-text locality pre/post comparison
  • preservation of teacher-forcing pre/post token outputs
  • preservation of ConceptEditor locality outputs
  • aligned and JSON-serializable multimodal top-k locality evidence
  • compatibility with summary_metrics()
  • real edit() and batch_edit() smoke tests

The real smoke validation was run with Qwen2.5-1.5B-Instruct and one-step FT. All checks passed:

compileall: pass
git_diff_check: pass
source_result_contract: pass
mock_result_contracts: pass
real_generate_text_edit: pass
real_llm_judge_fallback_batch_edit: pass
real_teacher_forcing_locality_evidence: pass
Full validation script
#!/usr/bin/env python3
import argparse
import compileall
import importlib
import json
import os
import subprocess
import sys
import tempfile
from pathlib import Path
from types import SimpleNamespace


REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))


CHECK_FILES = [
    "easyeditor/evaluate/evaluate_utils.py",
    "easyeditor/evaluate/evaluate.py",
    "easyeditor/editors/editor.py",
    "easyeditor/editors/concept_editor.py",
    "easyeditor/editors/multimodal_editor.py",
]


def record(results, name, status, detail=None):
    item = {"name": name, "status": status}
    if detail is not None:
        item["detail"] = detail
    results.append(item)


def to_builtin(value):
    if isinstance(value, dict):
        return {key: to_builtin(item) for key, item in value.items()}
    if isinstance(value, (list, tuple)):
        return [to_builtin(item) for item in value]
    if hasattr(value, "item"):
        return value.item()
    return value


def assert_text_list(value, field):
    assert isinstance(value, list), f"{field} must be a list, got {type(value).__name__}"
    assert all(isinstance(item, str) for item in value), (
        f"{field} must contain only strings: {value}"
    )


def assert_numeric_list(value, field):
    assert isinstance(value, list), f"{field} must be a list, got {type(value).__name__}"
    assert all(isinstance(to_builtin(item), (int, float)) for item in value), (
        f"{field} must contain only numbers: {value}"
    )


def run_compileall(results):
    failed = []
    for rel_path in CHECK_FILES:
        if not compileall.compile_file(str(REPO_ROOT / rel_path), quiet=1):
            failed.append(rel_path)
    if failed:
        raise AssertionError(f"compileall failed: {failed}")
    record(results, "compileall", "pass", CHECK_FILES)


def run_diff_check(results):
    proc = subprocess.run(
        ["git", "diff", "--check"],
        cwd=REPO_ROOT,
        text=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
    )
    if proc.returncode != 0:
        raise AssertionError(proc.stdout)
    record(results, "git_diff_check", "pass")


def run_source_contract_check(results):
    evaluate_utils_source = (
        REPO_ROOT / "easyeditor/evaluate/evaluate_utils.py"
    ).read_text()
    evaluate_source = (REPO_ROOT / "easyeditor/evaluate/evaluate.py").read_text()
    editor_source = (REPO_ROOT / "easyeditor/editors/editor.py").read_text()
    concept_editor_source = (
        REPO_ROOT / "easyeditor/editors/concept_editor.py"
    ).read_text()
    multimodal_editor_source = (
        REPO_ROOT / "easyeditor/editors/multimodal_editor.py"
    ).read_text()

    required_snippets = {
        "evaluate_utils.py": [
            "def generate_texts(",
            "return all_score, all_response",
        ],
        "evaluate.py": [
            'f"{portability_key}_gen_content"',
            'f"{locality_key}_gen_content"',
        ],
        "editor.py": [
            "def finalize_locality_metrics(",
            "summary_metrics(all_metrics)",
        ],
        "multimodal_editor.py": [
            "def get_aligned_topk_token_ids(",
            "locality_topk_tokens",
            "multimodal_locality_topk_tokens",
        ],
    }
    sources = {
        "evaluate_utils.py": evaluate_utils_source,
        "evaluate.py": evaluate_source,
        "editor.py": editor_source,
        "multimodal_editor.py": multimodal_editor_source,
    }
    missing = []
    for filename, snippets in required_snippets.items():
        for snippet in snippets:
            if snippet not in sources[filename]:
                missing.append({"file": filename, "snippet": snippet})
    if missing:
        raise AssertionError(f"missing expected result-contract code: {missing}")

    forbidden_locality_cleanup = {
        "editor.py": [
            'post_locality.pop(output_key)',
            'metric["pre"].pop("locality")',
        ],
        "concept_editor.py": [
            ".pop(f'{locality_key}_output')",
            ".pop('locality')",
        ],
    }
    cleanup_sources = {
        "editor.py": editor_source,
        "concept_editor.py": concept_editor_source,
    }
    cleanup_hits = []
    for filename, snippets in forbidden_locality_cleanup.items():
        for snippet in snippets:
            if snippet in cleanup_sources[filename]:
                cleanup_hits.append({"file": filename, "snippet": snippet})
    if cleanup_hits:
        raise AssertionError(
            f"found locality evidence cleanup code: {cleanup_hits}"
        )

    record(results, "source_result_contract", "pass")


def run_mock_result_contract_checks(results):
    evaluate_module = importlib.import_module("easyeditor.evaluate.evaluate")
    editor_module = importlib.import_module("easyeditor.editors.editor")
    multimodal_editor_module = importlib.import_module(
        "easyeditor.editors.multimodal_editor"
    )
    utils_module = importlib.import_module("easyeditor.editors.utils")
    import torch

    original_generate_texts = evaluate_module.generate_texts
    original_judge = evaluate_module.test_prediction_acc_LLM_judge
    evaluate_module.generate_texts = lambda *args, **kwargs: ["generated answer"]
    evaluate_module.test_prediction_acc_LLM_judge = (
        lambda *args, **kwargs: ([1.0], ["judged answer"])
    )

    generate_hparams = SimpleNamespace(
        evaluation_type="generate-text",
        alg_name="FT",
    )
    judge_hparams = SimpleNamespace(
        evaluation_type="LLM-judge",
        alg_name="FT",
        api_key="mock-key",
    )

    try:
        rewrite_generate = evaluate_module.compute_rewrite_or_rephrase_quality(
            None, "qwen", generate_hparams, None, "prompt", "target", 0
        )
        assert_text_list(
            rewrite_generate["rewrite_gen_content"],
            "rewrite_gen_content",
        )
        assert "rewrite_acc" not in rewrite_generate

        rephrase_generate = evaluate_module.compute_rewrite_or_rephrase_quality(
            None,
            "qwen",
            generate_hparams,
            None,
            "rephrase prompt",
            "target",
            0,
            test_rephrase=True,
        )
        assert_text_list(
            rephrase_generate["rephrase_gen_content"],
            "rephrase_gen_content",
        )
        assert "rephrase_acc" not in rephrase_generate

        rewrite_judge = evaluate_module.compute_rewrite_or_rephrase_quality(
            None, "qwen", judge_hparams, None, "prompt", "target", 0
        )
        assert_numeric_list(rewrite_judge["rewrite_acc"], "rewrite_acc")
        assert_text_list(
            rewrite_judge["rewrite_gen_content"],
            "rewrite_gen_content",
        )

        portability_generate = evaluate_module.compute_portability_quality(
            None,
            "qwen",
            generate_hparams,
            None,
            "one_hop",
            "prompt",
            "target",
            0,
        )
        assert_text_list(
            portability_generate["one_hop_gen_content"],
            "one_hop_gen_content",
        )
        assert "one_hop_acc" not in portability_generate

        portability_judge = evaluate_module.compute_portability_quality(
            None,
            "qwen",
            judge_hparams,
            None,
            "one_hop",
            "prompt",
            "target",
            0,
        )
        assert_numeric_list(
            portability_judge["one_hop_acc"],
            "one_hop_acc",
        )
        assert_text_list(
            portability_judge["one_hop_gen_content"],
            "one_hop_gen_content",
        )

        locality_generate = evaluate_module.compute_locality_quality(
            None,
            "qwen",
            generate_hparams,
            None,
            "neighborhood",
            "prompt",
            "target",
            0,
        )
        assert locality_generate == {
            "neighborhood_gen_content": ["generated answer"]
        }

        request = {"locality": {"neighborhood": {}}}
        same_metric = {
            "pre": {
                "locality": {
                    "neighborhood_gen_content": ["Berlin"],
                }
            },
            "post": {
                "locality": {
                    "neighborhood_gen_content": ["Berlin"],
                }
            },
        }
        editor_module.finalize_locality_metrics(
            same_metric,
            request,
            generate_hparams,
            "qwen",
        )
        assert same_metric["post"]["locality"]["neighborhood_acc"] == [1.0]

        changed_metric = {
            "pre": {
                "locality": {
                    "neighborhood_gen_content": ["Berlin"],
                }
            },
            "post": {
                "locality": {
                    "neighborhood_gen_content": ["Munich"],
                }
            },
        }
        editor_module.finalize_locality_metrics(
            changed_metric,
            request,
            judge_hparams,
            "qwen",
        )
        assert changed_metric["post"]["locality"]["neighborhood_acc"] == [0.0]

        token_metric = {
            "pre": {
                "locality": {
                    "neighborhood_output": [[10, 20]],
                }
            },
            "post": {
                "locality": {
                    "neighborhood_output": [[10, 99]],
                }
            },
        }
        editor_module.finalize_locality_metrics(
            token_metric,
            request,
            SimpleNamespace(alg_name="FT"),
            "qwen",
        )
        assert token_metric["pre"]["locality"][
            "neighborhood_output"
        ] == [[10, 20]]
        assert token_metric["post"]["locality"][
            "neighborhood_output"
        ] == [[10, 99]]
        assert token_metric["post"]["locality"][
            "neighborhood_acc"
        ] == [0.5]

        pre_text_logits = torch.tensor([[
            [0.1, 0.9, 0.2],
            [0.8, 0.1, 0.3],
            [0.2, 0.3, 0.7],
        ]])
        post_text_logits = torch.tensor([[
            [0.8, 0.1, 0.3],
            [0.7, 0.1, 0.2],
        ]])
        pre_image_logits = torch.arange(
            1 * 3 * 12,
            dtype=torch.float32,
        ).reshape(1, 3, 12)
        post_image_logits = pre_image_logits[:, -2:, :].clone()
        multimodal_metric = {
            "pre": {
                "locality_output": pre_text_logits,
                "multimodal_locality_output": pre_image_logits,
            },
            "post": {
                "locality_output": post_text_logits,
                "multimodal_locality_output": post_image_logits,
            },
        }
        multimodal_editor_module.attach_topk_locality_metrics(
            multimodal_metric
        )
        assert "locality_output" not in multimodal_metric["pre"]
        assert "locality_output" not in multimodal_metric["post"]
        assert "multimodal_locality_output" not in multimodal_metric["pre"]
        assert "multimodal_locality_output" not in multimodal_metric["post"]
        assert isinstance(
            multimodal_metric["pre"]["locality_topk_tokens"],
            list,
        )
        assert isinstance(
            multimodal_metric["post"]["locality_topk_tokens"],
            list,
        )
        assert isinstance(
            multimodal_metric["pre"][
                "multimodal_locality_topk_tokens"
            ],
            list,
        )
        assert isinstance(
            multimodal_metric["post"][
                "multimodal_locality_topk_tokens"
            ],
            list,
        )
        json.dumps(to_builtin(multimodal_metric))

        summary_input = [{
            "pre": {
                "rewrite_gen_content": ["before"],
                "locality": {
                    "neighborhood_gen_content": ["Berlin"],
                    "neighborhood_output": [[10, 20]],
                },
            },
            "post": {
                "rewrite_gen_content": ["after"],
                "portability": {
                    "one_hop_gen_content": ["Paris"],
                },
                "locality": {
                    "neighborhood_gen_content": ["Berlin"],
                    "neighborhood_output": [[10, 99]],
                    "neighborhood_acc": [1.0],
                },
            },
        }]
        old_cwd = Path.cwd()
        with tempfile.TemporaryDirectory() as temp_dir:
            os.chdir(temp_dir)
            try:
                utils_module.summary_metrics(summary_input)
            finally:
                os.chdir(old_cwd)
    finally:
        evaluate_module.generate_texts = original_generate_texts
        evaluate_module.test_prediction_acc_LLM_judge = original_judge

    record(
        results,
        "mock_result_contracts",
        "pass",
        {
            "generate_text": "*_gen_content contains text only",
            "llm_judge": "*_acc contains numbers and *_gen_content contains text",
            "locality": "pre/post generated text produces numeric locality_acc",
            "token_locality": "pre/post token outputs remain available",
            "multimodal_locality": "aligned top-k token evidence is JSON serializable",
            "concept_locality": "source no longer removes pre/post locality output",
            "summary": "generated text is ignored by numeric aggregation",
        },
    )


def build_eval_inputs():
    return {
        "prompts": "The internal validation generated evaluation key is",
        "target_new": " alpha",
        "ground_truth": None,
        "rephrase_prompts": "What is the internal generated evaluation key?",
        "locality_inputs": {
            "neighborhood": {
                "prompt": "The capital of France is",
                "ground_truth": " Paris",
            }
        },
        "portability_inputs": {
            "one_hop": {
                "prompt": "The internal generated portability key is",
                "ground_truth": " alpha",
            }
        },
        "sequential_edit": False,
        "verbose": False,
    }


def validate_real_metrics(metrics, mode):
    case = metrics[0]
    pre = case["pre"]
    post = case["post"]

    assert_numeric_list(
        post["locality"]["neighborhood_acc"],
        "post.locality.neighborhood_acc",
    )

    if mode in ["generate-text", "LLM-judge"]:
        for phase_name, phase in [("pre", pre), ("post", post)]:
            assert_text_list(
                phase["rewrite_gen_content"],
                f"{phase_name}.rewrite_gen_content",
            )
            assert_text_list(
                phase["rephrase_gen_content"],
                f"{phase_name}.rephrase_gen_content",
            )
            assert_text_list(
                phase["portability"]["one_hop_gen_content"],
                f"{phase_name}.portability.one_hop_gen_content",
            )
            assert_text_list(
                phase["locality"]["neighborhood_gen_content"],
                f"{phase_name}.locality.neighborhood_gen_content",
            )

    if mode == "generate-text":
        assert "rewrite_acc" not in pre and "rewrite_acc" not in post
        assert "rephrase_acc" not in pre and "rephrase_acc" not in post
        assert "one_hop_acc" not in pre["portability"]
        assert "one_hop_acc" not in post["portability"]
    elif mode == "LLM-judge":
        assert_numeric_list(pre["rewrite_acc"], "pre.rewrite_acc")
        assert_numeric_list(post["rewrite_acc"], "post.rewrite_acc")
        assert_numeric_list(pre["rephrase_acc"], "pre.rephrase_acc")
        assert_numeric_list(post["rephrase_acc"], "post.rephrase_acc")
        assert_numeric_list(
            pre["portability"]["one_hop_acc"],
            "pre.portability.one_hop_acc",
        )
        assert_numeric_list(
            post["portability"]["one_hop_acc"],
            "post.portability.one_hop_acc",
        )
    else:
        assert_numeric_list(pre["rewrite_acc"], "pre.rewrite_acc")
        assert_numeric_list(post["rewrite_acc"], "post.rewrite_acc")
        assert_numeric_list(pre["rephrase_acc"], "pre.rephrase_acc")
        assert_numeric_list(post["rephrase_acc"], "post.rephrase_acc")
        assert_numeric_list(
            pre["portability"]["one_hop_acc"],
            "pre.portability.one_hop_acc",
        )
        assert_numeric_list(
            post["portability"]["one_hop_acc"],
            "post.portability.one_hop_acc",
        )
        pre_output = pre["locality"]["neighborhood_output"]
        post_output = post["locality"]["neighborhood_output"]
        assert isinstance(pre_output, list), pre_output
        assert isinstance(post_output, list), post_output

    return {
        "mode": mode,
        "pre": {
            "rewrite_keys": sorted(
                key for key in pre if key.startswith("rewrite_")
            ),
            "rephrase_keys": sorted(
                key for key in pre if key.startswith("rephrase_")
            ),
            "portability_keys": sorted(pre["portability"]),
            "locality_keys": sorted(pre["locality"]),
            "locality_evidence": to_builtin(
                pre["locality"].get(
                    "neighborhood_gen_content",
                    pre["locality"].get("neighborhood_output"),
                )
            ),
        },
        "post": {
            "rewrite_keys": sorted(
                key for key in post if key.startswith("rewrite_")
            ),
            "rephrase_keys": sorted(
                key for key in post if key.startswith("rephrase_")
            ),
            "portability_keys": sorted(post["portability"]),
            "locality_keys": sorted(post["locality"]),
            "locality_acc": to_builtin(
                post["locality"]["neighborhood_acc"]
            ),
            "locality_evidence": to_builtin(
                post["locality"].get(
                    "neighborhood_gen_content",
                    post["locality"].get("neighborhood_output"),
                )
            ),
        },
    }


def run_model_smoke(results, args):
    from easyeditor import BaseEditor, FTHyperParams

    hparams = FTHyperParams.from_hparams(args.hparams)
    hparams.model_name = args.model
    hparams.device = args.device
    hparams.layers = [args.layer]
    hparams.num_steps = args.num_steps
    hparams.batch_size = 1

    editor = BaseEditor.from_hparams(hparams)

    hparams.evaluation_type = "generate-text"
    hparams.api_key = ""
    generate_metrics, _, _ = editor.edit(**build_eval_inputs())
    generate_detail = validate_real_metrics(
        generate_metrics,
        "generate-text",
    )
    record(
        results,
        "real_generate_text_edit",
        "pass",
        generate_detail,
    )

    hparams.evaluation_type = "LLM-judge"
    hparams.api_key = ""
    batch_inputs = build_eval_inputs()
    batch_inputs["prompts"] = [batch_inputs["prompts"]]
    batch_inputs["target_new"] = [batch_inputs["target_new"]]
    batch_inputs["ground_truth"] = [None]
    batch_inputs["rephrase_prompts"] = [batch_inputs["rephrase_prompts"]]
    for section in ["locality_inputs", "portability_inputs"]:
        for item in batch_inputs[section].values():
            item["prompt"] = [item["prompt"]]
            item["ground_truth"] = [item["ground_truth"]]

    judge_metrics, _, _ = editor.batch_edit(**batch_inputs)
    judge_detail = validate_real_metrics(
        judge_metrics,
        "LLM-judge",
    )
    record(
        results,
        "real_llm_judge_fallback_batch_edit",
        "pass",
        judge_detail,
    )

    if hasattr(hparams, "evaluation_type"):
        delattr(hparams, "evaluation_type")
    if hasattr(hparams, "api_key"):
        delattr(hparams, "api_key")
    token_metrics, _, _ = editor.edit(**build_eval_inputs())
    token_detail = validate_real_metrics(
        token_metrics,
        "teacher-forcing",
    )
    record(
        results,
        "real_teacher_forcing_locality_evidence",
        "pass",
        token_detail,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default=os.environ.get("MODEL"))
    parser.add_argument("--hparams", default="hparams/FT/qwen2.5-7b.yaml")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--layer", type=int, default=0)
    parser.add_argument("--num-steps", type=int, default=1)
    parser.add_argument(
        "--out",
        default="outputs/generated_evaluation_consistency_validation.json",
    )
    parser.add_argument(
        "--skip-model-smoke",
        action="store_true",
        help="Run only static and mock checks.",
    )
    args = parser.parse_args()

    results = []
    run_compileall(results)
    run_diff_check(results)
    run_source_contract_check(results)
    run_mock_result_contract_checks(results)

    if args.skip_model_smoke or not args.model:
        record(
            results,
            "real_model_smoke",
            "skipped",
            "Pass --model or set MODEL to run real Qwen validation.",
        )
    else:
        run_model_smoke(results, args)

    payload = {
        "repo": str(REPO_ROOT),
        "commit": subprocess.check_output(
            ["git", "rev-parse", "--short", "HEAD"],
            cwd=REPO_ROOT,
            text=True,
        ).strip(),
        "results": to_builtin(results),
    }

    out_path = REPO_ROOT / args.out
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(
        json.dumps(payload, ensure_ascii=False, indent=2) + "\n"
    )
    print(json.dumps(payload, ensure_ascii=False, indent=2))
    print(f"PASS: validation written to {out_path}")


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant