From 4ecf1576dfbd279bb1b8548d1ad01adccf63fe33 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 21 Mar 2026 19:40:21 +0000 Subject: [PATCH 1/5] Make CrossScore a pip-installable package Restructure the project into a proper Python package so users can install with `pip install crossscore` instead of manually configuring a conda env. Key changes: - Move model/, dataloading/, utils/ into crossscore/ package with proper __init__.py files and absolute imports (no more sys.path.append hacks) - Add pyproject.toml with flexible dependency ranges (torch>=2.0.0, not pinned) so users install PyTorch+CUDA separately per their system - Add crossscore.score() high-level API for simple inference - Add `crossscore` CLI entry point - Add auto-download of model checkpoint from HuggingFace Hub - Update task/ entry points to import from crossscore package - Update README with pip install instructions and quick start guide https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB --- .gitignore | 9 + README.md | 68 ++- crossscore/__init__.py | 37 ++ crossscore/_download.py | 71 +++ crossscore/api.py | 206 +++++++ crossscore/cli.py | 89 +++ crossscore/config/data/SimpleReference.yaml | 22 + crossscore/config/default_predict.yaml | 50 ++ crossscore/config/model/model.yaml | 32 ++ crossscore/dataloading/__init__.py | 0 .../dataloading}/data_manager.py | 2 +- crossscore/dataloading/dataset/__init__.py | 0 .../dataloading}/dataset/nvs_dataset.py | 7 +- .../dataloading}/dataset/simple_reference.py | 7 +- .../dataloading/transformation/__init__.py | 0 .../dataloading}/transformation/crop.py | 0 crossscore/model/__init__.py | 0 .../model}/cross_reference.py | 2 +- .../model/customised_transformer/__init__.py | 0 .../customised_transformer/transformer.py | 0 .../model}/positional_encoding.py | 0 .../model}/regression_layer.py | 3 +- crossscore/task/__init__.py | 0 crossscore/task/core.py | 513 +++++++++++++++++ crossscore/utils/__init__.py | 0 {utils => crossscore/utils}/check_config.py | 0 .../split_gaussian_processed.py | 0 crossscore/utils/evaluation/__init__.py | 0 .../utils}/evaluation/metric.py | 0 .../utils}/evaluation/metric_logger.py | 0 .../utils}/evaluation/summarise_score_gt.py | 3 +- crossscore/utils/io/__init__.py | 0 .../utils}/io/batch_writer.py | 4 +- {utils => crossscore/utils}/io/images.py | 0 .../utils}/io/score_summariser.py | 4 +- crossscore/utils/misc/__init__.py | 0 {utils => crossscore/utils}/misc/image.py | 2 +- crossscore/utils/neighbour/__init__.py | 0 .../utils}/neighbour/sampler.py | 0 crossscore/utils/plot/__init__.py | 0 .../utils}/plot/batch_visualiser.py | 6 +- pyproject.toml | 100 ++++ task/core.py | 515 +----------------- task/predict.py | 11 +- task/test.py | 11 +- task/train.py | 13 +- 46 files changed, 1227 insertions(+), 560 deletions(-) create mode 100644 crossscore/__init__.py create mode 100644 crossscore/_download.py create mode 100644 crossscore/api.py create mode 100644 crossscore/cli.py create mode 100644 crossscore/config/data/SimpleReference.yaml create mode 100644 crossscore/config/default_predict.yaml create mode 100644 crossscore/config/model/model.yaml create mode 100644 crossscore/dataloading/__init__.py rename {dataloading => crossscore/dataloading}/data_manager.py (95%) create mode 100644 crossscore/dataloading/dataset/__init__.py rename {dataloading => crossscore/dataloading}/dataset/nvs_dataset.py (99%) rename {dataloading => crossscore/dataloading}/dataset/simple_reference.py (96%) create mode 100644 crossscore/dataloading/transformation/__init__.py rename {dataloading => crossscore/dataloading}/transformation/crop.py (100%) create mode 100644 crossscore/model/__init__.py rename {model => crossscore/model}/cross_reference.py (98%) create mode 100644 crossscore/model/customised_transformer/__init__.py rename {model => crossscore/model}/customised_transformer/transformer.py (100%) rename {model => crossscore/model}/positional_encoding.py (100%) rename {model => crossscore/model}/regression_layer.py (95%) create mode 100644 crossscore/task/__init__.py create mode 100644 crossscore/task/core.py create mode 100644 crossscore/utils/__init__.py rename {utils => crossscore/utils}/check_config.py (100%) rename {utils => crossscore/utils}/data_processing/split_gaussian_processed.py (100%) create mode 100644 crossscore/utils/evaluation/__init__.py rename {utils => crossscore/utils}/evaluation/metric.py (100%) rename {utils => crossscore/utils}/evaluation/metric_logger.py (100%) rename {utils => crossscore/utils}/evaluation/summarise_score_gt.py (91%) create mode 100644 crossscore/utils/io/__init__.py rename {utils => crossscore/utils}/io/batch_writer.py (98%) rename {utils => crossscore/utils}/io/images.py (100%) rename {utils => crossscore/utils}/io/score_summariser.py (99%) create mode 100644 crossscore/utils/misc/__init__.py rename {utils => crossscore/utils}/misc/image.py (98%) create mode 100644 crossscore/utils/neighbour/__init__.py rename {utils => crossscore/utils}/neighbour/sampler.py (100%) create mode 100644 crossscore/utils/plot/__init__.py rename {utils => crossscore/utils}/plot/batch_visualiser.py (99%) create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index de52713..ffb57d3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,12 @@ tmp.sh log* datadir* debug* + +# Package build artifacts +dist/ +build/ +*.egg-info/ +*.egg + +# Plan file +plan.md diff --git a/README.md b/README.md index 7445e99..e32aed9 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,73 @@ University of Oxford. ## Table of Content -- [Environment](#Environment) +- [Installation (pip)](#installation-pip) +- [Quick Start](#quick-start) +- [Environment (conda, legacy)](#environment-conda-legacy) - [Data](#Data) - [Training](#Training) - [Inferencing](#Inferencing) -## Environment -We provide a `environment.yaml` file to set up a `conda` environment: +## Installation (pip) + +**Step 1**: Install PyTorch with your preferred CUDA version (see [pytorch.org](https://pytorch.org/get-started/locally/)): +```bash +# Example: PyTorch with CUDA 12.1 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + +# Example: PyTorch with CUDA 11.8 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 + +# Example: CPU only +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +``` + +**Step 2**: Install CrossScore: +```bash +pip install crossscore +``` + +Or install from source: +```bash +git clone https://github.com/ActiveVisionLab/CrossScore.git +cd CrossScore +pip install -e . +``` + +> **Why install PyTorch separately?** PyTorch distributions are coupled with specific CUDA versions. By letting you install PyTorch first, we avoid version conflicts with your system's CUDA setup. CrossScore works with PyTorch 2.0+ and any CUDA version it supports. + +## Quick Start + +### Python API +```python +import crossscore + +# Score query images against reference images +results = crossscore.score( + query_dir="path/to/query/images", + reference_dir="path/to/reference/images", +) + +# The model checkpoint is auto-downloaded on first use (~129MB) +# Score maps are written to disk and returned as tensors +for score_map in results["score_maps"]: + print(score_map.shape) # (batch_size, H, W) +``` + +### Command Line +```bash +crossscore --query-dir path/to/queries --reference-dir path/to/references + +# With options +crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 +``` + +### Environment Variables +- `CROSSSCORE_CKPT_PATH`: Use a specific local checkpoint instead of auto-downloading +- `CROSSSCORE_CACHE_DIR`: Custom cache directory (default: `~/.cache/crossscore`) + +## Environment (conda, legacy) +We also provide a `environment.yaml` file to set up a `conda` environment: ```bash git clone https://github.com/ActiveVisionLab/CrossScore.git cd CrossScore @@ -86,7 +146,7 @@ on our project page. - [ ] Create a HuggingFace demo page. - [ ] Release ECCV quantitative results related scripts. - [x] Release [data processing scripts](https://github.com/ziruiw-dev/CrossScore-3DGS-Preprocessing) -- [ ] Release PyPI and Conda package. +- [x] Release PyPI package. ## Acknowledgement This research is supported by an diff --git a/crossscore/__init__.py b/crossscore/__init__.py new file mode 100644 index 0000000..4b22ec9 --- /dev/null +++ b/crossscore/__init__.py @@ -0,0 +1,37 @@ +"""CrossScore: Towards Multi-View Image Evaluation and Scoring. + +A pip-installable package for neural image quality assessment using +cross-reference scoring with DINOv2 backbone. + +Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) +""" + +__version__ = "1.0.0" + + +def score(*args, **kwargs): + """Score query images against reference images using CrossScore. + + See crossscore.api.score for full documentation. + """ + from crossscore.api import score as _score + + return _score(*args, **kwargs) + + +def get_checkpoint_path(): + """Get path to the CrossScore checkpoint, downloading if necessary. + + See crossscore._download.get_checkpoint_path for full documentation. + """ + from crossscore._download import get_checkpoint_path as _get + + return _get() + + +__all__ = ["score", "get_checkpoint_path"] diff --git a/crossscore/_download.py b/crossscore/_download.py new file mode 100644 index 0000000..134fa5d --- /dev/null +++ b/crossscore/_download.py @@ -0,0 +1,71 @@ +"""Utilities for downloading CrossScore model checkpoints.""" + +import os +from pathlib import Path + +CHECKPOINT_URL = ( + "https://huggingface.co/ActiveVisionLab/CrossScore/resolve/main/CrossScore-v1.0.0.ckpt" +) +CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt" + + +def get_cache_dir() -> Path: + """Return the cache directory for CrossScore model checkpoints.""" + cache_dir = Path(os.environ.get("CROSSSCORE_CACHE_DIR", Path.home() / ".cache" / "crossscore")) + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def get_checkpoint_path() -> str: + """Get path to the CrossScore checkpoint, downloading it if necessary. + + Downloads from HuggingFace Hub on first use and caches locally. + Set CROSSSCORE_CACHE_DIR environment variable to customize cache location. + Set CROSSSCORE_CKPT_PATH to use a specific local checkpoint file. + + Returns: + Path to the checkpoint file. + """ + # Allow user to override with a custom path + custom_path = os.environ.get("CROSSSCORE_CKPT_PATH") + if custom_path: + if not Path(custom_path).exists(): + raise FileNotFoundError(f"Checkpoint not found at CROSSSCORE_CKPT_PATH={custom_path}") + return custom_path + + cache_dir = get_cache_dir() + ckpt_path = cache_dir / CHECKPOINT_FILENAME + + if ckpt_path.exists(): + return str(ckpt_path) + + print(f"Downloading CrossScore checkpoint to {ckpt_path}...") + print(f" Source: {CHECKPOINT_URL}") + print(" (Set CROSSSCORE_CKPT_PATH to use a local checkpoint instead)") + + try: + from huggingface_hub import hf_hub_download + + downloaded_path = hf_hub_download( + repo_id="ActiveVisionLab/CrossScore", + filename=CHECKPOINT_FILENAME, + local_dir=str(cache_dir), + ) + return downloaded_path + except ImportError: + # Fallback to urllib if huggingface_hub not installed + import urllib.request + import shutil + + tmp_path = str(ckpt_path) + ".tmp" + try: + with urllib.request.urlopen(CHECKPOINT_URL) as response, open(tmp_path, "wb") as out: + shutil.copyfileobj(response, out) + os.rename(tmp_path, str(ckpt_path)) + except Exception: + if os.path.exists(tmp_path): + os.remove(tmp_path) + raise + + print(f" Download complete: {ckpt_path}") + return str(ckpt_path) diff --git a/crossscore/api.py b/crossscore/api.py new file mode 100644 index 0000000..e197cd6 --- /dev/null +++ b/crossscore/api.py @@ -0,0 +1,206 @@ +"""High-level API for CrossScore image quality assessment.""" + +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import v2 as T +from omegaconf import OmegaConf + +from crossscore._download import get_checkpoint_path +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.dataloading.dataset.simple_reference import SimpleReference +from crossscore.dataloading.transformation.crop import CropperFactory + + +def _build_config( + metric_type: str = "ssim", + metric_min: int = 0, + metric_max: int = 1, + batch_size: int = 8, + num_workers: int = 4, + resize_short_side: int = 518, + devices: Optional[list] = None, + out_dir: Optional[str] = None, +) -> OmegaConf: + """Build an OmegaConf config object for prediction.""" + if devices is None: + devices = [0] if torch.cuda.is_available() else [] + + config_dir = Path(__file__).parent / "config" + # Load base configs + base_cfg = OmegaConf.load(config_dir / "default_predict.yaml") + model_cfg = OmegaConf.load(config_dir / "model" / "model.yaml") + data_cfg = OmegaConf.load(config_dir / "data" / "SimpleReference.yaml") + + # Merge model config into base + base_cfg.model = model_cfg + base_cfg.data = data_cfg + + # Apply overrides + base_cfg.model.predict.metric.type = metric_type + base_cfg.model.predict.metric.min = metric_min + base_cfg.model.predict.metric.max = metric_max + base_cfg.data.loader.validation.batch_size = batch_size + base_cfg.data.loader.validation.num_workers = num_workers + base_cfg.trainer.devices = devices + base_cfg.trainer.precision = "16-mixed" + base_cfg.trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu" + + if out_dir is not None: + base_cfg.logger.predict.out_dir = out_dir + + return base_cfg + + +def score( + query_dir: str, + reference_dir: str, + ckpt_path: Optional[str] = None, + metric_type: str = "ssim", + batch_size: int = 8, + num_workers: int = 4, + resize_short_side: int = 518, + devices: Optional[list] = None, + out_dir: Optional[str] = None, + write_outputs: bool = True, +) -> dict: + """Score query images against reference images using CrossScore. + + Args: + query_dir: Directory containing query images (e.g., NVS rendered images). + reference_dir: Directory containing reference images (e.g., real captured images). + ckpt_path: Path to model checkpoint. Auto-downloads if not provided. + metric_type: Metric type to predict. One of "ssim", "mae", "mse". + batch_size: Batch size for inference. + num_workers: Number of data loading workers. + resize_short_side: Resize images so short side equals this value. Set to -1 to disable. + devices: List of GPU device indices. Defaults to [0] if CUDA available. + out_dir: Output directory for score maps. Defaults to a timestamped directory + under the checkpoint's parent. + write_outputs: Whether to write score maps and visualizations to disk. + + Returns: + Dictionary with: + - "score_maps": List of predicted score map tensors + - "out_dir": Output directory path (if write_outputs=True) + + Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) + """ + import lightning + from datetime import datetime + from crossscore.task.core import CrossScoreLightningModule + + # Get checkpoint + if ckpt_path is None: + ckpt_path = get_checkpoint_path() + + # Build config + metric_min = -1 if metric_type == "ssim" else 0 + # For SSIM, CrossScore predicts in [0, 1] by default (the common sub-range) + if metric_type == "ssim": + metric_min = 0 + + cfg = _build_config( + metric_type=metric_type, + metric_min=metric_min, + batch_size=batch_size, + num_workers=num_workers, + resize_short_side=resize_short_side, + devices=devices, + out_dir=out_dir, + ) + + # Set checkpoint path + cfg.trainer.ckpt_path_to_load = ckpt_path + + # Determine output directory + if cfg.logger.predict.out_dir is None: + now = datetime.now().strftime("%Y%m%d_%H%M%S.%f") + log_dir = Path(ckpt_path).parents[1] if Path(ckpt_path).parent.name == "ckpt" else Path(".") + cfg.logger.predict.out_dir = str(log_dir / "predict" / now) + + if not write_outputs: + cfg.logger.predict.write.flag.batch = False + cfg.logger.predict.write.config.vis_img_every_n_steps = -1 + + # Set up data + lightning.seed_everything(cfg.lightning.seed, workers=True) + + img_norm_stat = ImageNetMeanStd() + transforms = { + "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), + } + + if resize_short_side > 0: + transforms["resize"] = T.Resize( + resize_short_side, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + ) + + dataset = SimpleReference( + query_dir=query_dir, + reference_dir=reference_dir, + transforms=transforms, + neighbour_config=cfg.data.neighbour_config, + return_item_paths=True, + zero_reference=cfg.data.dataset.zero_reference, + ) + + dataloader = DataLoader( + dataset, + batch_size=cfg.data.loader.validation.batch_size, + shuffle=False, + num_workers=cfg.data.loader.validation.num_workers, + pin_memory=True, + persistent_workers=False, + ) + + # Build model and trainer + model = CrossScoreLightningModule(cfg) + + NUM_GPUS = len(cfg.trainer.devices) + if NUM_GPUS > 1: + from lightning.pytorch.strategies import DDPStrategy + strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) + use_distributed_sampler = True + else: + strategy = "auto" + use_distributed_sampler = False + + trainer = lightning.Trainer( + accelerator=cfg.trainer.accelerator, + devices=cfg.trainer.devices, + precision=cfg.trainer.precision, + strategy=strategy, + use_distributed_sampler=use_distributed_sampler, + logger=False, + ) + + # Run prediction + with torch.no_grad(): + predictions = trainer.predict( + model, + dataloader, + ckpt_path=ckpt_path, + ) + + # Collect results + score_maps = [] + if predictions: + for batch_output in predictions: + if "score_map_ref_cross" in batch_output: + score_maps.append(batch_output["score_map_ref_cross"].cpu()) + + results = {"score_maps": score_maps} + if write_outputs: + results["out_dir"] = cfg.logger.predict.out_dir + + return results diff --git a/crossscore/cli.py b/crossscore/cli.py new file mode 100644 index 0000000..c311fac --- /dev/null +++ b/crossscore/cli.py @@ -0,0 +1,89 @@ +"""Command-line interface for CrossScore.""" + +import argparse +import sys + + +def main(): + parser = argparse.ArgumentParser( + description="CrossScore: Multi-View Image Quality Assessment", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""\ +Examples: + crossscore --query-dir path/to/queries --reference-dir path/to/references + crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + crossscore --query-dir renders/ --reference-dir gt/ --ckpt-path my_model.ckpt +""", + ) + parser.add_argument( + "--query-dir", required=True, help="Directory containing query images" + ) + parser.add_argument( + "--reference-dir", required=True, help="Directory containing reference images" + ) + parser.add_argument( + "--ckpt-path", + default=None, + help="Path to model checkpoint (auto-downloads if not provided)", + ) + parser.add_argument( + "--metric-type", + default="ssim", + choices=["ssim", "mae", "mse"], + help="Metric type to predict (default: ssim)", + ) + parser.add_argument( + "--batch-size", type=int, default=8, help="Batch size (default: 8)" + ) + parser.add_argument( + "--num-workers", type=int, default=4, help="Data loading workers (default: 4)" + ) + parser.add_argument( + "--resize-short-side", + type=int, + default=518, + help="Resize short side to this value, -1 to disable (default: 518)", + ) + parser.add_argument( + "--devices", + type=int, + nargs="+", + default=None, + help="GPU device indices (default: [0])", + ) + parser.add_argument( + "--out-dir", + default=None, + help="Output directory for results (default: auto-generated)", + ) + parser.add_argument( + "--no-write", + action="store_true", + help="Do not write output files to disk", + ) + + args = parser.parse_args() + + from crossscore.api import score + + results = score( + query_dir=args.query_dir, + reference_dir=args.reference_dir, + ckpt_path=args.ckpt_path, + metric_type=args.metric_type, + batch_size=args.batch_size, + num_workers=args.num_workers, + resize_short_side=args.resize_short_side, + devices=args.devices, + out_dir=args.out_dir, + write_outputs=not args.no_write, + ) + + n_maps = sum(s.shape[0] for s in results["score_maps"]) if results["score_maps"] else 0 + print(f"\nCrossScore completed: {n_maps} score maps generated") + if "out_dir" in results and results["out_dir"]: + print(f"Results written to: {results['out_dir']}") + + +if __name__ == "__main__": + main() diff --git a/crossscore/config/data/SimpleReference.yaml b/crossscore/config/data/SimpleReference.yaml new file mode 100644 index 0000000..0f86c52 --- /dev/null +++ b/crossscore/config/data/SimpleReference.yaml @@ -0,0 +1,22 @@ +loader: + validation: + batch_size: 8 + num_workers: 8 + shuffle: True + pin_memory: True + persistent_workers: False + prefetch_factor: 2 + +dataset: + query_dir: null + reference_dir: null + resolution: res_540 + zero_reference: False + +neighbour_config: + strategy: random + cross: 5 + deterministic: False + +transforms: + crop_size: 518 diff --git a/crossscore/config/default_predict.yaml b/crossscore/config/default_predict.yaml new file mode 100644 index 0000000..80e1f99 --- /dev/null +++ b/crossscore/config/default_predict.yaml @@ -0,0 +1,50 @@ +defaults: + - _self_ # overriding order, see https://hydra.cc/docs/tutorials/structured_config/defaults/#a-note-about-composition-order + - data: SimpleReference + - model: model + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +hydra: + output_subdir: null + run: + dir: . + +lightning: + seed: 1 + +project: + name: CrossScore + +alias: "" + +trainer: + accelerator: gpu + devices: [0] + # devices: [0, 1] + precision: 16-mixed + + limit_test_batches: 1.0 + ckpt_path_to_load: null + +logger: + predict: + out_dir: null # if null, use ckpt dir + write: + flag: + batch: True + score_map_prediction: True + item_path_json: False + score_map_gt: False + attn_weights: False + image_query: True + image_reference: True + config: + vis_img_every_n_steps: 1 # -1: off + score_map_colour_mode: rgb # gray or rgb, use rgb for vis + +this_main: + resize_short_side: 518 # set to -1 to disable + crop_mode: null # default no crop + + force_batch_size: False diff --git a/crossscore/config/model/model.yaml b/crossscore/config/model/model.yaml new file mode 100644 index 0000000..06fc8d6 --- /dev/null +++ b/crossscore/config/model/model.yaml @@ -0,0 +1,32 @@ +patch_size: 14 +do_reference_cross: True + +decoder_do_self_attn: True +decoder_do_short_cut: True +need_attn_weights: False # def False, requires more gpu mem if True +need_attn_weights_head_id: 0 # check which attn head + +backbone: + from_pretrained: facebook/dinov2-small + +pos_enc: + multi_view: + interpolate_mode: bilinear + req_grad: False + h: 40 # def 40 so we always interpolate in training, could be 37 or 16 too. + w: 40 + +loss: + fn: l1 + +predict: + metric: + type: ssim + # type: mae + # type: mse + + # min: -1 + min: 0 + max: 1 + + power_factor: default # can be a scalar \ No newline at end of file diff --git a/crossscore/dataloading/__init__.py b/crossscore/dataloading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloading/data_manager.py b/crossscore/dataloading/data_manager.py similarity index 95% rename from dataloading/data_manager.py rename to crossscore/dataloading/data_manager.py index 3c4fb8b..67018b4 100644 --- a/dataloading/data_manager.py +++ b/crossscore/dataloading/data_manager.py @@ -1,6 +1,6 @@ import torch from omegaconf import OmegaConf, ListConfig -from dataloading.dataset.nvs_dataset import NvsDataset +from crossscore.dataloading.dataset.nvs_dataset import NvsDataset from pprint import pprint diff --git a/crossscore/dataloading/dataset/__init__.py b/crossscore/dataloading/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloading/dataset/nvs_dataset.py b/crossscore/dataloading/dataset/nvs_dataset.py similarity index 99% rename from dataloading/dataset/nvs_dataset.py rename to crossscore/dataloading/dataset/nvs_dataset.py index 8ebec16..8e6cc7f 100644 --- a/dataloading/dataset/nvs_dataset.py +++ b/crossscore/dataloading/dataset/nvs_dataset.py @@ -5,10 +5,9 @@ from torch.utils.data import Dataset from omegaconf import OmegaConf -sys.path.append(str(Path(__file__).parents[2])) -from utils.io.images import metric_map_read, image_read -from utils.neighbour.sampler import SamplerFactory -from utils.check_config import ConfigChecker +from crossscore.utils.io.images import metric_map_read, image_read +from crossscore.utils.neighbour.sampler import SamplerFactory +from crossscore.utils.check_config import ConfigChecker class NeighbourSelector: diff --git a/dataloading/dataset/simple_reference.py b/crossscore/dataloading/dataset/simple_reference.py similarity index 96% rename from dataloading/dataset/simple_reference.py rename to crossscore/dataloading/dataset/simple_reference.py index 3fc4722..cf72e92 100644 --- a/dataloading/dataset/simple_reference.py +++ b/crossscore/dataloading/dataset/simple_reference.py @@ -3,8 +3,7 @@ import torch from omegaconf import OmegaConf -sys.path.append(str(Path(__file__).parents[2])) -from dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch +from crossscore.dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch class SimpleReference(NvsDataset): @@ -89,8 +88,8 @@ def get_paths(query_dir, reference_dir): from lightning import seed_everything from torchvision.transforms import v2 as T from tqdm import tqdm - from dataloading.transformation.crop import CropperFactory - from utils.io.images import ImageNetMeanStd + from crossscore.dataloading.transformation.crop import CropperFactory + from crossscore.utils.io.images import ImageNetMeanStd from omegaconf import OmegaConf seed_everything(1) diff --git a/crossscore/dataloading/transformation/__init__.py b/crossscore/dataloading/transformation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloading/transformation/crop.py b/crossscore/dataloading/transformation/crop.py similarity index 100% rename from dataloading/transformation/crop.py rename to crossscore/dataloading/transformation/crop.py diff --git a/crossscore/model/__init__.py b/crossscore/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/cross_reference.py b/crossscore/model/cross_reference.py similarity index 98% rename from model/cross_reference.py rename to crossscore/model/cross_reference.py index a952d78..32879bd 100644 --- a/model/cross_reference.py +++ b/crossscore/model/cross_reference.py @@ -4,7 +4,7 @@ TransformerDecoderCustomised, ) from .regression_layer import RegressionLayer -from utils.misc.image import jigsaw_to_image +from crossscore.utils.misc.image import jigsaw_to_image class CrossReferenceNet(torch.nn.Module): diff --git a/crossscore/model/customised_transformer/__init__.py b/crossscore/model/customised_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/customised_transformer/transformer.py b/crossscore/model/customised_transformer/transformer.py similarity index 100% rename from model/customised_transformer/transformer.py rename to crossscore/model/customised_transformer/transformer.py diff --git a/model/positional_encoding.py b/crossscore/model/positional_encoding.py similarity index 100% rename from model/positional_encoding.py rename to crossscore/model/positional_encoding.py diff --git a/model/regression_layer.py b/crossscore/model/regression_layer.py similarity index 95% rename from model/regression_layer.py rename to crossscore/model/regression_layer.py index e651048..f943afd 100644 --- a/model/regression_layer.py +++ b/crossscore/model/regression_layer.py @@ -3,8 +3,7 @@ from pathlib import Path import torch -sys.path.append(str(Path(__file__).parents[1])) -from utils.check_config import check_metric_prediction_config +from crossscore.utils.check_config import check_metric_prediction_config class RegressionLayer(torch.nn.Module): diff --git a/crossscore/task/__init__.py b/crossscore/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/task/core.py b/crossscore/task/core.py new file mode 100644 index 0000000..5218ad2 --- /dev/null +++ b/crossscore/task/core.py @@ -0,0 +1,513 @@ +from pathlib import Path +import torch +import lightning +import wandb +from transformers import Dinov2Config, Dinov2Model +from omegaconf import DictConfig, OmegaConf +from lightning.pytorch.utilities import rank_zero_only +from crossscore.utils.evaluation.metric import abs2psnr, correlation +from crossscore.utils.evaluation.metric_logger import ( + MetricLoggerScalar, + MetricLoggerHistogram, + MetricLoggerCorrelation, + MetricLoggerImg, +) +from crossscore.utils.plot.batch_visualiser import BatchVisualiserFactory +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.utils.io.batch_writer import BatchWriter +from crossscore.utils.io.score_summariser import ( + SummaryWriterPredictedOnline, + SummaryWriterPredictedOnlineTestPrediction, +) +from crossscore.model.cross_reference import CrossReferenceNet +from crossscore.model.positional_encoding import MultiViewPosionalEmbeddings + + +class CrossScoreNet(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + # used in 1. denormalising images for visualisation + # and 2. normalising images for training when required + img_norm_stat = ImageNetMeanStd() + self.register_buffer( + "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) + ) + + # backbone, freeze + self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) + self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) + for param in self.backbone.parameters(): + param.requires_grad = False + + # positional encoding layer + self.pos_enc_fn = MultiViewPosionalEmbeddings( + positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, + positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, + interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, + req_grad=self.cfg.model.pos_enc.multi_view.req_grad, + patch_size=self.cfg.model.patch_size, + hidden_size=self.dinov2_cfg.hidden_size, + ) + + # cross reference predictor + if self.cfg.model.do_reference_cross: + self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) + + def forward( + self, + query_img, + ref_cross_imgs, + need_attn_weights, + need_attn_weights_head_id, + norm_img, + ): + """ + :param query_img: (B, 3, H, W) + :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) + :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + + if norm_img: + img_mean = self.img_mean_std[None, :3, None, None] + img_std = self.img_mean_std[None, 3:, None, None] + query_img = (query_img - img_mean) / img_std + if ref_cross_imgs is not None: + ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] + + featmaps = self.get_featmaps(query_img, ref_cross_imgs) + results = {} + + # processing (and predicting) for query + featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) + + if self.cfg.model.do_reference_cross: + N_ref_cross = ref_cross_imgs.shape[1] + + # (B, N_ref_cross*num_patches, hidden_size) + featmaps["ref_cross"] = self.pos_enc_fn( + featmaps["ref_cross"], + N_view=N_ref_cross, + img_h=H, + img_w=W, + ) + + # prediction + dim_params = { + "B": B, + "N_patch_h": N_patch_h, + "N_patch_w": N_patch_w, + "N_ref": N_ref_cross, + } + results_ref_cross = self.ref_cross( + featmaps["query"], + featmaps["ref_cross"], + None, + dim_params, + need_attn_weights, + need_attn_weights_head_id, + ) + results["score_map_ref_cross"] = results_ref_cross["score_map"] + results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] + return results + + @torch.no_grad() + def get_featmaps(self, query_img, ref_cross_imgs): + """ + :param query_img: (B, 3, H, W) + :param ref_cross: (B, N_ref_cross, 3, H, W) + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + N_query = 1 + N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] + N_all_imgs = N_query + N_ref_cross + + # concat all images to go through backbone for once + all_imgs = [query_img.view(B, 1, 3, H, W)] + if ref_cross_imgs is not None: + all_imgs.append(ref_cross_imgs) + all_imgs = torch.cat(all_imgs, dim=1) + all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) + + # bbo: backbone output + bbo_all = self.backbone(all_imgs) + featmap_all = bbo_all.last_hidden_state[:, 1:] + featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) + + # query + featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) + N_patches = featmap_query.shape[1] + hidden_size = featmap_query.shape[2] + + # cross ref + if ref_cross_imgs is not None: + featmap_ref_cross = featmap_all[:, -N_ref_cross:] + featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) + else: + featmap_ref_cross = None + + featmaps = { + "query": featmap_query, # (B, num_patches, hidden_size) + "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) + } + return featmaps + + +class CrossScoreLightningModule(lightning.LightningModule): + def __init__(self, cfg: DictConfig): + super().__init__() + self.cfg = cfg + + # write config to wandb + self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) + + # init my network + self.model = CrossScoreNet(cfg=self.cfg) + + # init visualiser + self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() + + # init loss fn + if self.cfg.model.loss.fn == "l1": + self.loss_fn = torch.nn.L1Loss() + self.to_psnr_fn = abs2psnr + else: + raise NotImplementedError + + # logging related names + self.ref_mode_names = [] + if self.cfg.model.do_reference_cross: + self.ref_mode_names.append("ref_cross") + + def on_fit_start(self): + # reset logging cache + if self.global_rank == 0: + self._reset_logging_cache_train() + self._reset_logging_cache_validation() + + self.frame_score_summariser = SummaryWriterPredictedOnline( + metric_type=self.cfg.model.predict.metric.type, + metric_min=self.cfg.model.predict.metric.min, + ) + + def on_test_start(self): + Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) + if self.cfg.logger.test.write.flag.batch: + self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) + else: + self.batch_writer = None + + self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( + metric_type=self.cfg.model.predict.metric.type, + metric_min=self.cfg.model.predict.metric.min, + dir_out=self.cfg.logger.test.out_dir, + ) + + def on_predict_start(self): + Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) + if self.cfg.logger.predict.write.flag.batch: + self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) + else: + self.batch_writer = None + + self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( + metric_type=self.cfg.model.predict.metric.type, + metric_min=self.cfg.model.predict.metric.min, + dir_out=self.cfg.logger.predict.out_dir, + ) + + def _reset_logging_cache_train(self): + self.train_cache = { + "loss": { + k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) + for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names + }, + "correlation": { + k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) + for k in self.ref_mode_names + }, + "map": { + "score": { + k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) + for k in self.ref_mode_names + }, + "l1_diff": { + k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) + for k in self.ref_mode_names + }, + "delta": { + k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) + for k in ["self", "cross"] + }, + }, + } + + def _reset_logging_cache_validation(self): + self.validation_cache = { + "loss": { + k: MetricLoggerScalar(max_length=None) + for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names + }, + "correlation": { + k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names + }, + "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, + } + + def _core_step(self, batch, batch_idx, skip_loss=False): + outputs = self.model( + query_img=batch["query/img"], # (B, C, H, W) + ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) + need_attn_weights=self.cfg.model.need_attn_weights, + need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, + norm_img=False, + ) + + if skip_loss: # only used in predict_step + return outputs + + score_map = batch["query/score_map"] # (B, H, W) + + loss = [] + # cross reference model predicts + if self.cfg.model.do_reference_cross: + score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) + l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) + if self.cfg.model.loss.fn == "l1": + loss_cross = l1_diff_map_cross.mean() + else: + loss_cross = self.loss_fn(score_map_cross, score_map) + outputs["loss_cross"] = loss_cross + outputs["l1_diff_map_ref_cross"] = l1_diff_map_cross + loss.append(loss_cross) + + loss = torch.stack(loss).sum() + outputs["loss"] = loss + return outputs + + def training_step(self, batch, batch_idx): + outputs = self._core_step(batch, batch_idx) + return outputs + + def validation_step(self, batch, batch_idx): + outputs = self._core_step(batch, batch_idx) + return outputs + + def test_step(self, batch, batch_idx): + outputs = self._core_step(batch, batch_idx) + return outputs + + def predict_step(self, batch, batch_idx): + outputs = self._core_step(batch, batch_idx, skip_loss=True) + return outputs + + @rank_zero_only + def on_train_batch_end(self, outputs, batch, batch_idx): + self.train_cache["loss"]["final"].update(outputs["loss"]) + + if self.cfg.model.do_reference_cross: + self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) + self.train_cache["correlation"]["ref_cross"].update( + outputs["score_map_ref_cross"], batch["query/score_map"] + ) + self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) + self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) + + # logger vis batch + if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: + fig = self.visualiser.vis(batch, outputs) + self.logger.experiment.log({"train_batch": fig}) + + # logger vis X batches statics + if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: + # log loss + tmp_loss = self.train_cache["loss"]["final"].compute() + self.log("train/loss", tmp_loss, prog_bar=True) + + if self.cfg.model.do_reference_cross: + tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() + self.log("train/loss_cross", tmp_loss_cross) + + # log psnr + if self.cfg.model.do_reference_cross: + self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) + + # log correlation + if self.cfg.model.do_reference_cross: + self.log( + "train/correlation_cross", + self.train_cache["correlation"]["ref_cross"].compute(), + ) + + # logger vis X batches histogram + if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: + if self.cfg.model.do_reference_cross: + self.logger.experiment.log( + { + "train/score_histogram_cross": wandb.Histogram( + np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() + ), + "train/l1_diff_histogram_cross": wandb.Histogram( + np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() + ), + } + ) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.validation_cache["loss"]["final"].update(outputs["loss"]) + + if self.cfg.model.do_reference_cross: + self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) + self.validation_cache["correlation"]["ref_cross"].update( + outputs["score_map_ref_cross"], batch["query/score_map"] + ) + + self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) + + if batch_idx < self.cfg.logger.cache_size.validation.n_fig: + fig = self.visualiser.vis(batch, outputs) + self.validation_cache["fig"]["batch"].update(fig) + + def on_test_batch_end(self, outputs, batch, batch_idx): + results = {"test/loss": outputs["loss"]} + + if self.cfg.model.do_reference_cross: + corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) + psnr = self.to_psnr_fn(outputs["loss_cross"]) + results["test/loss_cross"] = outputs["loss_cross"] + results["test/corr_cross"] = corr + results["test/psnr_cross"] = psnr + + self.log_dict( + results, + on_step=self.cfg.logger.test.on_step, + sync_dist=self.cfg.logger.test.sync_dist, + ) + + self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) + + # write image to vis + if ( + self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 + and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 + ): + fig = self.visualiser.vis(batch, outputs) + fig.image.save( + Path( + self.cfg.logger.test.out_dir, + "vis", + f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", + ) + ) + + if self.cfg.logger.test.write.flag.batch: + self.batch_writer.write_out( + batch_input=batch, + batch_output=outputs, + local_rank=self.local_rank, + batch_idx=batch_idx, + ) + + def on_predict_batch_end(self, outputs, batch, batch_idx): + self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) + + # write image to vis + if ( + self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 + and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 + ): + fig = self.visualiser.vis(batch, outputs) + fig.image.save( + Path( + self.cfg.logger.predict.out_dir, + "vis", + f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", + ) + ) + + if self.cfg.logger.predict.write.flag.batch: + self.batch_writer.write_out( + batch_input=batch, + batch_output=outputs, + local_rank=self.local_rank, + batch_idx=batch_idx, + ) + + @rank_zero_only + def on_train_epoch_end(self): + self._reset_logging_cache_train() + + def on_validation_epoch_end(self): + sync_dist = True + self.log( + "validation/loss", + self.validation_cache["loss"]["final"].compute(), + prog_bar=True, + sync_dist=sync_dist, + ) + self.logger.experiment.log( + {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, + ) + + if self.cfg.model.do_reference_cross: + self.log( + "validation/loss_cross", + self.validation_cache["loss"]["ref_cross"].compute(), + sync_dist=sync_dist, + ) + self.log( + "validation/correlation_cross", + self.validation_cache["correlation"]["ref_cross"].compute(), + sync_dist=sync_dist, + ) + self.log( + "validation/psnr_cross", + self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), + sync_dist=sync_dist, + ) + + self._reset_logging_cache_validation() + self.frame_score_summariser.reset() + + def on_test_epoch_end(self): + self.frame_score_summariser.summarise() + + def on_predict_epoch_end(self): + self.frame_score_summariser.summarise() + + def configure_optimizers(self): + # how to use configure_optimizers: + # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers + + # we freeze backbone and we only pass parameters that requires grad to optimizer: + # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 + # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor + # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 + parameters = [p for p in self.model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW( + params=parameters, + lr=self.cfg.trainer.optimizer.lr, + ) + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=self.cfg.trainer.lr_scheduler.step_size, + gamma=self.cfg.trainer.lr_scheduler.gamma, + ) + + results = { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.cfg.trainer.lr_scheduler.step_interval, + "frequency": 1, + }, + } + return results diff --git a/crossscore/utils/__init__.py b/crossscore/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/check_config.py b/crossscore/utils/check_config.py similarity index 100% rename from utils/check_config.py rename to crossscore/utils/check_config.py diff --git a/utils/data_processing/split_gaussian_processed.py b/crossscore/utils/data_processing/split_gaussian_processed.py similarity index 100% rename from utils/data_processing/split_gaussian_processed.py rename to crossscore/utils/data_processing/split_gaussian_processed.py diff --git a/crossscore/utils/evaluation/__init__.py b/crossscore/utils/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/evaluation/metric.py b/crossscore/utils/evaluation/metric.py similarity index 100% rename from utils/evaluation/metric.py rename to crossscore/utils/evaluation/metric.py diff --git a/utils/evaluation/metric_logger.py b/crossscore/utils/evaluation/metric_logger.py similarity index 100% rename from utils/evaluation/metric_logger.py rename to crossscore/utils/evaluation/metric_logger.py diff --git a/utils/evaluation/summarise_score_gt.py b/crossscore/utils/evaluation/summarise_score_gt.py similarity index 91% rename from utils/evaluation/summarise_score_gt.py rename to crossscore/utils/evaluation/summarise_score_gt.py index f9e8f75..28a64e5 100644 --- a/utils/evaluation/summarise_score_gt.py +++ b/crossscore/utils/evaluation/summarise_score_gt.py @@ -2,8 +2,7 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[2])) -from utils.io.score_summariser import SummaryWriterGroundTruth +from crossscore.utils.io.score_summariser import SummaryWriterGroundTruth def parse_args(): diff --git a/crossscore/utils/io/__init__.py b/crossscore/utils/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/io/batch_writer.py b/crossscore/utils/io/batch_writer.py similarity index 98% rename from utils/io/batch_writer.py rename to crossscore/utils/io/batch_writer.py index d6b4ce9..6442dba 100644 --- a/utils/io/batch_writer.py +++ b/crossscore/utils/io/batch_writer.py @@ -2,8 +2,8 @@ from pathlib import Path from PIL import Image import numpy as np -from utils.io.images import metric_map_write, u8 -from utils.misc.image import gray2rgb, attn2rgb, de_norm_img +from crossscore.utils.io.images import metric_map_write, u8 +from crossscore.utils.misc.image import gray2rgb, attn2rgb, de_norm_img def get_vrange(predict_metric_type, predict_metric_min, predict_metric_max): diff --git a/utils/io/images.py b/crossscore/utils/io/images.py similarity index 100% rename from utils/io/images.py rename to crossscore/utils/io/images.py diff --git a/utils/io/score_summariser.py b/crossscore/utils/io/score_summariser.py similarity index 99% rename from utils/io/score_summariser.py rename to crossscore/utils/io/score_summariser.py index fc34972..9b1d49b 100644 --- a/utils/io/score_summariser.py +++ b/crossscore/utils/io/score_summariser.py @@ -9,8 +9,8 @@ from pandas import DataFrame from tqdm import tqdm -from utils.io.images import metric_map_read -from utils.evaluation.metric import mse2psnr +from crossscore.utils.io.images import metric_map_read +from crossscore.utils.evaluation.metric import mse2psnr class ScoreReader(Dataset): diff --git a/crossscore/utils/misc/__init__.py b/crossscore/utils/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/misc/image.py b/crossscore/utils/misc/image.py similarity index 98% rename from utils/misc/image.py rename to crossscore/utils/misc/image.py index 73b2e9c..614f826 100644 --- a/utils/misc/image.py +++ b/crossscore/utils/misc/image.py @@ -2,7 +2,7 @@ import numpy as np import matplotlib.cm as cm from PIL import Image, ImageDraw, ImageFont -from utils.io.images import u8 +from crossscore.utils.io.images import u8 def jigsaw_to_image(x, grid_size): diff --git a/crossscore/utils/neighbour/__init__.py b/crossscore/utils/neighbour/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/neighbour/sampler.py b/crossscore/utils/neighbour/sampler.py similarity index 100% rename from utils/neighbour/sampler.py rename to crossscore/utils/neighbour/sampler.py diff --git a/crossscore/utils/plot/__init__.py b/crossscore/utils/plot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/plot/batch_visualiser.py b/crossscore/utils/plot/batch_visualiser.py similarity index 99% rename from utils/plot/batch_visualiser.py rename to crossscore/utils/plot/batch_visualiser.py index 668e04f..a194ccc 100644 --- a/utils/plot/batch_visualiser.py +++ b/crossscore/utils/plot/batch_visualiser.py @@ -7,9 +7,9 @@ from matplotlib.patches import Rectangle from torchvision.utils import make_grid from PIL import Image -from utils.io.images import u8 -from utils.misc.image import de_norm_img -from utils.check_config import check_reference_type +from crossscore.utils.io.images import u8 +from crossscore.utils.misc.image import de_norm_img +from crossscore.utils.check_config import check_reference_type class BatchVisualiserBase(ABC): diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c6fadf4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,100 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "crossscore" +version = "1.0.0" +description = "CrossScore: Towards Multi-View Image Evaluation and Scoring" +readme = "README.md" +license = {text = "Apache-2.0"} +requires-python = ">=3.9" +authors = [ + {name = "Zirui Wang"}, + {name = "Wenjing Bian"}, + {name = "Victor Adrian Prisacariu"}, +] +keywords = ["image quality assessment", "neural rendering", "novel view synthesis", "deep learning"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", +] + +# NOTE on PyTorch/CUDA: +# We require torch>=2.0.0 but do NOT pin a specific CUDA version. +# Users should install PyTorch with their preferred CUDA version FIRST: +# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +# Then install CrossScore: +# pip install crossscore +# +# This avoids CUDA version conflicts since PyTorch+CUDA combos vary by system. +dependencies = [ + # Core deep learning (flexible ranges - user installs PyTorch first) + "torch>=2.0.0", + "torchvision>=0.15.0", + "lightning>=2.0.0", + + # DINOv2 backbone + "transformers>=4.30.0", + + # Configuration + "hydra-core>=1.3.0", + "omegaconf>=2.3.0", + + # Model download + "huggingface-hub>=0.20.0", + + # Image processing + "Pillow>=9.0.0", + "imageio>=2.20.0", + "numpy>=1.22.0", + "scipy>=1.9.0", + "scikit-image>=0.19.0", + "matplotlib>=3.5.0", + + # Data + "pandas>=1.4.0", + "tqdm>=4.60.0", +] + +[project.optional-dependencies] +# For training (additional deps not needed for inference) +train = [ + "wandb>=0.15.0", + "scikit-learn>=1.0.0", + "accelerate>=0.20.0", + "tensorboard>=2.10.0", +] +# Memory-efficient attention (optional, for large images) +xformers = [ + "xformers>=0.0.20", +] +dev = [ + "pytest>=7.0", + "ruff>=0.1.0", +] + +[project.urls] +Homepage = "https://crossscore.active.vision" +Repository = "https://github.com/ActiveVisionLab/CrossScore" +Paper = "https://arxiv.org/abs/2404.14409" + +[project.scripts] +crossscore = "crossscore.cli:main" + +[tool.setuptools.packages.find] +include = ["crossscore*"] + +[tool.setuptools.package-data] +crossscore = [ + "config/**/*.yaml", +] diff --git a/task/core.py b/task/core.py index 80de8be..3eb8c23 100644 --- a/task/core.py +++ b/task/core.py @@ -1,513 +1,4 @@ -from pathlib import Path -import torch -import lightning -import wandb -from transformers import Dinov2Config, Dinov2Model -from omegaconf import DictConfig, OmegaConf -from lightning.pytorch.utilities import rank_zero_only -from utils.evaluation.metric import abs2psnr, correlation -from utils.evaluation.metric_logger import ( - MetricLoggerScalar, - MetricLoggerHistogram, - MetricLoggerCorrelation, - MetricLoggerImg, -) -from utils.plot.batch_visualiser import BatchVisualiserFactory -from utils.io.images import ImageNetMeanStd -from utils.io.batch_writer import BatchWriter -from utils.io.score_summariser import ( - SummaryWriterPredictedOnline, - SummaryWriterPredictedOnlineTestPrediction, -) -from model.cross_reference import CrossReferenceNet -from model.positional_encoding import MultiViewPosionalEmbeddings +"""Backwards compatibility: re-export from crossscore package.""" +from crossscore.task.core import CrossScoreLightningModule, CrossScoreNet - -class CrossScoreNet(torch.nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - - # used in 1. denormalising images for visualisation - # and 2. normalising images for training when required - img_norm_stat = ImageNetMeanStd() - self.register_buffer( - "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) - ) - - # backbone, freeze - self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) - self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) - for param in self.backbone.parameters(): - param.requires_grad = False - - # positional encoding layer - self.pos_enc_fn = MultiViewPosionalEmbeddings( - positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, - positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, - interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, - req_grad=self.cfg.model.pos_enc.multi_view.req_grad, - patch_size=self.cfg.model.patch_size, - hidden_size=self.dinov2_cfg.hidden_size, - ) - - # cross reference predictor - if self.cfg.model.do_reference_cross: - self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) - - def forward( - self, - query_img, - ref_cross_imgs, - need_attn_weights, - need_attn_weights_head_id, - norm_img, - ): - """ - :param query_img: (B, 3, H, W) - :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) - :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - - if norm_img: - img_mean = self.img_mean_std[None, :3, None, None] - img_std = self.img_mean_std[None, 3:, None, None] - query_img = (query_img - img_mean) / img_std - if ref_cross_imgs is not None: - ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] - - featmaps = self.get_featmaps(query_img, ref_cross_imgs) - results = {} - - # processing (and predicting) for query - featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) - - if self.cfg.model.do_reference_cross: - N_ref_cross = ref_cross_imgs.shape[1] - - # (B, N_ref_cross*num_patches, hidden_size) - featmaps["ref_cross"] = self.pos_enc_fn( - featmaps["ref_cross"], - N_view=N_ref_cross, - img_h=H, - img_w=W, - ) - - # prediction - dim_params = { - "B": B, - "N_patch_h": N_patch_h, - "N_patch_w": N_patch_w, - "N_ref": N_ref_cross, - } - results_ref_cross = self.ref_cross( - featmaps["query"], - featmaps["ref_cross"], - None, - dim_params, - need_attn_weights, - need_attn_weights_head_id, - ) - results["score_map_ref_cross"] = results_ref_cross["score_map"] - results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] - return results - - @torch.no_grad() - def get_featmaps(self, query_img, ref_cross_imgs): - """ - :param query_img: (B, 3, H, W) - :param ref_cross: (B, N_ref_cross, 3, H, W) - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - N_query = 1 - N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] - N_all_imgs = N_query + N_ref_cross - - # concat all images to go through backbone for once - all_imgs = [query_img.view(B, 1, 3, H, W)] - if ref_cross_imgs is not None: - all_imgs.append(ref_cross_imgs) - all_imgs = torch.cat(all_imgs, dim=1) - all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) - - # bbo: backbone output - bbo_all = self.backbone(all_imgs) - featmap_all = bbo_all.last_hidden_state[:, 1:] - featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) - - # query - featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) - N_patches = featmap_query.shape[1] - hidden_size = featmap_query.shape[2] - - # cross ref - if ref_cross_imgs is not None: - featmap_ref_cross = featmap_all[:, -N_ref_cross:] - featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) - else: - featmap_ref_cross = None - - featmaps = { - "query": featmap_query, # (B, num_patches, hidden_size) - "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) - } - return featmaps - - -class CrossScoreLightningModule(lightning.LightningModule): - def __init__(self, cfg: DictConfig): - super().__init__() - self.cfg = cfg - - # write config to wandb - self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) - - # init my network - self.model = CrossScoreNet(cfg=self.cfg) - - # init visualiser - self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() - - # init loss fn - if self.cfg.model.loss.fn == "l1": - self.loss_fn = torch.nn.L1Loss() - self.to_psnr_fn = abs2psnr - else: - raise NotImplementedError - - # logging related names - self.ref_mode_names = [] - if self.cfg.model.do_reference_cross: - self.ref_mode_names.append("ref_cross") - - def on_fit_start(self): - # reset logging cache - if self.global_rank == 0: - self._reset_logging_cache_train() - self._reset_logging_cache_validation() - - self.frame_score_summariser = SummaryWriterPredictedOnline( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - ) - - def on_test_start(self): - Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.test.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.test.out_dir, - ) - - def on_predict_start(self): - Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.predict.out_dir, - ) - - def _reset_logging_cache_train(self): - self.train_cache = { - "loss": { - k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "map": { - "score": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "l1_diff": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "delta": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["self", "cross"] - }, - }, - } - - def _reset_logging_cache_validation(self): - self.validation_cache = { - "loss": { - k: MetricLoggerScalar(max_length=None) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names - }, - "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, - } - - def _core_step(self, batch, batch_idx, skip_loss=False): - outputs = self.model( - query_img=batch["query/img"], # (B, C, H, W) - ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) - need_attn_weights=self.cfg.model.need_attn_weights, - need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, - norm_img=False, - ) - - if skip_loss: # only used in predict_step - return outputs - - score_map = batch["query/score_map"] # (B, H, W) - - loss = [] - # cross reference model predicts - if self.cfg.model.do_reference_cross: - score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) - l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) - if self.cfg.model.loss.fn == "l1": - loss_cross = l1_diff_map_cross.mean() - else: - loss_cross = self.loss_fn(score_map_cross, score_map) - outputs["loss_cross"] = loss_cross - outputs["l1_diff_map_ref_cross"] = l1_diff_map_cross - loss.append(loss_cross) - - loss = torch.stack(loss).sum() - outputs["loss"] = loss - return outputs - - def training_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def validation_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def test_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def predict_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx, skip_loss=True) - return outputs - - @rank_zero_only - def on_train_batch_end(self, outputs, batch, batch_idx): - self.train_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.train_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) - self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) - - # logger vis batch - if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: - fig = self.visualiser.vis(batch, outputs) - self.logger.experiment.log({"train_batch": fig}) - - # logger vis X batches statics - if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: - # log loss - tmp_loss = self.train_cache["loss"]["final"].compute() - self.log("train/loss", tmp_loss, prog_bar=True) - - if self.cfg.model.do_reference_cross: - tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() - self.log("train/loss_cross", tmp_loss_cross) - - # log psnr - if self.cfg.model.do_reference_cross: - self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) - - # log correlation - if self.cfg.model.do_reference_cross: - self.log( - "train/correlation_cross", - self.train_cache["correlation"]["ref_cross"].compute(), - ) - - # logger vis X batches histogram - if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: - if self.cfg.model.do_reference_cross: - self.logger.experiment.log( - { - "train/score_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() - ), - "train/l1_diff_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() - ), - } - ) - - def on_validation_batch_end(self, outputs, batch, batch_idx): - self.validation_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.validation_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - if batch_idx < self.cfg.logger.cache_size.validation.n_fig: - fig = self.visualiser.vis(batch, outputs) - self.validation_cache["fig"]["batch"].update(fig) - - def on_test_batch_end(self, outputs, batch, batch_idx): - results = {"test/loss": outputs["loss"]} - - if self.cfg.model.do_reference_cross: - corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) - psnr = self.to_psnr_fn(outputs["loss_cross"]) - results["test/loss_cross"] = outputs["loss_cross"] - results["test/corr_cross"] = corr - results["test/psnr_cross"] = psnr - - self.log_dict( - results, - on_step=self.cfg.logger.test.on_step, - sync_dist=self.cfg.logger.test.sync_dist, - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.test.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.test.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - def on_predict_batch_end(self, outputs, batch, batch_idx): - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.predict.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - @rank_zero_only - def on_train_epoch_end(self): - self._reset_logging_cache_train() - - def on_validation_epoch_end(self): - sync_dist = True - self.log( - "validation/loss", - self.validation_cache["loss"]["final"].compute(), - prog_bar=True, - sync_dist=sync_dist, - ) - self.logger.experiment.log( - {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, - ) - - if self.cfg.model.do_reference_cross: - self.log( - "validation/loss_cross", - self.validation_cache["loss"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/correlation_cross", - self.validation_cache["correlation"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/psnr_cross", - self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), - sync_dist=sync_dist, - ) - - self._reset_logging_cache_validation() - self.frame_score_summariser.reset() - - def on_test_epoch_end(self): - self.frame_score_summariser.summarise() - - def on_predict_epoch_end(self): - self.frame_score_summariser.summarise() - - def configure_optimizers(self): - # how to use configure_optimizers: - # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers - - # we freeze backbone and we only pass parameters that requires grad to optimizer: - # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 - # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor - # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 - parameters = [p for p in self.model.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW( - params=parameters, - lr=self.cfg.trainer.optimizer.lr, - ) - lr_scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, - step_size=self.cfg.trainer.lr_scheduler.step_size, - gamma=self.cfg.trainer.lr_scheduler.gamma, - ) - - results = { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": self.cfg.trainer.lr_scheduler.step_interval, - "frequency": 1, - }, - } - return results +__all__ = ["CrossScoreLightningModule", "CrossScoreNet"] diff --git a/task/predict.py b/task/predict.py index e565b78..fe03911 100644 --- a/task/predict.py +++ b/task/predict.py @@ -1,9 +1,6 @@ from datetime import datetime -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.dataset.simple_reference import SimpleReference -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from crossscore.task.core import CrossScoreLightningModule +from crossscore.dataloading.dataset.simple_reference import SimpleReference +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_predict") diff --git a/task/test.py b/task/test.py index ac0f6e3..7eb4eb0 100644 --- a/task/test.py +++ b/task/test.py @@ -1,8 +1,5 @@ -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from crossscore.task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_test") diff --git a/task/train.py b/task/train.py index acc2650..44a8aea 100644 --- a/task/train.py +++ b/task/train.py @@ -1,9 +1,6 @@ -import sys from pathlib import Path from datetime import timedelta -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -16,11 +13,11 @@ from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd -from utils.check_config import ConfigChecker +from crossscore.task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.utils.check_config import ConfigChecker @hydra.main(version_base="1.3", config_path="../config", config_name="default") From 71736228992343f0074f11c6ca0bed802e0e412c Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 21 Mar 2026 19:52:02 +0000 Subject: [PATCH 2/5] Fix CPU mode support and make wandb a lazy import - Make wandb import lazy in core.py and batch_visualiser.py so inference works without wandb installed (it's only needed for training) - Fix CPU mode: use "auto" devices, "32-true" precision, handle non-list device configs in DDP logic - Add --cpu flag to CLI https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB --- crossscore/api.py | 12 ++++++++---- crossscore/cli.py | 9 ++++++++- crossscore/task/core.py | 3 ++- crossscore/utils/plot/batch_visualiser.py | 4 +++- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/crossscore/api.py b/crossscore/api.py index e197cd6..c1ad999 100644 --- a/crossscore/api.py +++ b/crossscore/api.py @@ -25,8 +25,12 @@ def _build_config( out_dir: Optional[str] = None, ) -> OmegaConf: """Build an OmegaConf config object for prediction.""" + use_gpu = torch.cuda.is_available() and devices != "cpu" if devices is None: - devices = [0] if torch.cuda.is_available() else [] + devices = [0] if use_gpu else "auto" + elif devices == "cpu": + devices = "auto" + use_gpu = False config_dir = Path(__file__).parent / "config" # Load base configs @@ -45,8 +49,8 @@ def _build_config( base_cfg.data.loader.validation.batch_size = batch_size base_cfg.data.loader.validation.num_workers = num_workers base_cfg.trainer.devices = devices - base_cfg.trainer.precision = "16-mixed" - base_cfg.trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu" + base_cfg.trainer.precision = "16-mixed" if use_gpu else "32-true" + base_cfg.trainer.accelerator = "gpu" if use_gpu else "cpu" if out_dir is not None: base_cfg.logger.predict.out_dir = out_dir @@ -166,7 +170,7 @@ def score( # Build model and trainer model = CrossScoreLightningModule(cfg) - NUM_GPUS = len(cfg.trainer.devices) + NUM_GPUS = len(cfg.trainer.devices) if isinstance(cfg.trainer.devices, list) else 0 if NUM_GPUS > 1: from lightning.pytorch.strategies import DDPStrategy strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) diff --git a/crossscore/cli.py b/crossscore/cli.py index c311fac..7492634 100644 --- a/crossscore/cli.py +++ b/crossscore/cli.py @@ -51,6 +51,11 @@ def main(): default=None, help="GPU device indices (default: [0])", ) + parser.add_argument( + "--cpu", + action="store_true", + help="Force CPU mode (no GPU)", + ) parser.add_argument( "--out-dir", default=None, @@ -66,6 +71,8 @@ def main(): from crossscore.api import score + devices = "cpu" if args.cpu else args.devices + results = score( query_dir=args.query_dir, reference_dir=args.reference_dir, @@ -74,7 +81,7 @@ def main(): batch_size=args.batch_size, num_workers=args.num_workers, resize_short_side=args.resize_short_side, - devices=args.devices, + devices=devices, out_dir=args.out_dir, write_outputs=not args.no_write, ) diff --git a/crossscore/task/core.py b/crossscore/task/core.py index 5218ad2..7d6ca17 100644 --- a/crossscore/task/core.py +++ b/crossscore/task/core.py @@ -1,7 +1,6 @@ from pathlib import Path import torch import lightning -import wandb from transformers import Dinov2Config, Dinov2Model from omegaconf import DictConfig, OmegaConf from lightning.pytorch.utilities import rank_zero_only @@ -350,6 +349,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx): # logger vis X batches histogram if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: if self.cfg.model.do_reference_cross: + import wandb + self.logger.experiment.log( { "train/score_histogram_cross": wandb.Histogram( diff --git a/crossscore/utils/plot/batch_visualiser.py b/crossscore/utils/plot/batch_visualiser.py index a194ccc..b037e30 100644 --- a/crossscore/utils/plot/batch_visualiser.py +++ b/crossscore/utils/plot/batch_visualiser.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod import numpy as np import torch -import wandb import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from torchvision.utils import make_grid @@ -159,6 +158,7 @@ def vis(self, batch_input, batch_output, vis_id=0): ).convert("RGB") plt.close() + import wandb out_img = wandb.Image(out_img, file_type="jpg") return out_img @@ -306,6 +306,7 @@ def vis(self, batch_input, batch_output, vis_id=0): ).convert("RGB") plt.close() + import wandb out_img = wandb.Image(out_img, file_type="jpg") return out_img @@ -390,6 +391,7 @@ def vis(self, batch_input, batch_output, vis_id=0): ).convert("RGB") plt.close() + import wandb out_img = wandb.Image(out_img, file_type="jpg") return out_img From dea24588517a9463e4b9f9cec6e2c7028391a5b4 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 21 Mar 2026 20:05:42 +0000 Subject: [PATCH 3/5] Strip package to inference-only, remove Lightning dependency The pip package now only supports inference (scoring), not training. This dramatically reduces dependencies and complexity. Key changes: - Rewrite crossscore/task/core.py to pure PyTorch (no LightningModule) - CrossScoreNet is a plain torch.nn.Module - load_model() handles Lightning checkpoint format - Rewrite crossscore/api.py to use direct torch inference loop - No Lightning Trainer, just DataLoader + model.forward() - Returns both score_maps (tensors) and scores (per-image means) - Writes colorized score map PNGs to disk - Remove training-only files from package: - metric_logger, batch_visualiser, batch_writer, score_summariser - evaluation metrics, data_manager, summarise_score_gt - Remove heavy deps: lightning, wandb, scipy, scikit-image, pandas, hydra-core - Move CrossScoreLightningModule to task/core.py (repo-only, not in pip package) for training workflow - Add --cpu and --device flags to CLI https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB --- README.md | 12 +- crossscore/__init__.py | 1 + crossscore/api.py | 217 ++++----- crossscore/cli.py | 32 +- crossscore/config/data/SimpleReference.yaml | 22 - crossscore/config/default_predict.yaml | 50 --- crossscore/dataloading/data_manager.py | 41 -- crossscore/task/core.py | 404 ++--------------- .../split_gaussian_processed.py | 134 ------ crossscore/utils/evaluation/__init__.py | 0 crossscore/utils/evaluation/metric.py | 30 -- crossscore/utils/evaluation/metric_logger.py | 55 --- .../utils/evaluation/summarise_score_gt.py | 42 -- crossscore/utils/io/batch_writer.py | 270 ------------ crossscore/utils/io/score_summariser.py | 315 ------------- crossscore/utils/plot/__init__.py | 0 crossscore/utils/plot/batch_visualiser.py | 416 ------------------ pyproject.toml | 24 +- task/core.py | 113 ++++- task/predict.py | 2 +- task/test.py | 2 +- task/train.py | 2 +- 22 files changed, 269 insertions(+), 1915 deletions(-) delete mode 100644 crossscore/config/data/SimpleReference.yaml delete mode 100644 crossscore/config/default_predict.yaml delete mode 100644 crossscore/dataloading/data_manager.py delete mode 100644 crossscore/utils/data_processing/split_gaussian_processed.py delete mode 100644 crossscore/utils/evaluation/__init__.py delete mode 100644 crossscore/utils/evaluation/metric.py delete mode 100644 crossscore/utils/evaluation/metric_logger.py delete mode 100644 crossscore/utils/evaluation/summarise_score_gt.py delete mode 100644 crossscore/utils/io/batch_writer.py delete mode 100644 crossscore/utils/io/score_summariser.py delete mode 100644 crossscore/utils/plot/__init__.py delete mode 100644 crossscore/utils/plot/batch_visualiser.py diff --git a/README.md b/README.md index e32aed9..ef5c9ee 100644 --- a/README.md +++ b/README.md @@ -55,15 +55,20 @@ pip install -e . import crossscore # Score query images against reference images +# Model checkpoint is auto-downloaded on first use (~129MB) results = crossscore.score( query_dir="path/to/query/images", reference_dir="path/to/reference/images", ) -# The model checkpoint is auto-downloaded on first use (~129MB) -# Score maps are written to disk and returned as tensors +# Per-image mean scores +print(results["scores"]) # [0.82, 0.91, 0.76, ...] + +# Score map tensors (pixel-level quality maps) for score_map in results["score_maps"]: print(score_map.shape) # (batch_size, H, W) + +# Colorized score map PNGs are written to results["out_dir"] ``` ### Command Line @@ -72,6 +77,9 @@ crossscore --query-dir path/to/queries --reference-dir path/to/references # With options crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + +# Force CPU mode +crossscore --query-dir renders/ --reference-dir gt/ --cpu ``` ### Environment Variables diff --git a/crossscore/__init__.py b/crossscore/__init__.py index 4b22ec9..356ca68 100644 --- a/crossscore/__init__.py +++ b/crossscore/__init__.py @@ -9,6 +9,7 @@ ... query_dir="path/to/query/images", ... reference_dir="path/to/reference/images", ... ) + >>> print(results["scores"]) # per-image mean scores """ __version__ = "1.0.0" diff --git a/crossscore/api.py b/crossscore/api.py index c1ad999..7e7380b 100644 --- a/crossscore/api.py +++ b/crossscore/api.py @@ -1,61 +1,37 @@ """High-level API for CrossScore image quality assessment.""" from pathlib import Path -from typing import Optional +from typing import Optional, Union, List +import numpy as np import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T from omegaconf import OmegaConf +from tqdm import tqdm from crossscore._download import get_checkpoint_path from crossscore.utils.io.images import ImageNetMeanStd from crossscore.dataloading.dataset.simple_reference import SimpleReference -from crossscore.dataloading.transformation.crop import CropperFactory -def _build_config( - metric_type: str = "ssim", - metric_min: int = 0, - metric_max: int = 1, - batch_size: int = 8, - num_workers: int = 4, - resize_short_side: int = 518, - devices: Optional[list] = None, - out_dir: Optional[str] = None, -) -> OmegaConf: - """Build an OmegaConf config object for prediction.""" - use_gpu = torch.cuda.is_available() and devices != "cpu" - if devices is None: - devices = [0] if use_gpu else "auto" - elif devices == "cpu": - devices = "auto" - use_gpu = False - - config_dir = Path(__file__).parent / "config" - # Load base configs - base_cfg = OmegaConf.load(config_dir / "default_predict.yaml") - model_cfg = OmegaConf.load(config_dir / "model" / "model.yaml") - data_cfg = OmegaConf.load(config_dir / "data" / "SimpleReference.yaml") - - # Merge model config into base - base_cfg.model = model_cfg - base_cfg.data = data_cfg - - # Apply overrides - base_cfg.model.predict.metric.type = metric_type - base_cfg.model.predict.metric.min = metric_min - base_cfg.model.predict.metric.max = metric_max - base_cfg.data.loader.validation.batch_size = batch_size - base_cfg.data.loader.validation.num_workers = num_workers - base_cfg.trainer.devices = devices - base_cfg.trainer.precision = "16-mixed" if use_gpu else "32-true" - base_cfg.trainer.accelerator = "gpu" if use_gpu else "cpu" - - if out_dir is not None: - base_cfg.logger.predict.out_dir = out_dir - - return base_cfg +def _write_score_maps(score_maps, query_paths, out_dir, metric_type, metric_min, metric_max): + """Write score maps to disk as colorized PNGs.""" + from PIL import Image + from crossscore.utils.misc.image import gray2rgb + + vrange_vis = [metric_min, metric_max] + out_dir = Path(out_dir) / "score_maps" + out_dir.mkdir(parents=True, exist_ok=True) + + idx = 0 + for batch_maps, batch_paths in zip(score_maps, query_paths): + for score_map, qpath in zip(batch_maps, batch_paths): + fname = Path(qpath).stem + ".png" + rgb = gray2rgb(score_map.cpu().numpy(), vrange_vis) + Image.fromarray(rgb).save(out_dir / fname) + idx += 1 + return str(out_dir) def score( @@ -66,9 +42,9 @@ def score( batch_size: int = 8, num_workers: int = 4, resize_short_side: int = 518, - devices: Optional[list] = None, + device: Optional[str] = None, out_dir: Optional[str] = None, - write_outputs: bool = True, + write_score_maps: bool = True, ) -> dict: """Score query images against reference images using CrossScore. @@ -79,16 +55,16 @@ def score( metric_type: Metric type to predict. One of "ssim", "mae", "mse". batch_size: Batch size for inference. num_workers: Number of data loading workers. - resize_short_side: Resize images so short side equals this value. Set to -1 to disable. - devices: List of GPU device indices. Defaults to [0] if CUDA available. - out_dir: Output directory for score maps. Defaults to a timestamped directory - under the checkpoint's parent. - write_outputs: Whether to write score maps and visualizations to disk. + resize_short_side: Resize images so short side equals this value. -1 to disable. + device: Device string ("cuda", "cuda:0", "cpu"). Auto-detected if None. + out_dir: Output directory for score maps. Defaults to "./crossscore_output". + write_score_maps: Whether to write colorized score map PNGs to disk. Returns: Dictionary with: - - "score_maps": List of predicted score map tensors - - "out_dir": Output directory path (if write_outputs=True) + - "score_maps": List of score map tensors, each (B, H, W) + - "scores": List of per-image mean scores (float) + - "out_dir": Output directory path (if write_score_maps=True) Example: >>> import crossscore @@ -96,52 +72,26 @@ def score( ... query_dir="path/to/query/images", ... reference_dir="path/to/reference/images", ... ) + >>> print(results["scores"]) # per-image mean scores """ - import lightning - from datetime import datetime - from crossscore.task.core import CrossScoreLightningModule + from crossscore.task.core import load_model + + # Determine device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" # Get checkpoint if ckpt_path is None: ckpt_path = get_checkpoint_path() - # Build config - metric_min = -1 if metric_type == "ssim" else 0 - # For SSIM, CrossScore predicts in [0, 1] by default (the common sub-range) - if metric_type == "ssim": - metric_min = 0 - - cfg = _build_config( - metric_type=metric_type, - metric_min=metric_min, - batch_size=batch_size, - num_workers=num_workers, - resize_short_side=resize_short_side, - devices=devices, - out_dir=out_dir, - ) - - # Set checkpoint path - cfg.trainer.ckpt_path_to_load = ckpt_path - - # Determine output directory - if cfg.logger.predict.out_dir is None: - now = datetime.now().strftime("%Y%m%d_%H%M%S.%f") - log_dir = Path(ckpt_path).parents[1] if Path(ckpt_path).parent.name == "ckpt" else Path(".") - cfg.logger.predict.out_dir = str(log_dir / "predict" / now) - - if not write_outputs: - cfg.logger.predict.write.flag.batch = False - cfg.logger.predict.write.config.vis_img_every_n_steps = -1 - - # Set up data - lightning.seed_everything(cfg.lightning.seed, workers=True) + # Load model + model = load_model(ckpt_path, device=device) + # Set up data transforms img_norm_stat = ImageNetMeanStd() transforms = { "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), } - if resize_short_side > 0: transforms["resize"] = T.Resize( resize_short_side, @@ -149,62 +99,73 @@ def score( antialias=True, ) + # Build dataset and dataloader + neighbour_config = {"strategy": "random", "cross": 5, "deterministic": False} dataset = SimpleReference( query_dir=query_dir, reference_dir=reference_dir, transforms=transforms, - neighbour_config=cfg.data.neighbour_config, + neighbour_config=neighbour_config, return_item_paths=True, - zero_reference=cfg.data.dataset.zero_reference, + zero_reference=False, ) dataloader = DataLoader( dataset, - batch_size=cfg.data.loader.validation.batch_size, + batch_size=batch_size, shuffle=False, - num_workers=cfg.data.loader.validation.num_workers, - pin_memory=True, + num_workers=num_workers, + pin_memory=(device != "cpu"), persistent_workers=False, ) - # Build model and trainer - model = CrossScoreLightningModule(cfg) - - NUM_GPUS = len(cfg.trainer.devices) if isinstance(cfg.trainer.devices, list) else 0 - if NUM_GPUS > 1: - from lightning.pytorch.strategies import DDPStrategy - strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) - use_distributed_sampler = True - else: - strategy = "auto" - use_distributed_sampler = False - - trainer = lightning.Trainer( - accelerator=cfg.trainer.accelerator, - devices=cfg.trainer.devices, - precision=cfg.trainer.precision, - strategy=strategy, - use_distributed_sampler=use_distributed_sampler, - logger=False, - ) + # Run inference + all_score_maps = [] + all_scores = [] + all_query_paths = [] - # Run prediction with torch.no_grad(): - predictions = trainer.predict( - model, - dataloader, - ckpt_path=ckpt_path, - ) + for batch in tqdm(dataloader, desc="CrossScore"): + query_img = batch["query/img"].to(device) + ref_imgs = batch.get("reference/cross/imgs") + if ref_imgs is not None: + ref_imgs = ref_imgs.to(device) + + outputs = model( + query_img=query_img, + ref_cross_imgs=ref_imgs, + norm_img=False, + ) + + score_map = outputs["score_map_ref_cross"] # (B, H, W) + all_score_maps.append(score_map.cpu()) + + # Per-image mean score + for i in range(score_map.shape[0]): + all_scores.append(score_map[i].mean().item()) + + # Track query paths for output naming + if "item_paths" in batch and "query/img" in batch["item_paths"]: + all_query_paths.append(batch["item_paths"]["query/img"]) + + # Build results + metric_min = -1 if metric_type == "ssim" else 0 + if metric_type == "ssim": + metric_min = 0 # CrossScore predicts SSIM in [0, 1] by default - # Collect results - score_maps = [] - if predictions: - for batch_output in predictions: - if "score_map_ref_cross" in batch_output: - score_maps.append(batch_output["score_map_ref_cross"].cpu()) + results = { + "score_maps": all_score_maps, + "scores": all_scores, + } - results = {"score_maps": score_maps} - if write_outputs: - results["out_dir"] = cfg.logger.predict.out_dir + # Write outputs + if write_score_maps and all_score_maps: + if out_dir is None: + out_dir = "./crossscore_output" + written_dir = _write_score_maps( + all_score_maps, all_query_paths, out_dir, + metric_type, metric_min, metric_max=1, + ) + results["out_dir"] = written_dir return results diff --git a/crossscore/cli.py b/crossscore/cli.py index 7492634..7be9d9c 100644 --- a/crossscore/cli.py +++ b/crossscore/cli.py @@ -1,7 +1,6 @@ """Command-line interface for CrossScore.""" import argparse -import sys def main(): @@ -12,7 +11,7 @@ def main(): Examples: crossscore --query-dir path/to/queries --reference-dir path/to/references crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 - crossscore --query-dir renders/ --reference-dir gt/ --ckpt-path my_model.ckpt + crossscore --query-dir renders/ --reference-dir gt/ --cpu """, ) parser.add_argument( @@ -45,11 +44,9 @@ def main(): help="Resize short side to this value, -1 to disable (default: 518)", ) parser.add_argument( - "--devices", - type=int, - nargs="+", + "--device", default=None, - help="GPU device indices (default: [0])", + help="Device string, e.g. 'cuda', 'cuda:0', 'cpu' (default: auto-detect)", ) parser.add_argument( "--cpu", @@ -59,19 +56,19 @@ def main(): parser.add_argument( "--out-dir", default=None, - help="Output directory for results (default: auto-generated)", + help="Output directory for results (default: ./crossscore_output)", ) parser.add_argument( "--no-write", action="store_true", - help="Do not write output files to disk", + help="Do not write score map images to disk", ) args = parser.parse_args() from crossscore.api import score - devices = "cpu" if args.cpu else args.devices + device = "cpu" if args.cpu else args.device results = score( query_dir=args.query_dir, @@ -81,15 +78,20 @@ def main(): batch_size=args.batch_size, num_workers=args.num_workers, resize_short_side=args.resize_short_side, - devices=devices, + device=device, out_dir=args.out_dir, - write_outputs=not args.no_write, + write_score_maps=not args.no_write, ) - n_maps = sum(s.shape[0] for s in results["score_maps"]) if results["score_maps"] else 0 - print(f"\nCrossScore completed: {n_maps} score maps generated") - if "out_dir" in results and results["out_dir"]: - print(f"Results written to: {results['out_dir']}") + n_images = len(results["scores"]) + print(f"\nCrossScore completed: {n_images} images scored") + if results["scores"]: + mean_score = sum(results["scores"]) / len(results["scores"]) + print(f"Mean score: {mean_score:.4f}") + for i, s in enumerate(results["scores"]): + print(f" Image {i}: {s:.4f}") + if "out_dir" in results: + print(f"Score maps written to: {results['out_dir']}") if __name__ == "__main__": diff --git a/crossscore/config/data/SimpleReference.yaml b/crossscore/config/data/SimpleReference.yaml deleted file mode 100644 index 0f86c52..0000000 --- a/crossscore/config/data/SimpleReference.yaml +++ /dev/null @@ -1,22 +0,0 @@ -loader: - validation: - batch_size: 8 - num_workers: 8 - shuffle: True - pin_memory: True - persistent_workers: False - prefetch_factor: 2 - -dataset: - query_dir: null - reference_dir: null - resolution: res_540 - zero_reference: False - -neighbour_config: - strategy: random - cross: 5 - deterministic: False - -transforms: - crop_size: 518 diff --git a/crossscore/config/default_predict.yaml b/crossscore/config/default_predict.yaml deleted file mode 100644 index 80e1f99..0000000 --- a/crossscore/config/default_predict.yaml +++ /dev/null @@ -1,50 +0,0 @@ -defaults: - - _self_ # overriding order, see https://hydra.cc/docs/tutorials/structured_config/defaults/#a-note-about-composition-order - - data: SimpleReference - - model: model - - override hydra/hydra_logging: disabled - - override hydra/job_logging: disabled - -hydra: - output_subdir: null - run: - dir: . - -lightning: - seed: 1 - -project: - name: CrossScore - -alias: "" - -trainer: - accelerator: gpu - devices: [0] - # devices: [0, 1] - precision: 16-mixed - - limit_test_batches: 1.0 - ckpt_path_to_load: null - -logger: - predict: - out_dir: null # if null, use ckpt dir - write: - flag: - batch: True - score_map_prediction: True - item_path_json: False - score_map_gt: False - attn_weights: False - image_query: True - image_reference: True - config: - vis_img_every_n_steps: 1 # -1: off - score_map_colour_mode: rgb # gray or rgb, use rgb for vis - -this_main: - resize_short_side: 518 # set to -1 to disable - crop_mode: null # default no crop - - force_batch_size: False diff --git a/crossscore/dataloading/data_manager.py b/crossscore/dataloading/data_manager.py deleted file mode 100644 index 67018b4..0000000 --- a/crossscore/dataloading/data_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from omegaconf import OmegaConf, ListConfig -from crossscore.dataloading.dataset.nvs_dataset import NvsDataset -from pprint import pprint - - -def get_dataset(cfg, transforms, data_split, return_item_paths=False): - if isinstance(cfg.data.dataset.path, str): - dataset_path_list = [cfg.data.dataset.path] - elif isinstance(cfg.data.dataset.path, ListConfig): - dataset_path_list = OmegaConf.to_object(cfg.data.dataset.path) - else: - raise ValueError("cfg.data.dataset.path should be a string or a ListConfig") - - N_dataset = len(dataset_path_list) - print(f"Get {N_dataset} datasets for {data_split}:") - pprint(f"cfg.data.dataset.path: {dataset_path_list}") - print("==================================") - - dataset_list = [] - for i in range(N_dataset): - dataset_list.append( - NvsDataset( - dataset_path=dataset_path_list[i], - resolution=cfg.data.dataset.resolution, - data_split=data_split, - transforms=transforms, - neighbour_config=cfg.data.neighbour_config, - metric_type=cfg.model.predict.metric.type, - metric_min=cfg.model.predict.metric.min, - metric_max=cfg.model.predict.metric.max, - return_item_paths=return_item_paths, - num_gaussians_iters=cfg.data.dataset.num_gaussians_iters, - zero_reference=cfg.data.dataset.zero_reference, - ) - ) - if N_dataset == 1: - dataset = dataset_list[0] - else: - dataset = torch.utils.data.ConcatDataset(dataset_list) - return dataset diff --git a/crossscore/task/core.py b/crossscore/task/core.py index 7d6ca17..0bca761 100644 --- a/crossscore/task/core.py +++ b/crossscore/task/core.py @@ -1,23 +1,10 @@ -from pathlib import Path +"""CrossScoreNet: the core neural network for CrossScore inference.""" + import torch -import lightning from transformers import Dinov2Config, Dinov2Model -from omegaconf import DictConfig, OmegaConf -from lightning.pytorch.utilities import rank_zero_only -from crossscore.utils.evaluation.metric import abs2psnr, correlation -from crossscore.utils.evaluation.metric_logger import ( - MetricLoggerScalar, - MetricLoggerHistogram, - MetricLoggerCorrelation, - MetricLoggerImg, -) -from crossscore.utils.plot.batch_visualiser import BatchVisualiserFactory +from omegaconf import OmegaConf + from crossscore.utils.io.images import ImageNetMeanStd -from crossscore.utils.io.batch_writer import BatchWriter -from crossscore.utils.io.score_summariser import ( - SummaryWriterPredictedOnline, - SummaryWriterPredictedOnlineTestPrediction, -) from crossscore.model.cross_reference import CrossReferenceNet from crossscore.model.positional_encoding import MultiViewPosionalEmbeddings @@ -27,8 +14,6 @@ def __init__(self, cfg): super().__init__() self.cfg = cfg - # used in 1. denormalising images for visualisation - # and 2. normalising images for training when required img_norm_stat = ImageNetMeanStd() self.register_buffer( "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) @@ -58,9 +43,9 @@ def forward( self, query_img, ref_cross_imgs, - need_attn_weights, - need_attn_weights_head_id, - norm_img, + need_attn_weights=False, + need_attn_weights_head_id=0, + norm_img=False, ): """ :param query_img: (B, 3, H, W) @@ -160,355 +145,42 @@ def get_featmaps(self, query_img, ref_cross_imgs): return featmaps -class CrossScoreLightningModule(lightning.LightningModule): - def __init__(self, cfg: DictConfig): - super().__init__() - self.cfg = cfg - - # write config to wandb - self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) - - # init my network - self.model = CrossScoreNet(cfg=self.cfg) - - # init visualiser - self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() +def load_model(ckpt_path: str, device: str = "cpu") -> CrossScoreNet: + """Load a CrossScoreNet model from a Lightning or direct checkpoint. - # init loss fn - if self.cfg.model.loss.fn == "l1": - self.loss_fn = torch.nn.L1Loss() - self.to_psnr_fn = abs2psnr - else: - raise NotImplementedError - - # logging related names - self.ref_mode_names = [] - if self.cfg.model.do_reference_cross: - self.ref_mode_names.append("ref_cross") + Args: + ckpt_path: Path to the .ckpt file. + device: Device to load the model on. - def on_fit_start(self): - # reset logging cache - if self.global_rank == 0: - self._reset_logging_cache_train() - self._reset_logging_cache_validation() + Returns: + CrossScoreNet model in eval mode. + """ + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) - self.frame_score_summariser = SummaryWriterPredictedOnline( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - ) + # Extract config from checkpoint (saved by Lightning's save_hyperparameters) + if "hyper_parameters" in checkpoint: + cfg = OmegaConf.create(checkpoint["hyper_parameters"]) + else: + # Fallback: use default config + from pathlib import Path - def on_test_start(self): - Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.test.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) - else: - self.batch_writer = None + config_dir = Path(__file__).parent.parent / "config" + model_cfg = OmegaConf.load(config_dir / "model" / "model.yaml") + cfg = OmegaConf.create({"model": model_cfg}) - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.test.out_dir, - ) + model = CrossScoreNet(cfg) - def on_predict_start(self): - Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) + # Handle Lightning checkpoint format (keys prefixed with "model.") + state_dict = checkpoint.get("state_dict", checkpoint) + new_state_dict = {} + for k, v in state_dict.items(): + # Strip "model." prefix from Lightning checkpoint keys + if k.startswith("model."): + new_state_dict[k[6:]] = v else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.predict.out_dir, - ) - - def _reset_logging_cache_train(self): - self.train_cache = { - "loss": { - k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "map": { - "score": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "l1_diff": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "delta": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["self", "cross"] - }, - }, - } - - def _reset_logging_cache_validation(self): - self.validation_cache = { - "loss": { - k: MetricLoggerScalar(max_length=None) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names - }, - "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, - } - - def _core_step(self, batch, batch_idx, skip_loss=False): - outputs = self.model( - query_img=batch["query/img"], # (B, C, H, W) - ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) - need_attn_weights=self.cfg.model.need_attn_weights, - need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, - norm_img=False, - ) - - if skip_loss: # only used in predict_step - return outputs - - score_map = batch["query/score_map"] # (B, H, W) - - loss = [] - # cross reference model predicts - if self.cfg.model.do_reference_cross: - score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) - l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) - if self.cfg.model.loss.fn == "l1": - loss_cross = l1_diff_map_cross.mean() - else: - loss_cross = self.loss_fn(score_map_cross, score_map) - outputs["loss_cross"] = loss_cross - outputs["l1_diff_map_ref_cross"] = l1_diff_map_cross - loss.append(loss_cross) - - loss = torch.stack(loss).sum() - outputs["loss"] = loss - return outputs - - def training_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def validation_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def test_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs - - def predict_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx, skip_loss=True) - return outputs - - @rank_zero_only - def on_train_batch_end(self, outputs, batch, batch_idx): - self.train_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.train_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) - self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) - - # logger vis batch - if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: - fig = self.visualiser.vis(batch, outputs) - self.logger.experiment.log({"train_batch": fig}) - - # logger vis X batches statics - if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: - # log loss - tmp_loss = self.train_cache["loss"]["final"].compute() - self.log("train/loss", tmp_loss, prog_bar=True) - - if self.cfg.model.do_reference_cross: - tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() - self.log("train/loss_cross", tmp_loss_cross) - - # log psnr - if self.cfg.model.do_reference_cross: - self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) - - # log correlation - if self.cfg.model.do_reference_cross: - self.log( - "train/correlation_cross", - self.train_cache["correlation"]["ref_cross"].compute(), - ) - - # logger vis X batches histogram - if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: - if self.cfg.model.do_reference_cross: - import wandb + new_state_dict[k] = v - self.logger.experiment.log( - { - "train/score_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() - ), - "train/l1_diff_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() - ), - } - ) - - def on_validation_batch_end(self, outputs, batch, batch_idx): - self.validation_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.validation_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - if batch_idx < self.cfg.logger.cache_size.validation.n_fig: - fig = self.visualiser.vis(batch, outputs) - self.validation_cache["fig"]["batch"].update(fig) - - def on_test_batch_end(self, outputs, batch, batch_idx): - results = {"test/loss": outputs["loss"]} - - if self.cfg.model.do_reference_cross: - corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) - psnr = self.to_psnr_fn(outputs["loss_cross"]) - results["test/loss_cross"] = outputs["loss_cross"] - results["test/corr_cross"] = corr - results["test/psnr_cross"] = psnr - - self.log_dict( - results, - on_step=self.cfg.logger.test.on_step, - sync_dist=self.cfg.logger.test.sync_dist, - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.test.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.test.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - def on_predict_batch_end(self, outputs, batch, batch_idx): - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.predict.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - @rank_zero_only - def on_train_epoch_end(self): - self._reset_logging_cache_train() - - def on_validation_epoch_end(self): - sync_dist = True - self.log( - "validation/loss", - self.validation_cache["loss"]["final"].compute(), - prog_bar=True, - sync_dist=sync_dist, - ) - self.logger.experiment.log( - {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, - ) - - if self.cfg.model.do_reference_cross: - self.log( - "validation/loss_cross", - self.validation_cache["loss"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/correlation_cross", - self.validation_cache["correlation"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/psnr_cross", - self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), - sync_dist=sync_dist, - ) - - self._reset_logging_cache_validation() - self.frame_score_summariser.reset() - - def on_test_epoch_end(self): - self.frame_score_summariser.summarise() - - def on_predict_epoch_end(self): - self.frame_score_summariser.summarise() - - def configure_optimizers(self): - # how to use configure_optimizers: - # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers - - # we freeze backbone and we only pass parameters that requires grad to optimizer: - # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 - # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor - # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 - parameters = [p for p in self.model.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW( - params=parameters, - lr=self.cfg.trainer.optimizer.lr, - ) - lr_scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, - step_size=self.cfg.trainer.lr_scheduler.step_size, - gamma=self.cfg.trainer.lr_scheduler.gamma, - ) - - results = { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": self.cfg.trainer.lr_scheduler.step_interval, - "frequency": 1, - }, - } - return results + model.load_state_dict(new_state_dict, strict=False) + model.eval() + model.to(device) + return model diff --git a/crossscore/utils/data_processing/split_gaussian_processed.py b/crossscore/utils/data_processing/split_gaussian_processed.py deleted file mode 100644 index 20734f9..0000000 --- a/crossscore/utils/data_processing/split_gaussian_processed.py +++ /dev/null @@ -1,134 +0,0 @@ -import argparse -import os -import json -from pathlib import Path -from pprint import pprint -import numpy as np - - -def split_list_by_ratio(list_input, ratio_dict): - # Check if the sum of ratios is close to 1 - if not 0.999 < sum(ratio_dict.values()) < 1.001: - raise ValueError("The sum of the ratios must be close to 1") - - total_length = len(list_input) - lengths = {k: int(v * total_length) for k, v in ratio_dict.items()} - - # Adjust the last split to include any rounding difference - last_split_name = list(ratio_dict.keys())[-1] - lengths[last_split_name] = total_length - sum(lengths.values()) + lengths[last_split_name] - - # Split the list - split_lists = {} - start = 0 - for split_name, length in lengths.items(): - split_lists[split_name] = list_input[start : start + length] - start += length - - split_lists = {k: v.tolist() for k, v in split_lists.items()} - return split_lists - - -if __name__ == "__main__": - np.random.seed(1234) - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_path", - type=str, - default="~/projects/mview/storage/scratch_dataset/gaussian-splatting-processed/RealEstate200/res_540", - ) - parser.add_argument("--min_seq_len", type=int, default=2, help="keep seq with len(imgs) >= 2") - parser.add_argument("--min_psnr", type=float, default=10.0, help="keep seq with psnr >= 10.0") - parser.add_argument( - "--split_ratio", - nargs="+", - type=float, - default=[0.8, 0.1, 0.1], - help="train/val/test split ratio", - ) - args = parser.parse_args() - - data_path = Path(args.data_path).expanduser() - log_files = sorted([f for f in os.listdir(data_path) if f.endswith(".log")]) - - # get low psnr scenes from log files - scene_all = [] - scene_low_psnr = {} - for log_f in log_files: - with open(data_path / log_f, "r") as f: - lines = f.readlines() - for line in lines: - # assume scene name printed a few lines before PSNR - if "Output folder" in line: - scene_name = line.split("Output folder: ")[1].split("/")[-1] - scene_name = scene_name.removesuffix("\n") - elif "[ITER 7000] Evaluating train" in line: - psnr = line.split("PSNR ")[1] - psnr = psnr.removesuffix("\n") - psnr = float(psnr) - - scene_all.append(scene_name) - if psnr < args.min_psnr: - scene_low_psnr[scene_name] = psnr - else: - pass - - # get low seq length scenes from data folders - gaussian_splits = ["train", "test"] - scene_low_length = {} - for scene_name in scene_all: - for gs_split in gaussian_splits: - tmp_dir = data_path / scene_name / gs_split / "ours_1000" / "gt" - num_img = len(os.listdir(tmp_dir)) - if num_img < args.min_seq_len: - scene_low_length[scene_name] = num_img - - num_scene_total_after_gaussian = len(scene_all) - num_scene_low_psnr = len(scene_low_psnr) - num_scene_low_length = len(scene_low_length) - - # filter out low psnr scenes - scene_all = [s for s in scene_all if s not in scene_low_psnr.keys()] - num_scene_total_filtered_low_psnr = len(scene_all) - - # filter out low seq length scenes - scene_all = [s for s in scene_all if s not in scene_low_length.keys()] - num_scene_total_filtered_low_length = len(scene_all) - - # split train/val/test - scene_all = np.random.permutation(scene_all) - num_scene_after_all_filtering = len(scene_all) - ratio = { - "train": args.split_ratio[0], - "val": args.split_ratio[1], - "test": args.split_ratio[2], - } - scene_split_info = split_list_by_ratio(scene_all, ratio) - num_scene_train = len(scene_split_info["train"]) - num_scene_val = len(scene_split_info["val"]) - num_scene_test = len(scene_split_info["test"]) - num_scene_after_split = sum([len(v) for v in scene_split_info.values()]) - assert num_scene_after_split == num_scene_after_all_filtering - - # save to json - stats = { - "min_psnr": args.min_psnr, - "min_seq_len": args.min_seq_len, - "split_ratio": args.split_ratio, - "num_scene_total_after_gaussian": num_scene_total_after_gaussian, - "num_scene_low_psnr": num_scene_low_psnr, - "num_scene_low_length": num_scene_low_length, - "num_scene_total_filtered_low_psnr": num_scene_total_filtered_low_psnr, - "num_scene_total_filtered_low_length": num_scene_total_filtered_low_length, - "num_scene_after_all_filtering": num_scene_after_all_filtering, - "num_scene_train": num_scene_train, - "num_scene_val": num_scene_val, - "num_scene_test": num_scene_test, - "num_scene_after_split": num_scene_after_split, - } - - pprint(stats, sort_dicts=False) - out_dict = {"stats": stats, **scene_split_info} - - with open(data_path / "split.json", "w") as f: - json.dump(out_dict, f, indent=2) diff --git a/crossscore/utils/evaluation/__init__.py b/crossscore/utils/evaluation/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/crossscore/utils/evaluation/metric.py b/crossscore/utils/evaluation/metric.py deleted file mode 100644 index 82e47ed..0000000 --- a/crossscore/utils/evaluation/metric.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np -import torch - - -def psnr(a, b, return_map=False): - mse_map = torch.nn.functional.mse_loss(a, b, reduction="none") - psnr_map = -10 * torch.log10(mse_map) - if return_map: - return psnr_map - else: - return psnr_map.mean() - - -def mse2psnr(a): - return -10 * torch.log10(a) - - -def abs2psnr(a): - return -10 * torch.log10(a.pow(2)) - - -def psnr2mse(a): - return 10 ** (-a / 10) - - -def correlation(a, b): - x = torch.stack([a.flatten(), b.flatten()], dim=0) # (2, N) - corr = x.corrcoef() # (2, 2) - corr = corr[0, 1] # only this one is meaningful - return corr diff --git a/crossscore/utils/evaluation/metric_logger.py b/crossscore/utils/evaluation/metric_logger.py deleted file mode 100644 index e1cf202..0000000 --- a/crossscore/utils/evaluation/metric_logger.py +++ /dev/null @@ -1,55 +0,0 @@ -from abc import ABC, abstractmethod -import torch -import numpy as np -from .metric import correlation - - -class MetricLogger(ABC): - def __init__(self, max_length): - self.storage = [] - self.max_length = max_length - - @torch.no_grad() - def update(self, x): - if self.max_length is not None and len(self) >= self.max_length: - self.reset() - self.storage.append(x) - - def reset(self): - self.storage.clear() - - def __len__(self): - return len(self.storage) - - @abstractmethod - def compute(self): - raise NotImplementedError - - -class MetricLoggerScalar(MetricLogger): - @torch.no_grad() - def compute(self, aggregation_fn=torch.mean): - tmp = torch.stack(self.storage) - result = aggregation_fn(tmp) - return result - - -class MetricLoggerHistogram(MetricLogger): - @torch.no_grad() - def compute(self, bins=10, range=None): - tmp = torch.cat(self.storage).cpu().numpy() - result = np.histogram(tmp, bins=bins, range=range) - return result - - -class MetricLoggerCorrelation(MetricLoggerScalar): - @torch.no_grad() - def update(self, a, b): - corr = correlation(a, b) - super().update(corr) - - -class MetricLoggerImg(MetricLogger): - @torch.no_grad() - def compute(self): - return self.storage diff --git a/crossscore/utils/evaluation/summarise_score_gt.py b/crossscore/utils/evaluation/summarise_score_gt.py deleted file mode 100644 index 28a64e5..0000000 --- a/crossscore/utils/evaluation/summarise_score_gt.py +++ /dev/null @@ -1,42 +0,0 @@ -from argparse import ArgumentParser -import sys -from pathlib import Path - -from crossscore.utils.io.score_summariser import SummaryWriterGroundTruth - - -def parse_args(): - parser = ArgumentParser(description="Summarise the ground truth results.") - parser.add_argument( - "--dir_in", - type=str, - default="datadir/processed_training_ready/gaussian/map-free-reloc/res_540", - help="The ground truth data dir that contains scene dirs.", - ) - parser.add_argument( - "--dir_out", - type=str, - default="~/projects/mview/storage/scratch_dataset/score_summary", - help="The output directory to save the summarised results.", - ) - parser.add_argument( - "--fast_debug", - type=int, - default=-1, - help="num batch to load for debug. Set to -1 to disable", - ) - parser.add_argument("-n", "--num_workers", type=int, default=16) - parser.add_argument("-f", "--force", type=eval, default=False, choices=[True, False]) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - summariser = SummaryWriterGroundTruth( - dir_in=args.dir_in, - dir_out=args.dir_out, - num_workers=args.num_workers, - fast_debug=args.fast_debug, - force=args.force, - ) - summariser.write_csv() diff --git a/crossscore/utils/io/batch_writer.py b/crossscore/utils/io/batch_writer.py deleted file mode 100644 index 6442dba..0000000 --- a/crossscore/utils/io/batch_writer.py +++ /dev/null @@ -1,270 +0,0 @@ -import json -from pathlib import Path -from PIL import Image -import numpy as np -from crossscore.utils.io.images import metric_map_write, u8 -from crossscore.utils.misc.image import gray2rgb, attn2rgb, de_norm_img - - -def get_vrange(predict_metric_type, predict_metric_min, predict_metric_max): - # Using uint16 when write in gray, normalise to the intrinsic range - # based on score type, regardless of the model prediction range - if predict_metric_type == "ssim": - vrange_intrinsic = [-1, 1] - elif predict_metric_type in ["mse", "mae"]: - vrange_intrinsic = [0, 1] - else: - raise ValueError(f"metric_type {predict_metric_type} not supported") - - # RGB for visualization only, normalise to the model prediction range - vrange_vis = [predict_metric_min, predict_metric_max] - return vrange_intrinsic, vrange_vis - - -class BatchWriter: - """Write batch outputs to disk.""" - - def __init__(self, cfg, phase: str, img_mean_std): - if phase not in ["test", "predict"]: - raise ValueError( - f"Phase {phase} not supported. Has to be a Lightening phase test/predict." - ) - self.cfg = cfg - self.phase = phase - self.img_mean_std = img_mean_std - - self.out_dir = Path(self.cfg.logger[phase].out_dir) - self.write_config = self.cfg.logger[phase].write.config - self.write_flag = self.cfg.logger[phase].write.flag - - # overwrite the flag for attn_weights if the model does not have it - self.write_flag.attn_weights = ( - self.write_flag.attn_weights and self.cfg.model.need_attn_weights - ) - - self.predict_metric_type = self.cfg.model.predict.metric.type - self.predict_metric_min = self.cfg.model.predict.metric.min - self.predict_metric_max = self.cfg.model.predict.metric.max - - self.vrange_intrinsic, self.vrange_vis = get_vrange( - self.predict_metric_type, self.predict_metric_min, self.predict_metric_max - ) - - # prepare out_dirs, create them if required - self.out_dir_dict = {"batch": Path(self.out_dir, "batch")} - if self.write_flag["batch"]: - for k in self.write_flag.keys(): - if k not in ["batch", "score_map_prediction"]: - if self.write_flag[k]: - self.out_dir_dict[k] = Path(self.out_dir_dict["batch"], k) - self.out_dir_dict[k].mkdir(parents=True, exist_ok=True) - - def write_out(self, batch_input, batch_output, local_rank, batch_idx): - if self.write_flag["score_map_prediction"]: - self._write_score_map_prediction( - self.out_dir_dict["batch"], - batch_input, - batch_output, - local_rank, - batch_idx, - ) - - if self.write_flag["score_map_gt"]: - self._write_score_map_gt( - self.out_dir_dict["score_map_gt"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["item_path_json"]: - self._write_item_path_json( - self.out_dir_dict["item_path_json"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["image_query"]: - self._write_query_image( - self.out_dir_dict["image_query"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["image_reference"]: - self._write_reference_image( - self.out_dir_dict["image_reference"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["attn_weights"]: - self._write_attn_weights( - self.out_dir_dict["attn_weights"], - batch_input, - batch_output, - local_rank, - batch_idx, - check_patch_mode="centre", - ) - - def _write_score_map_prediction( - self, out_dir, batch_input, batch_output, local_rank, batch_idx - ): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - score_map_type_list = [k for k in batch_output.keys() if k.startswith("score_map")] - for score_map_type in score_map_type_list: - tmp_out_dir = Path(out_dir, score_map_type) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - - if len(query_img_paths) != len(batch_output[score_map_type]): - raise ValueError("num of query images and score maps are not equal") - - for b, (query_img_p, score_map) in enumerate( - zip(query_img_paths, batch_output[score_map_type]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - self._write_a_score_map_with_colour_mode( - out_path=tmp_out_dir / tmp_out_name, score_map=score_map.cpu().numpy() - ) - - def _write_score_map_gt(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - - if len(query_img_paths) != len(batch_input["query/score_map"]): - raise ValueError("num of query images and score maps are not equal") - - for b, (query_img_p, score_map) in enumerate( - zip(query_img_paths, batch_input["query/score_map"]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - self._write_a_score_map_with_colour_mode( - out_path=out_dir / tmp_out_name, score_map=score_map.cpu().numpy() - ) - - def _write_item_path_json(self, out_dir, batch_input, local_rank, batch_idx): - out_path = out_dir / f"r{local_rank}_B{str(batch_idx).zfill(4)}.json" - - # get a deep copy to avoid infering other writing functions - item_paths = batch_input["item_paths"].copy() - for ref_type in ["reference/cross/imgs"]: - if len(item_paths[ref_type]) > 0: - # transpose ref paths to (N_ref, B) - item_paths[ref_type] = np.array(item_paths[ref_type]).T.tolist() - - with open(out_path, "w") as f: - json.dump(item_paths, f, indent=2) - - def _write_query_image(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - - for b, (query_img_p, image_query) in enumerate( - zip(query_img_paths, batch_input["query/img"]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - tmp_out_path = Path(out_dir, tmp_out_name) - image_query = de_norm_img(image_query.permute(1, 2, 0), self.img_mean_std) - image_query = u8(image_query.cpu().numpy()) - Image.fromarray(image_query).save(tmp_out_path) - - def _write_reference_image(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - for ref_type in ["reference/cross/imgs"]: - if len(batch_input["item_paths"][ref_type]) > 0: - ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) - - # create a subfolder for each query image to store its ref images - for b, query_img_p in enumerate(query_img_paths): - tmp_out_dir = Path( - out_dir, - f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", - ref_type.split("/")[1], - ) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - tmp_ref_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in ref_img_paths[b] # (N_ref, ) - ] - ref_imgs = batch_input[ref_type][b] # (N_ref, C, H, W) - for ref_idx, (ref_img_p, ref_img) in enumerate( - zip(tmp_ref_img_paths, ref_imgs) - ): - tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" - tmp_out_path = Path(tmp_out_dir, tmp_out_name) - ref_img = de_norm_img(ref_img.permute(1, 2, 0), self.img_mean_std) - ref_img = u8(ref_img.cpu().numpy()) - Image.fromarray(ref_img).save(tmp_out_path) - - def _write_attn_weights( - self, out_dir, batch_input, batch_output, local_rank, batch_idx, check_patch_mode - ): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - for ref_type in ["reference/cross/imgs"]: - if len(batch_input["item_paths"][ref_type]) > 0: - ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) - ref_type_short = ref_type.split("/")[1] - - # create a subfolder for each query image to store its ref images - for b, query_img_p in enumerate(query_img_paths): - tmp_out_dir = Path( - out_dir, - f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", - ref_type_short, - ) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - tmp_ref_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in ref_img_paths[b] # (N_ref, ) - ] - - # get attn maps of the centre patch in the query image - # (B, H, W, N_ref, H, W) - attn_weights_map = batch_output[f"attn_weights_map_ref_{ref_type_short}"] - tmp_h, tmp_w = attn_weights_map.shape[1:3] - - if check_patch_mode == "centre": - query_patch = (tmp_h // 2, tmp_w // 2) - elif check_patch_mode == "random": - query_patch = ( - np.random.randint(0, tmp_h), - np.random.randint(0, tmp_w), - ) - else: - raise ValueError(f"Unknown check_patch_mode: {check_patch_mode}") - attn_weights_map = attn_weights_map[b][query_patch] # (N_ref, H, W) - - # write attn maps - for ref_idx, (ref_img_p, attn_m) in enumerate( - zip(tmp_ref_img_paths, attn_weights_map) - ): - tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" - tmp_out_path = Path(tmp_out_dir, tmp_out_name) - attn_m = attn2rgb(attn_m.cpu().numpy()) # (H, W, 3) - Image.fromarray(attn_m).save(tmp_out_path) - - def _write_a_score_map_with_colour_mode(self, out_path, score_map): - if self.write_config.score_map_colour_mode == "gray": - metric_map_write(out_path, score_map, self.vrange_intrinsic) - elif self.write_config.score_map_colour_mode == "rgb": - rgb = gray2rgb(score_map, self.vrange_vis) - Image.fromarray(rgb).save(out_path) - else: - raise ValueError(f"colour_mode {self.write_config.score_map_colour_mode} not supported") diff --git a/crossscore/utils/io/score_summariser.py b/crossscore/utils/io/score_summariser.py deleted file mode 100644 index 9b1d49b..0000000 --- a/crossscore/utils/io/score_summariser.py +++ /dev/null @@ -1,315 +0,0 @@ -from pathlib import Path -import os -from glob import glob - -import torch -from torch.utils.data import Dataset, DataLoader -import numpy as np -import pandas as pd -from pandas import DataFrame -from tqdm import tqdm - -from crossscore.utils.io.images import metric_map_read -from crossscore.utils.evaluation.metric import mse2psnr - - -class ScoreReader(Dataset): - def __init__(self, score_map_dir_list): - # get all paths to read - read_score_types = ["ssim", "mae"] - self.read_paths_all = {k: [] for k in read_score_types} - for score_read_type in read_score_types: - for score_map_dir in score_map_dir_list: - tmp_dir = os.path.join(score_map_dir, score_read_type) - tmp_paths = [os.path.join(tmp_dir, n) for n in sorted(os.listdir(tmp_dir))] - self.read_paths_all[score_read_type].extend(tmp_paths) - - # (N_frames, 2) - self.read_paths_all = np.stack([self.read_paths_all[k] for k in read_score_types], axis=1) - - def __len__(self): - return len(self.read_paths_all) - - def __getitem__(self, idx): - path_ssim, path_mae = self.read_paths_all[idx] - ssim_map = metric_map_read(path_ssim, vrange=[-1, 1]) - mae_map = metric_map_read(path_mae, vrange=[0, 1]) - mse_map = np.square(mae_map) - - score_ssim_n11 = ssim_map.mean() - score_ssim_01 = ssim_map.clip(0, 1).mean() - score_mae = mae_map.mean() - score_mse = mse_map.mean() - score_psnr = mse2psnr(torch.tensor([score_mse])).numpy() - - results = { - "ssim_-1_1": score_ssim_n11, - "ssim_0_1": score_ssim_01, - "mae": score_mae, - "mse": score_mse, - "psnr": score_psnr, - "path_ssim": path_ssim, - } - return results - - -class SummaryWriterGroundTruth: - """ - Load the ground truth results from disk and summarise the results. - """ - - def __init__(self, dir_in, dir_out, num_workers, fast_debug, force): - self.dir_in = Path(dir_in).expanduser() - self.dir_out = Path(dir_out).expanduser() - self.num_workers = num_workers - self.fast_debug = fast_debug - self.force = force - - self.dataset_type = self.dir_in.parent.name - self.rendering_method = self.dir_in.parents[1].name - self.csv_dir = self.dir_out / self.dataset_type - self.csv_path = self.csv_dir / f"{self.rendering_method}.csv" - self.csv_dir.mkdir(parents=True, exist_ok=True) - self.rows = [] - self.columns = [ - "scene_name", - "rendered_dir", - "image_name", - "gt_ssim_-1_1", - "gt_ssim_0_1", - "gt_mae", - "gt_mse", - "gt_psnr", - ] - - def write_csv(self): - write = self._check_write(self.csv_path, self.force) - if write: - rows = self._load_per_frame_score() - df = DataFrame(data=rows, columns=self.columns) - df.to_csv(self.csv_path, index=False, float_format="%.4f") - - def _check_write(self, csv_path, force): - if csv_path.exists(): - if force: - csv_path.unlink() - print(f"Write to csv {csv_path} (OVERWRITE)") - write = True - else: - print(f"Write to csv {csv_path} (SKIP)") - write = False - else: - print(f"Write to csv {csv_path} (NORMAL)") - write = True - return write - - def _load_per_frame_score(self): - # use glob to find all dir named "metric_map" in the scene dirs - score_map_dir_list = sorted(glob(str(self.dir_in / "**/metric_map"), recursive=True)) - score_reader = ScoreReader(score_map_dir_list) - score_loader = DataLoader( - dataset=score_reader, - batch_size=16, - shuffle=False, - num_workers=self.num_workers, - ) - - # process score maps to csv rows, each row contains the following columns - rows = [] - for i, data in enumerate(tqdm(score_loader, desc=f"Loading gt scores", dynamic_ncols=True)): - for j in range(len(data["path_ssim"])): - path_ssim = data["path_ssim"][j] - scene_name = path_ssim.split("/")[-6] - rendered_dir = os.path.join(*path_ssim.split("/")[:-3]) - image_name = path_ssim.split("/")[-1] - image_name = image_name.replace("frame_", "") - tmp_row = [ - scene_name, - rendered_dir, - image_name, - data["ssim_-1_1"][j].item(), - data["ssim_0_1"][j].item(), - data["mae"][j].item(), - data["mse"][j].item(), - data["psnr"][j].item(), - ] - rows.append(tmp_row) - if self.fast_debug > 0 and i >= self.fast_debug: - break - return rows - - -class SummaryWriterPredictedOnline: - """ - Used in at the fit/test/predict phase with Lightning Module. - """ - - def __init__(self, metric_type, metric_min): - metric_type_str = self._get_metric_type_str(metric_type, metric_min) - self.columns = [ - "scene_name", - "rendered_dir", - "image_name", - f"pred_{metric_type_str}", - ] - self.reset() - - def _get_metric_type_str(self, metric_type, metric_min): - if metric_type == "ssim": - if metric_min == -1: - metric_str = f"{metric_type}_-1_1" - elif metric_min == 0: - metric_str = f"{metric_type}_0_1" - else: - metric_str = f"{metric_type}" - return metric_str - - def reset(self): - self.rows = DataFrame(columns=self.columns) - - def update(self, batch_input, batch_output): - """Store per frame scores to rows for each batch.""" - query_img_paths = batch_input["item_paths"]["query/img"] # (B,) - ref_types = [t for t in batch_output.keys() if t.startswith("score_map")] - - if len(ref_types) != 1: - raise ValueError(f"Expect exactly one ref_type: self/cross, but got {ref_types}.") - - rows_batch = [] - for ref_type in ref_types: - score_maps = batch_output[ref_type] # (B, H, W) - scores = score_maps.mean(dim=[-1, -2]) # (B,) - - scene_names = [p.split("/")[-5] for p in query_img_paths] - - rendered_dirs = [os.path.join(*p.split("/")[:-2]) for p in query_img_paths] - image_names = [p.split("/")[-1] for p in query_img_paths] - image_names = [n.replace("frame_", "") for n in image_names] - - for i in range(len(scene_names)): - rows_batch.append( - [scene_names[i], rendered_dirs[i], image_names[i], scores[i].item()] - ) - - # concat rows to panda dataframe - self.rows = pd.concat([self.rows, DataFrame(rows_batch, columns=self.columns)]) - - def summarise(self): - """Organise rows using dataset type and rendering method and sort them.""" - - # get unique rendering methods - rendering_method_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-6]).unique() - - # get unique dataset types - dataset_type_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-5]).unique() - - # organise rows using dataset type and rendering method - self.summary = {} - for dataset_type in dataset_type_list: - self.summary[dataset_type] = {} - for rendering_method in rendering_method_list: - # get rows with the same dataset type and rendering method - tmp_rows = self.rows[ - (self.rows["rendered_dir"].str.contains(rendering_method)) - & (self.rows["rendered_dir"].str.contains(dataset_type)) - ] - - # sort rows by scene_name, rendered_dir, image_name - tmp_rows = tmp_rows.sort_values(by=["scene_name", "rendered_dir", "image_name"]) - - self.summary[dataset_type][rendering_method] = tmp_rows - - def __len__(self): - return len(self.rows) - - def __repr__(self): - return self.rows.__repr__() - - -class SummaryWriterPredictedOnlineTestPrediction(SummaryWriterPredictedOnline): - """ - Used in at the end of validation/test/predict phase. - Summarising with online predicted results to avoid reading from disks. - """ - - def __init__(self, metric_type, metric_min, dir_out): - super().__init__(metric_type, metric_min) - self.csv_dir = Path(dir_out).expanduser() / "score_summary" - self.csv_dir.mkdir(parents=True, exist_ok=True) - self.cache_csv_path = self.csv_dir / f"summarise_cache.csv" - - def summarise(self): - super().summarise() - - # write to csv files - for dataset_type, dataset_summary in self.summary.items(): - for rendering_method, rows in dataset_summary.items(): - tmp_csv_dir = self.csv_dir / dataset_type - tmp_csv_dir.mkdir(parents=True, exist_ok=True) - tmp_csv_path = tmp_csv_dir / f"{rendering_method}.csv" - rows.to_csv(tmp_csv_path, index=False, float_format="%.4f") - - -class SummaryReader: - @staticmethod - def read_summary(summary_dir, dataset, method_list, scene_list, split_list, iter_list): - summary_dir = Path(summary_dir).expanduser() - summary_dir = summary_dir / dataset - - methods_available = [f.stem for f in summary_dir.iterdir() if f.is_file()] - methods_to_read = [] - if method_list != [""]: - for m in method_list: - if m in methods_available: - methods_to_read.append(m) - else: - raise ValueError(f"{m} is not available in {summary_dir}") - else: - methods_to_read = methods_available - - summary_files = [summary_dir / f"{m}.csv" for m in methods_to_read] - - # read csv files, create a new colume as 0th column for method_name - summary = pd.concat( - [pd.read_csv(f).assign(method_name=m) for f, m in zip(summary_files, methods_to_read)] - ) - - # filter with scene list using summary's scene_name column - if scene_list != [""]: - summary = summary[summary["scene_name"].isin(scene_list)] - - # filter split using rendered_dir column - if split_list != [""]: - new_s = [] - for split in split_list: - tmp_s = summary[summary["rendered_dir"].str.split("/").str[-2] == split] - new_s.append(tmp_s) - summary = pd.concat(new_s) - - # filter iter using rendered_dir column, the last part of the path should be EXACTLY "ours_{iter}" - if len(iter_list) > 0: - new_s = [] - for i in iter_list: - tmp_s = summary[summary["rendered_dir"].str.endswith(f"ours_{i}")] - new_s.append(tmp_s) - summary = pd.concat(new_s) - - # sort by scene_name, rendered_dir, image_name, method_name - summary = summary.sort_values(["scene_name", "rendered_dir", "image_name", "method_name"]) - - # reset index - summary = summary.reset_index(drop=True) - return summary - - @staticmethod - def check_summary_gt_prediction_rows(summary_gt, summary_prediction): - # these tow dataframes should have the same length - # and the columns [rendered_dir, image_name] should be identical - if len(summary_gt) != len(summary_prediction): - raise ValueError("Summary GT and prediction have different length") - - if not summary_gt["rendered_dir"].equals(summary_prediction["rendered_dir"]): - raise ValueError("Summary GT and prediction have different rendered_dir") - - if not summary_gt["image_name"].equals(summary_prediction["image_name"]): - raise ValueError("Summary GT and prediction have different image_name") diff --git a/crossscore/utils/plot/__init__.py b/crossscore/utils/plot/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/crossscore/utils/plot/batch_visualiser.py b/crossscore/utils/plot/batch_visualiser.py deleted file mode 100644 index b037e30..0000000 --- a/crossscore/utils/plot/batch_visualiser.py +++ /dev/null @@ -1,416 +0,0 @@ -from pathlib import Path -from abc import ABC, abstractmethod -import numpy as np -import torch -import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle -from torchvision.utils import make_grid -from PIL import Image -from crossscore.utils.io.images import u8 -from crossscore.utils.misc.image import de_norm_img -from crossscore.utils.check_config import check_reference_type - - -class BatchVisualiserBase(ABC): - """Organise batch data and preview in a single image.""" - - def __init__(self, cfg, img_mean_std, ref_type): - self.cfg = cfg - self.img_mean_std = img_mean_std.cpu().numpy() - self.ref_type = ref_type - self.metric_type = cfg.model.predict.metric.type - self.metric_min = cfg.model.predict.metric.min - self.metric_max = cfg.model.predict.metric.max - - @abstractmethod - def vis(self, **kwargs): - raise NotImplementedError - - -class BatchVisualiserRef(BatchVisualiserBase): - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - ref_type = self.ref_type - - fig = plt.figure(figsize=(6, 6.5), layout="constrained") - # plt.style.use("dark_background") - # fig.set_facecolor("gray") - - # arrange subfigure - N_ref_imgs = batch_input[f"reference/{ref_type}/imgs"].shape[1] - h_ratios_subfig = [3, np.ceil(N_ref_imgs / 5)] - fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) - - # arrange subplot for top fig - inner = [ - ["score_map/gt", f"score_map/ref_{ref_type}"], - ] - - mosaic_top = [ - ["query", inner], - ] - width_ratio_top = [3, 2] - ax_dict_top = fig_top.subplot_mosaic( - mosaic_top, - empty_sentinel=None, - width_ratios=width_ratio_top, - ) - - if "item_paths" in batch_input.keys(): - # 1. anonymize the path - # 2. split to two rows - # 3. add query path to title - query_path = batch_input["item_paths"]["query/img"][vis_id] - query_path = query_path.replace(str(Path("~").expanduser()), "~/") - query_path = query_path.split("/") - query_path.insert(len(query_path) // 2, "\n") - query_path = Path(*query_path) - fig_top.suptitle(f"{query_path}", fontsize=7) - - # arrange subplot for bottom fig - mosaic_bottom = [[f"ref/{ref_type}/imgs"]] - height_ratios = [1] - ax_dict_bottom = fig_bottom.subplot_mosaic( - mosaic_bottom, - empty_sentinel=None, - height_ratios=height_ratios, - ) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - img_dict[f"ref/{ref_type}/imgs"] = make_grid( - batch_input[f"reference/{ref_type}/imgs"][vis_id], nrow=5 - ) # 3HW - img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW - img_dict[f"score_map/ref_{ref_type}"] = batch_output[f"score_map_ref_{ref_type}"][ - vis_id - ] # HW - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict_top.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict_top[k].imshow( - tmp_img, - vmin=self.metric_min, - vmax=self.metric_max, - cmap="turbo", - ) - title_txt = k.replace("score", self.metric_type) - title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") - title_txt = title_txt + f"\n{tmp_img.mean():.3f}" - if k == f"score_map/ref_{ref_type}": - if f"l1_diff_map_ref_{ref_type}" in batch_output.keys(): - loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() - title_txt = title_txt + f" ∆{loss:.2f}" - ax_dict_top[k].set_title(title_txt, fontsize=8) - elif k.startswith("delta_map"): - tmp_img = v.cpu().numpy() - # ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="turbo") - ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="bwr") - ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_top[k].imshow(tmp_img) - ax_dict_top[k].set_title(k) - ax_dict_top[k].set_xticks([]) - ax_dict_top[k].set_yticks([]) - - # add a colorbar, turn off ticks - ax_colorbar = fig_top.colorbar( - plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict_top["query"], fraction=0.046, pad=0.04 - ) - ax_colorbar.ax.set_yticklabels([]) - - # plot bottom fig - for k, v in img_dict.items(): - if k not in ax_dict_bottom.keys(): - continue - if k.startswith("attn_weights_map"): - # tmp_img = v.permute(1, 2, 0) # HW3 - tmp_img = v[0] # HW - tmp_img = tmp_img.clamp(0, 1) - tmp_img = tmp_img.cpu().numpy() - ax_dict_bottom[k].imshow(tmp_img, vmin=0, cmap="turbo") - ax_dict_bottom[k].set_title(k) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_bottom[k].imshow(tmp_img) - ax_dict_bottom[k].set_title(k) - - ax_dict_bottom[k].set_xticks([]) - ax_dict_bottom[k].set_yticks([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - import wandb - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserRefAttnMap(BatchVisualiserBase): - def __init__(self, cfg, img_mean_std, ref_type, check_patch_mode): - super().__init__(cfg, img_mean_std, ref_type) - self.check_patch_mode = check_patch_mode - - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - fig = plt.figure(figsize=(6, 6.5), layout="constrained") - # plt.style.use("dark_background") - fig.set_facecolor("gray") - - has_reference_cross = "reference/cross/imgs" in batch_input.keys() - has_attn_weights_cross = batch_output.get("attn_weights_map_ref_cross", None) is not None - P = patch_size = 14 - - # arrange subfigure - h_ratios_subfig = [2, 1] - - fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) - - # arrange subplot for top fig - inner = [] - if has_reference_cross: - inner.append(["score_map/gt1", "score_map/ref_cross", "score_map/diff_cross"]) - inner = np.array(inner).T.tolist() # vertical layout - - mosaic_top = [ - ["query", inner], - ] - width_ratio_top = [4, 1] - - ax_dict_top = fig_top.subplot_mosaic( - mosaic_top, - empty_sentinel=None, - width_ratios=width_ratio_top, - ) - - # arrange subplot for bottom fig - mosaic_bottom = [] - if has_reference_cross: - mosaic_bottom.append(["ref/cross/imgs"]) - if has_attn_weights_cross: - mosaic_bottom.append(["attn_weights_map_ref_cross"]) - - height_ratios = [1] * len(mosaic_bottom) - - ax_dict_bottom = fig_bottom.subplot_mosaic( - mosaic_bottom, - empty_sentinel=None, - height_ratios=height_ratios, - ) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - if has_reference_cross: - img_dict["ref/cross/imgs"] = make_grid( - batch_input["reference/cross/imgs"][vis_id], nrow=5 - ) # 3HW - img_dict["score_map/gt1"] = batch_input["query/score_map"][vis_id] - img_dict["score_map/ref_cross"] = batch_output["score_map_ref_cross"][vis_id] - if "l1_diff_map_ref_cross" in batch_output.keys(): - img_dict["score_map/diff_cross"] = batch_output["l1_diff_map_ref_cross"][vis_id] - - if has_attn_weights_cross: - attn_weights_map_ref_cross = batch_output["attn_weights_map_ref_cross"] # BHWNHW - tmp_h, tmp_w = attn_weights_map_ref_cross.shape[1:3] - if self.check_patch_mode == "centre": - query_patch_cross = (tmp_h // 2, tmp_w // 2) - elif self.check_patch_mode == "random": - query_patch_cross = (np.random.randint(0, tmp_h), np.random.randint(0, tmp_w)) - else: - raise ValueError(f"Unknown check_patch_mode: {self.check_patch_mode}") - - # NHW - attn_weights_map_ref_cross = attn_weights_map_ref_cross[vis_id][query_patch_cross] - img_dict["attn_weights_map_ref_cross"] = make_grid( - attn_weights_map_ref_cross[:, None], nrow=5 # make grid expects N3HW - ) - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict_top.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict_top[k].imshow(tmp_img, vmin=0, vmax=1, cmap="turbo") - ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_top[k].imshow(tmp_img) - ax_dict_top[k].set_title(k) - - if has_attn_weights_cross and k == "query": - # draw a rectangle patch - query_pixel = (query_patch_cross[0] * P, query_patch_cross[1] * P) - rect = Rectangle( - (query_pixel[1], query_pixel[0]), - P, - P, - linewidth=2, - edgecolor="magenta", - facecolor="none", - ) - ax_dict_top[k].add_patch(rect) - ax_dict_top[k].set_xticks([]) - ax_dict_top[k].set_yticks([]) - - # plot bottom fig - for k, v in img_dict.items(): - if k not in ax_dict_bottom.keys(): - continue - if k.startswith("attn_weights_map"): - num_stabler = 1e-8 # to avoid log(0) - tmp_img = v[0] # HW - tmp_img = tmp_img.clamp(0, 1) - tmp_img = tmp_img.cpu().numpy() - # invert softmax (exp'd) attn weights - tmp_img = np.log(tmp_img + num_stabler) - np.log(num_stabler) - ax_dict_bottom[k].imshow(tmp_img, vmax=-np.log(num_stabler), cmap="turbo") - ax_dict_bottom[k].set_title(k) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_bottom[k].imshow(tmp_img) - ax_dict_bottom[k].set_title(k) - - ax_dict_bottom[k].set_xticks([]) - ax_dict_bottom[k].set_yticks([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - import wandb - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserRefFree(BatchVisualiserBase): - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - fig = plt.figure(figsize=(6.5, 5), layout="constrained") - # plt.style.use("dark_background") - # fig.set_facecolor("gray") - - ref_type = self.ref_type - inner = [["score_map/gt", f"score_map/ref_{ref_type}", "score_map/diff"]] - inner = np.array(inner).T.tolist() - - mosaic = [ - ["query", inner], - ] - width_ratio = [5, 2] - ax_dict = fig.subplot_mosaic( - mosaic, - empty_sentinel=None, - width_ratios=width_ratio, - ) - - if "item_paths" in batch_input.keys(): - # 1. anonymize the path - # 2. split to two rows - # 3. add query path to title - query_path = batch_input["item_paths"]["query/img"][vis_id] - query_path = query_path.replace(str(Path("~").expanduser()), "~/") - query_path = query_path.split("/") - query_path.insert(len(query_path) // 2, "\n") - query_path = Path(*query_path) - fig.suptitle(f"{query_path}", fontsize=7) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict[k].imshow( - tmp_img, - vmin=self.metric_min, - vmax=self.metric_max, - cmap="turbo", - ) - title_txt = k.replace("score", self.metric_type) - title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") - title_txt = title_txt + f"\n{tmp_img.mean():.3f}" - if k == f"score_map/ref_{ref_type}": - loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() - title_txt = title_txt + f" ∆{loss:.2f}" - ax_dict[k].set_title(title_txt, fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict[k].imshow(tmp_img) - ax_dict[k].set_title(k) - ax_dict[k].set_xticks([]) - ax_dict[k].set_yticks([]) - - # add a colorbar, turn off ticks - ax_colorbar = fig.colorbar( - plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict["query"], fraction=0.046, pad=0.04 - ) - ax_colorbar.ax.set_yticklabels([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - import wandb - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserFactory: - def __init__(self, cfg, img_mean_std): - self.cfg = cfg - self.img_mean_std = img_mean_std - self.ref_type = check_reference_type(self.cfg.model.do_reference_cross) - - if self.ref_type in ["cross"]: - if self.cfg.model.need_attn_weights: - self.visualiser = BatchVisualiserRefAttnMap( - cfg, self.img_mean_std, self.ref_type, check_patch_mode="centre" - ) - else: - self.visualiser = BatchVisualiserRef(cfg, self.img_mean_std, self.ref_type) - else: - raise NotImplementedError - - def __call__(self): - return self.visualiser diff --git a/pyproject.toml b/pyproject.toml index c6fadf4..1a00203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,11 @@ dependencies = [ # Core deep learning (flexible ranges - user installs PyTorch first) "torch>=2.0.0", "torchvision>=0.15.0", - "lightning>=2.0.0", # DINOv2 backbone "transformers>=4.30.0", # Configuration - "hydra-core>=1.3.0", "omegaconf>=2.3.0", # Model download @@ -57,32 +55,12 @@ dependencies = [ "Pillow>=9.0.0", "imageio>=2.20.0", "numpy>=1.22.0", - "scipy>=1.9.0", - "scikit-image>=0.19.0", "matplotlib>=3.5.0", - # Data - "pandas>=1.4.0", + # Progress bar "tqdm>=4.60.0", ] -[project.optional-dependencies] -# For training (additional deps not needed for inference) -train = [ - "wandb>=0.15.0", - "scikit-learn>=1.0.0", - "accelerate>=0.20.0", - "tensorboard>=2.10.0", -] -# Memory-efficient attention (optional, for large images) -xformers = [ - "xformers>=0.0.20", -] -dev = [ - "pytest>=7.0", - "ruff>=0.1.0", -] - [project.urls] Homepage = "https://crossscore.active.vision" Repository = "https://github.com/ActiveVisionLab/CrossScore" diff --git a/task/core.py b/task/core.py index 3eb8c23..99b8fb1 100644 --- a/task/core.py +++ b/task/core.py @@ -1,4 +1,111 @@ -"""Backwards compatibility: re-export from crossscore package.""" -from crossscore.task.core import CrossScoreLightningModule, CrossScoreNet +"""Training-only Lightning module. Not part of the pip package. -__all__ = ["CrossScoreLightningModule", "CrossScoreNet"] +Imports the CrossScoreNet model from the crossscore package and wraps it +in a LightningModule for training/validation/testing. +""" + +from pathlib import Path +import torch +import lightning +from omegaconf import DictConfig, OmegaConf +from lightning.pytorch.utilities import rank_zero_only + +# Model from the pip package +from crossscore.task.core import CrossScoreNet + +# Training-only imports +from crossscore.utils.io.images import ImageNetMeanStd + + +class CrossScoreLightningModule(lightning.LightningModule): + def __init__(self, cfg: DictConfig): + super().__init__() + self.cfg = cfg + + self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) + + # init my network (from pip package) + self.model = CrossScoreNet(cfg=self.cfg) + + # lazy imports for training-only deps + from crossscore.utils.check_config import check_reference_type + + # init loss fn + if self.cfg.model.loss.fn == "l1": + self.loss_fn = torch.nn.L1Loss() + else: + raise NotImplementedError + + # logging related names + self.ref_mode_names = [] + if self.cfg.model.do_reference_cross: + self.ref_mode_names.append("ref_cross") + + def _core_step(self, batch, batch_idx, skip_loss=False): + outputs = self.model( + query_img=batch["query/img"], + ref_cross_imgs=batch.get("reference/cross/imgs", None), + need_attn_weights=self.cfg.model.need_attn_weights, + need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, + norm_img=False, + ) + + if skip_loss: + return outputs + + score_map = batch["query/score_map"] + loss = [] + + if self.cfg.model.do_reference_cross: + score_map_cross = outputs["score_map_ref_cross"] + l1_diff_map_cross = torch.abs(score_map_cross - score_map) + if self.cfg.model.loss.fn == "l1": + loss_cross = l1_diff_map_cross.mean() + else: + loss_cross = self.loss_fn(score_map_cross, score_map) + outputs["loss_cross"] = loss_cross + outputs["l1_diff_map_ref_cross"] = l1_diff_map_cross + loss.append(loss_cross) + + loss = torch.stack(loss).sum() + outputs["loss"] = loss + return outputs + + def training_step(self, batch, batch_idx): + return self._core_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self._core_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._core_step(batch, batch_idx) + + def predict_step(self, batch, batch_idx): + return self._core_step(batch, batch_idx, skip_loss=True) + + @rank_zero_only + def on_train_batch_end(self, outputs, batch, batch_idx): + self.log("train/loss", outputs["loss"], prog_bar=True) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.log("validation/loss", outputs["loss"], prog_bar=True) + + def configure_optimizers(self): + parameters = [p for p in self.model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW( + params=parameters, + lr=self.cfg.trainer.optimizer.lr, + ) + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=self.cfg.trainer.lr_scheduler.step_size, + gamma=self.cfg.trainer.lr_scheduler.gamma, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.cfg.trainer.lr_scheduler.step_interval, + "frequency": 1, + }, + } diff --git a/task/predict.py b/task/predict.py index fe03911..e51552b 100644 --- a/task/predict.py +++ b/task/predict.py @@ -9,7 +9,7 @@ import hydra from omegaconf import DictConfig, open_dict -from crossscore.task.core import CrossScoreLightningModule +from task.core import CrossScoreLightningModule from crossscore.dataloading.dataset.simple_reference import SimpleReference from crossscore.dataloading.transformation.crop import CropperFactory from crossscore.utils.io.images import ImageNetMeanStd diff --git a/task/test.py b/task/test.py index 7eb4eb0..f2c1f6d 100644 --- a/task/test.py +++ b/task/test.py @@ -9,7 +9,7 @@ import hydra from omegaconf import DictConfig, open_dict -from crossscore.task.core import CrossScoreLightningModule +from task.core import CrossScoreLightningModule from crossscore.dataloading.data_manager import get_dataset from crossscore.dataloading.transformation.crop import CropperFactory from crossscore.utils.io.images import ImageNetMeanStd diff --git a/task/train.py b/task/train.py index 44a8aea..bc84188 100644 --- a/task/train.py +++ b/task/train.py @@ -13,7 +13,7 @@ from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from crossscore.task.core import CrossScoreLightningModule +from task.core import CrossScoreLightningModule from crossscore.dataloading.data_manager import get_dataset from crossscore.dataloading.transformation.crop import CropperFactory from crossscore.utils.io.images import ImageNetMeanStd From bc614167ce053ed4697ddbe490b34b66fac3f19f Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 21 Mar 2026 20:09:25 +0000 Subject: [PATCH 4/5] Use GitHub Git LFS for checkpoint download instead of HuggingFace The checkpoint is already hosted via Git LFS in the GitHub repo, so download directly from there using stdlib urllib (no extra deps). Remove huggingface-hub from direct dependencies. URL: https://github.com/ActiveVisionLab/CrossScore/raw/main/ckpt/CrossScore-v1.0.0.ckpt https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB --- crossscore/_download.py | 58 ++++++++++++++++++++--------------------- pyproject.toml | 3 --- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/crossscore/_download.py b/crossscore/_download.py index 134fa5d..f157e90 100644 --- a/crossscore/_download.py +++ b/crossscore/_download.py @@ -1,10 +1,13 @@ """Utilities for downloading CrossScore model checkpoints.""" import os +import urllib.request +import shutil from pathlib import Path +# Download directly from GitHub (served via Git LFS) CHECKPOINT_URL = ( - "https://huggingface.co/ActiveVisionLab/CrossScore/resolve/main/CrossScore-v1.0.0.ckpt" + "https://github.com/ActiveVisionLab/CrossScore/raw/main/ckpt/CrossScore-v1.0.0.ckpt" ) CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt" @@ -19,9 +22,10 @@ def get_cache_dir() -> Path: def get_checkpoint_path() -> str: """Get path to the CrossScore checkpoint, downloading it if necessary. - Downloads from HuggingFace Hub on first use and caches locally. - Set CROSSSCORE_CACHE_DIR environment variable to customize cache location. - Set CROSSSCORE_CKPT_PATH to use a specific local checkpoint file. + Downloads from GitHub (Git LFS) on first use and caches locally at + ~/.cache/crossscore/. Set environment variables to customize: + CROSSSCORE_CKPT_PATH - use a specific local checkpoint file + CROSSSCORE_CACHE_DIR - custom cache directory Returns: Path to the checkpoint file. @@ -39,33 +43,29 @@ def get_checkpoint_path() -> str: if ckpt_path.exists(): return str(ckpt_path) - print(f"Downloading CrossScore checkpoint to {ckpt_path}...") - print(f" Source: {CHECKPOINT_URL}") - print(" (Set CROSSSCORE_CKPT_PATH to use a local checkpoint instead)") + print(f"Downloading CrossScore checkpoint (~129MB)...") + print(f" From: {CHECKPOINT_URL}") + print(f" To: {ckpt_path}") + print(" (Set CROSSSCORE_CKPT_PATH to skip download and use a local file)") + tmp_path = str(ckpt_path) + ".tmp" try: - from huggingface_hub import hf_hub_download + urllib.request.urlretrieve(CHECKPOINT_URL, tmp_path, _download_progress) + os.rename(tmp_path, str(ckpt_path)) + except Exception: + if os.path.exists(tmp_path): + os.remove(tmp_path) + raise - downloaded_path = hf_hub_download( - repo_id="ActiveVisionLab/CrossScore", - filename=CHECKPOINT_FILENAME, - local_dir=str(cache_dir), - ) - return downloaded_path - except ImportError: - # Fallback to urllib if huggingface_hub not installed - import urllib.request - import shutil + print(f"\n Download complete.") + return str(ckpt_path) - tmp_path = str(ckpt_path) + ".tmp" - try: - with urllib.request.urlopen(CHECKPOINT_URL) as response, open(tmp_path, "wb") as out: - shutil.copyfileobj(response, out) - os.rename(tmp_path, str(ckpt_path)) - except Exception: - if os.path.exists(tmp_path): - os.remove(tmp_path) - raise - print(f" Download complete: {ckpt_path}") - return str(ckpt_path) +def _download_progress(block_count, block_size, total_size): + """Progress callback for urlretrieve.""" + downloaded = block_count * block_size + if total_size > 0: + pct = min(100, downloaded * 100 // total_size) + mb_done = downloaded / (1024 * 1024) + mb_total = total_size / (1024 * 1024) + print(f"\r {mb_done:.1f}/{mb_total:.1f} MB ({pct}%)", end="", flush=True) diff --git a/pyproject.toml b/pyproject.toml index 1a00203..0728490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,6 @@ dependencies = [ # Configuration "omegaconf>=2.3.0", - # Model download - "huggingface-hub>=0.20.0", - # Image processing "Pillow>=9.0.0", "imageio>=2.20.0", From 01a4099d60df2280da4fe7e473b587e757e0c241 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 21 Mar 2026 20:14:55 +0000 Subject: [PATCH 5/5] Use HuggingFace Hub for checkpoint download HuggingFace has no bandwidth limits for public models, unlike GitHub LFS (1 GB/month free). huggingface_hub is already a transitive dep of transformers so this adds no new dependencies. Upload checkpoint with: huggingface-cli upload ActiveVisionLab/CrossScore ckpt/CrossScore-v1.0.0.ckpt CrossScore-v1.0.0.ckpt https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB --- crossscore/_download.py | 56 +++++++---------------------------------- 1 file changed, 9 insertions(+), 47 deletions(-) diff --git a/crossscore/_download.py b/crossscore/_download.py index f157e90..2a669bc 100644 --- a/crossscore/_download.py +++ b/crossscore/_download.py @@ -1,31 +1,18 @@ """Utilities for downloading CrossScore model checkpoints.""" import os -import urllib.request -import shutil from pathlib import Path -# Download directly from GitHub (served via Git LFS) -CHECKPOINT_URL = ( - "https://github.com/ActiveVisionLab/CrossScore/raw/main/ckpt/CrossScore-v1.0.0.ckpt" -) +HF_REPO_ID = "ActiveVisionLab/CrossScore" CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt" -def get_cache_dir() -> Path: - """Return the cache directory for CrossScore model checkpoints.""" - cache_dir = Path(os.environ.get("CROSSSCORE_CACHE_DIR", Path.home() / ".cache" / "crossscore")) - cache_dir.mkdir(parents=True, exist_ok=True) - return cache_dir - - def get_checkpoint_path() -> str: """Get path to the CrossScore checkpoint, downloading it if necessary. - Downloads from GitHub (Git LFS) on first use and caches locally at - ~/.cache/crossscore/. Set environment variables to customize: + Downloads from HuggingFace Hub on first use and caches locally. + Set environment variables to customize: CROSSSCORE_CKPT_PATH - use a specific local checkpoint file - CROSSSCORE_CACHE_DIR - custom cache directory Returns: Path to the checkpoint file. @@ -37,35 +24,10 @@ def get_checkpoint_path() -> str: raise FileNotFoundError(f"Checkpoint not found at CROSSSCORE_CKPT_PATH={custom_path}") return custom_path - cache_dir = get_cache_dir() - ckpt_path = cache_dir / CHECKPOINT_FILENAME - - if ckpt_path.exists(): - return str(ckpt_path) - - print(f"Downloading CrossScore checkpoint (~129MB)...") - print(f" From: {CHECKPOINT_URL}") - print(f" To: {ckpt_path}") - print(" (Set CROSSSCORE_CKPT_PATH to skip download and use a local file)") - - tmp_path = str(ckpt_path) + ".tmp" - try: - urllib.request.urlretrieve(CHECKPOINT_URL, tmp_path, _download_progress) - os.rename(tmp_path, str(ckpt_path)) - except Exception: - if os.path.exists(tmp_path): - os.remove(tmp_path) - raise - - print(f"\n Download complete.") - return str(ckpt_path) - + from huggingface_hub import hf_hub_download -def _download_progress(block_count, block_size, total_size): - """Progress callback for urlretrieve.""" - downloaded = block_count * block_size - if total_size > 0: - pct = min(100, downloaded * 100 // total_size) - mb_done = downloaded / (1024 * 1024) - mb_total = total_size / (1024 * 1024) - print(f"\r {mb_done:.1f}/{mb_total:.1f} MB ({pct}%)", end="", flush=True) + path = hf_hub_download( + repo_id=HF_REPO_ID, + filename=CHECKPOINT_FILENAME, + ) + return path