diff --git a/.gitignore b/.gitignore index de52713..ffb57d3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,12 @@ tmp.sh log* datadir* debug* + +# Package build artifacts +dist/ +build/ +*.egg-info/ +*.egg + +# Plan file +plan.md diff --git a/README.md b/README.md index 7445e99..20a56c8 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,71 @@ University of Oxford. ## Table of Content -- [Environment](#Environment) +- [Installation](#installation) +- [Quick Start](#quick-start) +- [Environment (conda, legacy)](#environment-conda-legacy) - [Data](#Data) - [Training](#Training) - [Inferencing](#Inferencing) -## Environment -We provide a `environment.yaml` file to set up a `conda` environment: +## Installation + +### GPU (with CUDA) — one line +```bash +git clone https://github.com/ActiveVisionLab/CrossScore.git +cd CrossScore +conda env create -f environment_gpu.yaml && conda activate CrossScore +``` +This installs Python, PyTorch with CUDA 12.1, and all CrossScore dependencies in a single command. + +### CPU only — one line +```bash +git clone https://github.com/ActiveVisionLab/CrossScore.git +cd CrossScore +conda env create -f environment_cpu.yaml && conda activate CrossScore +``` + +> **Note:** If you use the CPU install, CrossScore will print a reminder at runtime on how to switch to the GPU version for faster inference. + +## Quick Start + +### Python API +```python +from crossscore.api import score + +# Score query images against reference images +# Model checkpoint is auto-downloaded on first use (~129MB) +results = score( + query_dir="path/to/query/images", + reference_dir="path/to/reference/images", +) + +# Per-image mean scores +print(results["scores"]) # [0.82, 0.91, 0.76, ...] + +# Score map tensors (pixel-level quality maps) +for score_map in results["score_maps"]: + print(score_map.shape) # (batch_size, H, W) + +# Colorized score map PNGs are written to results["out_dir"] +``` + +### Command Line +```bash +python -m crossscore.cli --query-dir path/to/queries --reference-dir path/to/references + +# With options +python -m crossscore.cli --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + +# Force CPU mode +python -m crossscore.cli --query-dir renders/ --reference-dir gt/ --cpu +``` + +### Environment Variables +- `CROSSSCORE_CKPT_PATH`: Use a specific local checkpoint instead of auto-downloading + +## Environment (conda, legacy) +For training and development, we provide the full pinned `environment.yaml`: ```bash git clone https://github.com/ActiveVisionLab/CrossScore.git cd CrossScore @@ -86,7 +144,7 @@ on our project page. - [ ] Create a HuggingFace demo page. - [ ] Release ECCV quantitative results related scripts. - [x] Release [data processing scripts](https://github.com/ziruiw-dev/CrossScore-3DGS-Preprocessing) -- [ ] Release PyPI and Conda package. +- [x] Release Conda package. ## Acknowledgement This research is supported by an diff --git a/crossscore/__init__.py b/crossscore/__init__.py new file mode 100644 index 0000000..356ca68 --- /dev/null +++ b/crossscore/__init__.py @@ -0,0 +1,38 @@ +"""CrossScore: Towards Multi-View Image Evaluation and Scoring. + +A pip-installable package for neural image quality assessment using +cross-reference scoring with DINOv2 backbone. + +Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) + >>> print(results["scores"]) # per-image mean scores +""" + +__version__ = "1.0.0" + + +def score(*args, **kwargs): + """Score query images against reference images using CrossScore. + + See crossscore.api.score for full documentation. + """ + from crossscore.api import score as _score + + return _score(*args, **kwargs) + + +def get_checkpoint_path(): + """Get path to the CrossScore checkpoint, downloading if necessary. + + See crossscore._download.get_checkpoint_path for full documentation. + """ + from crossscore._download import get_checkpoint_path as _get + + return _get() + + +__all__ = ["score", "get_checkpoint_path"] diff --git a/crossscore/_download.py b/crossscore/_download.py new file mode 100644 index 0000000..2a669bc --- /dev/null +++ b/crossscore/_download.py @@ -0,0 +1,33 @@ +"""Utilities for downloading CrossScore model checkpoints.""" + +import os +from pathlib import Path + +HF_REPO_ID = "ActiveVisionLab/CrossScore" +CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt" + + +def get_checkpoint_path() -> str: + """Get path to the CrossScore checkpoint, downloading it if necessary. + + Downloads from HuggingFace Hub on first use and caches locally. + Set environment variables to customize: + CROSSSCORE_CKPT_PATH - use a specific local checkpoint file + + Returns: + Path to the checkpoint file. + """ + # Allow user to override with a custom path + custom_path = os.environ.get("CROSSSCORE_CKPT_PATH") + if custom_path: + if not Path(custom_path).exists(): + raise FileNotFoundError(f"Checkpoint not found at CROSSSCORE_CKPT_PATH={custom_path}") + return custom_path + + from huggingface_hub import hf_hub_download + + path = hf_hub_download( + repo_id=HF_REPO_ID, + filename=CHECKPOINT_FILENAME, + ) + return path diff --git a/crossscore/api.py b/crossscore/api.py new file mode 100644 index 0000000..17312dd --- /dev/null +++ b/crossscore/api.py @@ -0,0 +1,179 @@ +"""High-level API for CrossScore image quality assessment.""" + +from pathlib import Path +from typing import Optional, Union, List + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import v2 as T +from omegaconf import OmegaConf +from tqdm import tqdm + +from crossscore._download import get_checkpoint_path +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.dataloading.dataset.simple_reference import SimpleReference + + +def _write_score_maps(score_maps, query_paths, out_dir, metric_type, metric_min, metric_max): + """Write score maps to disk as colorized PNGs.""" + from PIL import Image + from crossscore.utils.misc.image import gray2rgb + + vrange_vis = [metric_min, metric_max] + out_dir = Path(out_dir) / "score_maps" + out_dir.mkdir(parents=True, exist_ok=True) + + idx = 0 + for batch_maps, batch_paths in zip(score_maps, query_paths): + for score_map, qpath in zip(batch_maps, batch_paths): + fname = Path(qpath).stem + ".png" + rgb = gray2rgb(score_map.cpu().numpy(), vrange_vis) + Image.fromarray(rgb).save(out_dir / fname) + idx += 1 + return str(out_dir) + + +def score( + query_dir: str, + reference_dir: str, + ckpt_path: Optional[str] = None, + metric_type: str = "ssim", + batch_size: int = 8, + num_workers: int = 4, + resize_short_side: int = 518, + device: Optional[str] = None, + out_dir: Optional[str] = None, + write_score_maps: bool = True, +) -> dict: + """Score query images against reference images using CrossScore. + + Args: + query_dir: Directory containing query images (e.g., NVS rendered images). + reference_dir: Directory containing reference images (e.g., real captured images). + ckpt_path: Path to model checkpoint. Auto-downloads if not provided. + metric_type: Metric type to predict. One of "ssim", "mae", "mse". + batch_size: Batch size for inference. + num_workers: Number of data loading workers. + resize_short_side: Resize images so short side equals this value. -1 to disable. + device: Device string ("cuda", "cuda:0", "cpu"). Auto-detected if None. + out_dir: Output directory for score maps. Defaults to "./crossscore_output". + write_score_maps: Whether to write colorized score map PNGs to disk. + + Returns: + Dictionary with: + - "score_maps": List of score map tensors, each (B, H, W) + - "scores": List of per-image mean scores (float) + - "out_dir": Output directory path (if write_score_maps=True) + + Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) + >>> print(results["scores"]) # per-image mean scores + """ + from crossscore.task.core import load_model + + # Determine device + if device is None: + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + print( + "Note: CUDA not available, running on CPU. " + "For GPU acceleration, install with:\n" + " conda env create -f environment_gpu.yaml" + ) + + # Get checkpoint + if ckpt_path is None: + ckpt_path = get_checkpoint_path() + + # Load model + model = load_model(ckpt_path, device=device) + + # Set up data transforms + img_norm_stat = ImageNetMeanStd() + transforms = { + "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), + } + if resize_short_side > 0: + transforms["resize"] = T.Resize( + resize_short_side, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + ) + + # Build dataset and dataloader + neighbour_config = {"strategy": "random", "cross": 5, "deterministic": False} + dataset = SimpleReference( + query_dir=query_dir, + reference_dir=reference_dir, + transforms=transforms, + neighbour_config=neighbour_config, + return_item_paths=True, + zero_reference=False, + ) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=(device != "cpu"), + persistent_workers=False, + ) + + # Run inference + all_score_maps = [] + all_scores = [] + all_query_paths = [] + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="CrossScore"): + query_img = batch["query/img"].to(device) + ref_imgs = batch.get("reference/cross/imgs") + if ref_imgs is not None: + ref_imgs = ref_imgs.to(device) + + outputs = model( + query_img=query_img, + ref_cross_imgs=ref_imgs, + norm_img=False, + ) + + score_map = outputs["score_map_ref_cross"] # (B, H, W) + all_score_maps.append(score_map.cpu()) + + # Per-image mean score + for i in range(score_map.shape[0]): + all_scores.append(score_map[i].mean().item()) + + # Track query paths for output naming + if "item_paths" in batch and "query/img" in batch["item_paths"]: + all_query_paths.append(batch["item_paths"]["query/img"]) + + # Build results + metric_min = -1 if metric_type == "ssim" else 0 + if metric_type == "ssim": + metric_min = 0 # CrossScore predicts SSIM in [0, 1] by default + + results = { + "score_maps": all_score_maps, + "scores": all_scores, + } + + # Write outputs + if write_score_maps and all_score_maps: + if out_dir is None: + out_dir = "./crossscore_output" + written_dir = _write_score_maps( + all_score_maps, all_query_paths, out_dir, + metric_type, metric_min, metric_max=1, + ) + results["out_dir"] = written_dir + + return results diff --git a/crossscore/cli.py b/crossscore/cli.py new file mode 100644 index 0000000..7be9d9c --- /dev/null +++ b/crossscore/cli.py @@ -0,0 +1,98 @@ +"""Command-line interface for CrossScore.""" + +import argparse + + +def main(): + parser = argparse.ArgumentParser( + description="CrossScore: Multi-View Image Quality Assessment", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""\ +Examples: + crossscore --query-dir path/to/queries --reference-dir path/to/references + crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + crossscore --query-dir renders/ --reference-dir gt/ --cpu +""", + ) + parser.add_argument( + "--query-dir", required=True, help="Directory containing query images" + ) + parser.add_argument( + "--reference-dir", required=True, help="Directory containing reference images" + ) + parser.add_argument( + "--ckpt-path", + default=None, + help="Path to model checkpoint (auto-downloads if not provided)", + ) + parser.add_argument( + "--metric-type", + default="ssim", + choices=["ssim", "mae", "mse"], + help="Metric type to predict (default: ssim)", + ) + parser.add_argument( + "--batch-size", type=int, default=8, help="Batch size (default: 8)" + ) + parser.add_argument( + "--num-workers", type=int, default=4, help="Data loading workers (default: 4)" + ) + parser.add_argument( + "--resize-short-side", + type=int, + default=518, + help="Resize short side to this value, -1 to disable (default: 518)", + ) + parser.add_argument( + "--device", + default=None, + help="Device string, e.g. 'cuda', 'cuda:0', 'cpu' (default: auto-detect)", + ) + parser.add_argument( + "--cpu", + action="store_true", + help="Force CPU mode (no GPU)", + ) + parser.add_argument( + "--out-dir", + default=None, + help="Output directory for results (default: ./crossscore_output)", + ) + parser.add_argument( + "--no-write", + action="store_true", + help="Do not write score map images to disk", + ) + + args = parser.parse_args() + + from crossscore.api import score + + device = "cpu" if args.cpu else args.device + + results = score( + query_dir=args.query_dir, + reference_dir=args.reference_dir, + ckpt_path=args.ckpt_path, + metric_type=args.metric_type, + batch_size=args.batch_size, + num_workers=args.num_workers, + resize_short_side=args.resize_short_side, + device=device, + out_dir=args.out_dir, + write_score_maps=not args.no_write, + ) + + n_images = len(results["scores"]) + print(f"\nCrossScore completed: {n_images} images scored") + if results["scores"]: + mean_score = sum(results["scores"]) / len(results["scores"]) + print(f"Mean score: {mean_score:.4f}") + for i, s in enumerate(results["scores"]): + print(f" Image {i}: {s:.4f}") + if "out_dir" in results: + print(f"Score maps written to: {results['out_dir']}") + + +if __name__ == "__main__": + main() diff --git a/crossscore/config/model/model.yaml b/crossscore/config/model/model.yaml new file mode 100644 index 0000000..06fc8d6 --- /dev/null +++ b/crossscore/config/model/model.yaml @@ -0,0 +1,32 @@ +patch_size: 14 +do_reference_cross: True + +decoder_do_self_attn: True +decoder_do_short_cut: True +need_attn_weights: False # def False, requires more gpu mem if True +need_attn_weights_head_id: 0 # check which attn head + +backbone: + from_pretrained: facebook/dinov2-small + +pos_enc: + multi_view: + interpolate_mode: bilinear + req_grad: False + h: 40 # def 40 so we always interpolate in training, could be 37 or 16 too. + w: 40 + +loss: + fn: l1 + +predict: + metric: + type: ssim + # type: mae + # type: mse + + # min: -1 + min: 0 + max: 1 + + power_factor: default # can be a scalar \ No newline at end of file diff --git a/crossscore/dataloading/__init__.py b/crossscore/dataloading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/dataloading/dataset/__init__.py b/crossscore/dataloading/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/dataloading/dataset/nvs_dataset.py b/crossscore/dataloading/dataset/nvs_dataset.py new file mode 100644 index 0000000..8e6cc7f --- /dev/null +++ b/crossscore/dataloading/dataset/nvs_dataset.py @@ -0,0 +1,717 @@ +import os, sys, json +from pathlib import Path +import numpy as np +import torch +from torch.utils.data import Dataset +from omegaconf import OmegaConf + +from crossscore.utils.io.images import metric_map_read, image_read +from crossscore.utils.neighbour.sampler import SamplerFactory +from crossscore.utils.check_config import ConfigChecker + + +class NeighbourSelector: + """Return paths for neighbouring images (and metric maps) for query and reference.""" + + def __init__(self, paths, neighbour_config): + self.paths = paths + self.neighbour_config = neighbour_config + + self.idx_to_property_mapper = self._build_idx_to_property_mapper(self.paths) + self.all_scene_names = np.array(sorted(self.paths.keys())) + + self.neighbour_sampler_ref_cross = None + if self.neighbour_config["cross"] > 0: + # since query is not in cross ref set, we can only do random sampling + self.neighbour_sampler_ref_cross = SamplerFactory( + strategy_name="random", + N_sample=self.neighbour_config["cross"], + include_query=False, + deterministic=self.neighbour_config["deterministic"], + ) + + @staticmethod + def _build_idx_to_property_mapper(paths): + """Only consider query, since the dataloading idx is based on number of query images""" + + scene_name_list = sorted(paths.keys()) + gaussian_split_list = ["train", "test"] + global_idx = 0 + + idx_to_property_mapper = {} + for sn in scene_name_list: + for gs_split in gaussian_split_list: + if f"gs_{gs_split}" not in paths[sn].keys(): + continue + n_iter = paths[sn][f"gs_{gs_split}"]["query"]["N_iters"] + n_imgs_per_iter = paths[sn][f"gs_{gs_split}"]["query"]["N_imgs_per_iter"] + n_imgs = n_iter * n_imgs_per_iter + for idx in range(n_imgs): + idx_to_property_mapper[global_idx] = { + "scene_name": sn, + "gaussian_split": gs_split, + "iter_idx": idx // n_imgs_per_iter, + "img_idx": idx % n_imgs_per_iter, + } + global_idx += 1 + return idx_to_property_mapper + + def __len__(self): + return len(self.idx_to_property_mapper) + + def __getitem__(self, idx): + results = { + "query/img": None, + "query/score_map": None, + "reference/cross/imgs": [], + } + + scene_name, gaussian_split, iter_idx, img_idx = self.idx_to_property_mapper[idx].values() + tmp_split_paths = self.paths[scene_name][f"gs_{gaussian_split}"] + iter_name = list(tmp_split_paths["query"]["images"].keys())[iter_idx] + + results["query/img"] = tmp_split_paths["query"]["images"][iter_name][img_idx] + results["query/score_map"] = tmp_split_paths["query"]["score_map"][iter_name][img_idx] + + # sampling reference images for cross set + if self.neighbour_sampler_ref_cross is not None: + ref_list_cross = tmp_split_paths["reference"]["cross"]["images"][iter_name] + results["reference/cross/imgs"] = self.neighbour_sampler_ref_cross( + query=None, ref_list=ref_list_cross + ) + + return results + + +class NvsDataset(Dataset): + + def __init__( + self, + dataset_path, + resolution, + data_split, + transforms, + neighbour_config, + metric_type, + metric_min, + metric_max, + return_debug_info=False, + return_item_paths=False, + **kwargs, + ): + """ + :param scene_path: Gaussian Splatting output scene dir that contains point cloud, test, train etc. + :param query_split: train or test + :param transforms: a dict of transforms for all, img, metric_map + """ + self.transforms = transforms + self.neighbour_config = neighbour_config + self.return_debug_info = return_debug_info + self.return_item_paths = return_item_paths + self.zero_reference = kwargs.get("zero_reference", False) + self.num_gaussians_iters = kwargs.get("num_gaussians_iters", -1) + + if data_split not in ["train", "test", "val", "val_small", "test_small"]: + raise ValueError(f"Unknown data_split {data_split}") + + self._detect_conflict_transforms() + self.metric_config = self._build_metric_config(metric_type, metric_min, metric_max) + + # read split json for scene names + if resolution is None: + resolution = os.listdir(dataset_path)[0] + self.dataset_path = Path(dataset_path, resolution) + with open(self.dataset_path / "split.json", "r") as f: + scene_names = json.load(f)[data_split] + scene_paths = [self.dataset_path / n for n in sorted(scene_names)] + + # We use same split for all processed methods (e.g. gaussian, nerfacto, etc.). + # Some scenes may not be processed by some methods, so we need to filter out. + scene_paths = [p for p in scene_paths if p.exists()] + + # Define query and ref sets. Get all paths points to images and metric maps. + self.all_paths = self.get_paths( + scene_paths, self.num_gaussians_iters, self.metric_config.load_dir + ) + + self.neighbour_selector = NeighbourSelector( + self.all_paths, + self.neighbour_config, + ) + + def __getitem__(self, idx): + # neighouring logic + item_paths = self.neighbour_selector[idx] + + # load content from related paths + result = self.load_content(item_paths, self.zero_reference, self.metric_config) + + if "resize" in self.transforms: + result = self.resize_all(result) + + if "crop_integer_patches" in self.transforms: + result = self.adaptive_crop_integer_patches_all(result) + + if self.return_debug_info: + result["debug"] = { + "query/ori_img": result["query/img"], + "query/ori_score_map": result["query/score_map"], + "reference/cross/ori_imgs": result["reference/cross/imgs"], + } + + if self.return_item_paths: + result["item_paths"] = item_paths + + # apply transforms to query + transformed_query = self.transform_query( + result["query/img"], + result["query/score_map"], + ) + result["query/img"] = transformed_query["img"] + result["query/score_map"] = transformed_query["score_map"] + if self.return_debug_info: + result["debug"]["query/crop_param"] = transformed_query["crop_param"] + + if self.neighbour_config["cross"] > 0: + transformed_ref_cross = self.transform_reference(result["reference/cross/imgs"]) + result["reference/cross/imgs"] = transformed_ref_cross["imgs"] + if self.return_debug_info: + result["debug"]["reference/cross/crop_param"] = transformed_ref_cross["crop_param"] + else: + del result["reference/cross/imgs"] + return result + + @staticmethod + def collate_fn_debug(batch): + """Only return the first item in the batch, because the original images + before cropping are in different sizes. + Using [None] to add a batch dimension at the front. + """ + result = { + "query/img": batch[0]["query/img"][None], + "query/score_map": batch[0]["query/score_map"][None], + } + + result["debug"] = { + "query/ori_img": batch[0]["debug"]["query/ori_img"][None], + "query/ori_score_map": batch[0]["debug"]["query/ori_score_map"][None], + "query/crop_param": batch[0]["debug"]["query/crop_param"][None], + } + + result["item_paths"] = batch[0]["item_paths"] + + if "reference/cross/imgs" in batch[0].keys(): + result["reference/cross/imgs"] = batch[0]["reference/cross/imgs"][None] + result["debug"]["reference/cross/ori_imgs"] = batch[0]["debug"][ + "reference/cross/ori_imgs" + ][None] + result["debug"]["reference/cross/crop_param"] = batch[0]["debug"][ + "reference/cross/crop_param" + ][None] + + return result + + def __len__(self): + return len(self.neighbour_selector) + + def resize_all(self, results): + results["query/img"] = self.transforms["resize"](results["query/img"]) + results["query/score_map"] = self.transforms["resize"](results["query/score_map"][None])[0] + if "reference/cross/imgs" in results.keys(): + results["reference/cross/imgs"] = self.transforms["resize"]( + results["reference/cross/imgs"] + ) + return results + + def adaptive_crop_integer_patches_all(self, results): + """ + Adaptively crop all images to the closest integer patch size. + This is needed for test_steps, where we need to compute loss on images in arbitrary sizes. + """ + P = 14 # dinov2 patch size + ori_h, ori_w = results["query/img"].shape[-2:] + new_h = ori_h - ori_h % P + new_w = ori_w - ori_w % P + results["query/img"] = results["query/img"][:, :new_h, :new_w] + results["query/score_map"] = results["query/score_map"][:new_h, :new_w] + if len(results["reference/cross/imgs"]) > 0: + results["reference/cross/imgs"] = results["reference/cross/imgs"][:, :, :new_h, :new_w] + return results + + def transform_query(self, img, score_map): + if self.transforms.get("query_crop", None) is not None: + crop_results = self.transforms["query_crop"](img, score_map) + img = crop_results["out"][0] + score_map = crop_results["out"][1] + crop_param = crop_results["crop_param"] + else: + crop_param = torch.tensor([0, 0, *img.shape[-2:]]) # (4) + + if self.transforms.get("img", None) is not None: + img = self.transforms["img"](img) + + if self.transforms.get("metric_map", None) is not None: + score_map = self.transforms["metric_map"](score_map[None, None]) + score_map = score_map[0, 0] + + return { + "img": img, + "score_map": score_map, + "crop_param": crop_param, + } + + def transform_reference(self, imgs): + if self.transforms.get("reference_crop", None) is not None: + crop_results = self.transforms["reference_crop"](imgs) + imgs = crop_results["out"] + crop_param = crop_results["crop_param"] + else: + crop_param = torch.stack( + [torch.tensor([0, 0, *img.shape[-2:]]) for img in imgs] + ) # (B, 4) + + if self.transforms.get("img", None) is not None: + imgs = self.transforms["img"](imgs) + return { + "imgs": imgs, + "crop_param": crop_param, + } + + def _detect_conflict_transforms(self): + if "resize" in self.transforms: + crop_sizes = [] + if "query_crop" in self.transforms: + crop_sizes.append(self.transforms["query_crop"].output_size) + if "reference_crop" in self.transforms: + crop_sizes.append(self.transforms["reference_crop"].output_size) + + if len(crop_sizes) > 0: + max_crop_size = np.max(crop_sizes) + min_resize_size = np.min(self.transforms["resize"].size) + if min_resize_size < max_crop_size: + raise ValueError( + f"Required to resize image before crop, " + f"but min_resize_size {min_resize_size} " + f"< max_crop_size {max_crop_size}" + ) + + def _build_metric_config(self, metric_type, metric_min, metric_max): + """ + Convert predict_type to load_dir and vrange for metric map reading. + Supported predict_types: ssim_0_1, ssim_-1_1, mse, mae + """ + vrange = [metric_min, metric_max] + + if metric_type in ["ssim", "mae"]: + load_dir = f"metric_map/{metric_type}" + elif metric_type in ["mse"]: + load_dir = "metric_map/mae" # mse can be derived from mae + else: + raise ValueError(f"Invalid metric type {metric_type}") + + cfg = { + "type": metric_type, + "vrange": vrange, + "load_dir": load_dir, + } + cfg = OmegaConf.create(cfg) + return cfg + + @staticmethod + def get_paths(scene_paths, num_gaussians_iters, metric_load_dir): + """Get paths points to images and metric maps. Define query and refenence sets. + Naming convention: + Query: + The (noisy) image we want to measure. + Reference: + Images for making predictions. + For cross ref, we consider captured images that from ref splits. + Example: + When query_split is "train", we consider captured test images as cross ref. + When query_split is "test", we consider captured training images as cross ref. + """ + scene_name_list = sorted([scene_path.name for scene_path in scene_paths]) + all_paths = { + scene_name: { + "train": { + "renders": {}, + "gt": {}, + "score_map": {}, + }, + "test": { + "renders": {}, + "gt": {}, + "score_map": {}, + }, + } + for scene_name in scene_name_list + } + + for scene_path in scene_paths: + scene_name = scene_path.name + for gs_split in all_paths[scene_name].keys(): + dir_split = Path(scene_path, gs_split) + dir_iter_list = sorted(os.listdir(dir_split), key=lambda x: int(x.split("_")[-1])) + dir_iter_list = [Path(dir_split, d) for d in dir_iter_list] + + # This dataset contains images rendered from gaussian splatting checkpoints + # at different iterations. Use this to use images from earlier checkpoints, + # which have more artefacts. + if num_gaussians_iters > 0: + dir_iter_list = dir_iter_list[:num_gaussians_iters] + + for dir_iter in dir_iter_list: + iter_num = int(dir_iter.name.split("_")[-1]) + for img_type in all_paths[scene_name][gs_split].keys(): + if img_type in ["renders", "gt"]: + img_dir = dir_iter / img_type + elif img_type == "score_map": + img_dir = dir_iter / metric_load_dir + else: + raise ValueError(f"Unknown img_type {img_type}") + + if os.path.exists(img_dir): + img_names = sorted(os.listdir(img_dir)) + paths = [img_dir / img_n for img_n in img_names] + paths = [str(p) for p in paths] + else: + # if no metric map available, use a placeholder "empty_image" + paths = ["empty_image"] * len(all_paths[scene_name][gs_split]["gt"]) + all_paths[scene_name][gs_split][img_type][iter_num] = paths + + # all types of items should have the same item number as gt + for img_type in all_paths[scene_name][gs_split].keys(): + for iter_num in all_paths[scene_name][gs_split][img_type].keys(): + N_imgs = len(all_paths[scene_name][gs_split][img_type][iter_num]) + N_gt = len(all_paths[scene_name][gs_split]["gt"][iter_num]) + if N_imgs != N_gt: + raise ValueError( + f"Number of items mismatch in " + f"{scene_name}/{gs_split}/{iter_num}/{img_type}" + ) + + # assemble query and reference sets + def get_cross_ref_split(query_split): + all_splits = ["train", "test"] + all_splits.remove(query_split) + cross_ref_split = all_splits[0] + return cross_ref_split + + results = {} + for scene_name in scene_name_list: + results[scene_name] = {} + for gs_split in ["train", "test"]: + cross_ref_split = get_cross_ref_split(gs_split) + + results[scene_name][f"gs_{gs_split}"] = { + "query": { + "images": all_paths[scene_name][gs_split]["renders"], + "score_map": all_paths[scene_name][gs_split]["score_map"], + "N_iters": len(all_paths[scene_name][gs_split]["renders"]), + "N_imgs_per_iter": len( + list(all_paths[scene_name][gs_split]["renders"].values())[0] + ), + }, + "reference": { + "cross": { + "images": all_paths[scene_name][cross_ref_split]["gt"], + "N_iters": len(all_paths[scene_name][cross_ref_split]["gt"]), + "N_imgs_per_iter": len( + list(all_paths[scene_name][cross_ref_split]["gt"].values())[0] + ), + }, + }, + } + return results + + @staticmethod + def load_content(item_paths, zero_reference, metric_config): + results = { + "query/img": None, + "query/score_map": None, + "reference/cross/imgs": [], + } + + for k in item_paths.keys(): + if k == "query/img": + results[k] = torch.tensor(image_read(item_paths[k])).permute(2, 0, 1) # (3, H, W) + elif k == "query/score_map": + if metric_config.type == "ssim": + if item_paths[k] == "empty_image": + results[k] = torch.zeros_like(results["query/img"][0]) # (H, W) + else: + # (H, W), always read SSIM in range [-1, 1] + results[k] = torch.tensor(metric_map_read(item_paths[k], vrange=[-1, 1])) + if metric_config.vrange == [0, 1]: + results[k] = results[k].clamp(0, 1) + elif metric_config.type in ["mse", "mae"]: + if item_paths[k] == "empty_image": + results[k] = torch.full_like(results["query/img"][0], torch.nan) # (H, W) + else: + # (H, W), always read MAE in range [0, 1] + results[k] = torch.tensor(metric_map_read(item_paths[k], vrange=[0, 1])) + if metric_config.type == "mse": + results[k] = results[k].square() # create mse from loaded mae + elif metric_config.type is None: + # for SimpleReference, which doesn't need to load score maps + results[k] = torch.zeros_like(results["query/img"][0]) + elif k in ["reference/cross/imgs"]: + if len(item_paths[k]) == 0: + continue # if not loading this reference set, skip + for path in item_paths[k]: + if path == "empty_image": + # NOTE: this assumes ref_img_size == query_img_size + tmp_img = torch.zeros_like(results["query/img"]) # (3, H, W) + else: + tmp_img = torch.tensor(image_read(path)).permute(2, 0, 1) # (3, H, W) + results[k].append(tmp_img) # (3, H, W) + results[k] = torch.stack(results[k], dim=0) # (N, 3, H, W) + if zero_reference: + results[k] = torch.zeros_like(results[k]) + else: + raise ValueError(f"Unknown key {k}") + return results + + +def vis_batch(cfg, batch, metric_type, metric_min, metric_max, e, b): + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + metric_vrange = [metric_min, metric_max] + + # mkdir to save figures + save_fig_dir = Path(cfg.this_main.save_fig_dir).expanduser() + save_fig_dir.mkdir(parents=True, exist_ok=True) + + # Vis batch[0] in two figures: one actual loaded and one with debug info + # First figure with actual loaded data + max_cols = max([3, cfg.data.neighbour_config.cross]) + _, axes = plt.subplots(2, max_cols, figsize=(15, 9)) + for ax in axes.flatten(): + ax.set_axis_off() + + # first row: query + axes[0][0].imshow(batch["query/img"][0].permute(1, 2, 0).clip(0, 1)) + axes[0][0].set_title("query/img") + axes[0][1].imshow( + batch["query/score_map"][0], + vmin=metric_vrange[0], + vmax=metric_vrange[1], + cmap="turbo", + ) + axes[0][1].set_title(f"query/{metric_type}_map") + for i in range(2): + axes[0][i].set_axis_on() + + # second row: cross ref + if "reference/cross/imgs" in batch.keys(): + for i in range(batch["reference/cross/imgs"].shape[1]): + axes[1][i].imshow(batch["reference/cross/imgs"][0, i].permute(1, 2, 0).clip(0, 1)) + axes[1][i].set_title(f"reference/cross/imgs_{i}") + axes[1][i].set_axis_on() + + plt.tight_layout() + plt.savefig(save_fig_dir / f"e{e}b{b}.jpg") + plt.close() + + # Second figure with full debug details in three rows + if cfg.this_main.return_debug_info: + _, axes = plt.subplots(2, max_cols, figsize=(20, 10)) + for ax in axes.flatten(): + ax.set_axis_off() + + # first row: query + axes[0][0].imshow(batch["debug"]["query/ori_img"][0].permute(1, 2, 0).clip(0, 1)) + axes[0][0].set_title("query/ori_img") + axes[0][1].imshow( + batch["debug"][f"query/ori_score_map"][0], + vmin=metric_vrange[0], + vmax=metric_vrange[1], + cmap="turbo", + ) + axes[0][1].set_title(f"query/ori_{metric_type}_map") + # crop box + crop_param = batch["debug"]["query/crop_param"][0] + for i in range(2): + rect = Rectangle( + (crop_param[1], crop_param[0]), + crop_param[3], + crop_param[2], + linewidth=3, + edgecolor="r", + facecolor="none", + ) + axes[0][i].add_patch(rect) + axes[0][i].set_axis_on() + + # third row: cross ref + if "reference/cross/imgs" in batch.keys(): + for i in range(batch["debug"]["reference/cross/ori_imgs"].shape[1]): + axes[1][i].imshow( + batch["debug"]["reference/cross/ori_imgs"][0, i].permute(1, 2, 0).clip(0, 1) + ) + axes[1][i].set_title(f"reference/cross/ori_imgs_{i}") + # crop box + crop_param = batch["debug"]["reference/cross/crop_param"][0][i] + rect = Rectangle( + (crop_param[1], crop_param[0]), + crop_param[3], + crop_param[2], + linewidth=3, + edgecolor="r", + facecolor="none", + ) + axes[1][i].add_patch(rect) + axes[1][i].set_axis_on() + + plt.tight_layout() + plt.savefig(save_fig_dir / f"e{e}b{b}_full.jpg") + plt.close() + + +if __name__ == "__main__": + from lightning import seed_everything + from torchvision.transforms import v2 as T + from tqdm import tqdm + from dataloading.transformation.crop import CropperFactory + from utils.io.images import ImageNetMeanStd + from omegaconf import OmegaConf + + seed_everything(1) + + cfg = { + "data": { + "dataset": { + "path": "datadir/processed_training_ready/gaussian/map-free-reloc", + "resolution": "res_540", + "num_gaussians_iters": 1, + "zero_reference": False, + }, + "loader": { + "batch_size": 8, + "num_workers": 4, + "shuffle": True, + "data_split": "train", + "pin_memory": True, + "persistent_workers": True, + }, + "transforms": { + "crop_size": 518, + }, + "neighbour_config": { + "strategy": "random", + "cross": 5, + "deterministic": False, + }, + }, + "this_main": { + "skip_vis": False, + "save_fig_dir": "./debug/dataset/NvsData", + "epochs": 3, + "skip_batches": 5, + "deterministic_crop": False, + "return_debug_info": True, + "return_item_paths": True, + "resize_short_side": -1, # -1 to disable + "crop_mode": "dataset_default", + # "crop_mode": None, + }, + "model": { + "patch_size": 14, + "predict": { + "metric": { + "type": "ssim", + "min": 0, + "max": 1, + }, + }, + }, + } + cfg = OmegaConf.create(cfg) + + # Overwrite cfg in some conditions + if cfg.data.dataset.resolution == "res_540": + cfg.data.transforms.crop_size = 518 + elif cfg.data.dataset.resolution == "res_400": + cfg.data.transforms.crop_size = 392 + elif cfg.data.dataset.resolution == "res_200": + cfg.data.transforms.crop_size = 196 + else: + raise ValueError("Unknown resolution") + + if cfg.data.loader.num_workers == 0: + cfg.data.loader.persistent_workers = False + + # Check config + ConfigChecker(cfg).check_dataset() + + # Init transforms for dataset + img_norm_stat = ImageNetMeanStd() + transforms = { + "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), + } + + if cfg.this_main.crop_mode == "dataset_default": + transforms["query_crop"] = CropperFactory( + output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), + same_on_batch=True, + deterministic=cfg.this_main.deterministic_crop, + ) + transforms["reference_crop"] = CropperFactory( + output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), + same_on_batch=False, + deterministic=cfg.this_main.deterministic_crop, + ) + + if cfg.this_main.resize_short_side > 0: + transforms["resize"] = T.Resize( + cfg.this_main.resize_short_side, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + ) + + # Init dataset and dataloader + dataset = NvsDataset( + dataset_path=cfg.data.dataset.path, + resolution=cfg.data.dataset.resolution, + data_split=cfg.data.loader.data_split, + transforms=transforms, + neighbour_config=cfg.data.neighbour_config, + metric_type=cfg.model.predict.metric.type, + metric_min=cfg.model.predict.metric.min, + metric_max=cfg.model.predict.metric.max, + return_debug_info=cfg.this_main.return_debug_info, + return_item_paths=cfg.this_main.return_item_paths, + num_gaussians_iters=cfg.data.dataset.num_gaussians_iters, + zero_reference=cfg.data.dataset.zero_reference, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=cfg.data.loader.batch_size, + shuffle=cfg.data.loader.shuffle, + num_workers=cfg.data.loader.num_workers, + pin_memory=cfg.data.loader.pin_memory, + persistent_workers=cfg.data.loader.persistent_workers, + collate_fn=dataset.collate_fn_debug if cfg.this_main.return_debug_info else None, + ) + + # Actual looping dataset + for e in tqdm(range(cfg.this_main.epochs), desc="Epoch", dynamic_ncols=True): + for b, batch in enumerate(tqdm(dataloader, desc="Batch", dynamic_ncols=True)): + + if cfg.this_main.skip_batches > 0 and b >= cfg.this_main.skip_batches: + break + + if cfg.this_main.skip_vis: + continue + + vis_batch( + cfg, + batch, + metric_type=cfg.model.predict.metric.type, + metric_min=cfg.model.predict.metric.min, + metric_max=cfg.model.predict.metric.max, + e=e, + b=b, + ) diff --git a/crossscore/dataloading/dataset/simple_reference.py b/crossscore/dataloading/dataset/simple_reference.py new file mode 100644 index 0000000..cf72e92 --- /dev/null +++ b/crossscore/dataloading/dataset/simple_reference.py @@ -0,0 +1,221 @@ +import os, sys +from pathlib import Path +import torch +from omegaconf import OmegaConf + +from crossscore.dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch + + +class SimpleReference(NvsDataset): + def __init__( + self, + query_dir, + reference_dir, + transforms, + neighbour_config, + return_debug_info=False, + return_item_paths=False, + **kwargs, + ): + self.transforms = transforms + self.neighbour_config = neighbour_config + self.return_debug_info = return_debug_info + self.return_item_paths = return_item_paths + self.zero_reference = kwargs.get("zero_reference", False) + + self._detect_conflict_transforms() + self.metric_config = self._build_empty_metric_config() + + self.all_paths = self.get_paths(query_dir, reference_dir) + self.neighbour_selector = NeighbourSelector(self.all_paths, self.neighbour_config) + + def _build_empty_metric_config(self): + cfg = { + "type": None, + "vrange": None, + "load_dir": None, + } + cfg = OmegaConf.create(cfg) + return cfg + + @staticmethod + def get_paths(query_dir, reference_dir): + """Define query and reference paths for ONE scene. + This function is written in a way that mimics the + NvsDataset.get_paths(), so that we can reuse most NvsDataset methods. + + :param scene_name: str + :param query_dir: str, a dir that contains query images + :param reference_dir: str, a dir that contains reference images + """ + + query_dir = os.path.expanduser(query_dir) + reference_dir = os.path.expanduser(reference_dir) + query_paths = [os.path.join(query_dir, p) for p in sorted(os.listdir(query_dir))] + reference_paths = [ + os.path.join(reference_dir, p) for p in sorted(os.listdir(reference_dir)) + ] + + fake_iter = -1 + query = { + "images": {fake_iter: query_paths}, + "score_map": {fake_iter: ["empty_image"] * len(query_paths)}, + "N_iters": 1, + "N_imgs_per_iter": len(query_paths), + } + reference = { + "cross": { + "images": {fake_iter: reference_paths}, + "N_iters": 1, + "N_imgs_per_iter": len(reference_paths), + } + } + + # use query dir as scene name and anonymize the path + scene_name = str(query_dir).replace(str(Path.home()), "~") + results = { + scene_name: { + "gs_test": { + "query": query, + "reference": reference, + }, + } + } + return results + + +if __name__ == "__main__": + from lightning import seed_everything + from torchvision.transforms import v2 as T + from tqdm import tqdm + from crossscore.dataloading.transformation.crop import CropperFactory + from crossscore.utils.io.images import ImageNetMeanStd + from omegaconf import OmegaConf + + seed_everything(1) + cfg = { + "data": { + "dataset": { + "query_dir": "datadir/processed_training_ready/gaussian/map-free-reloc/res_540/s00000/test/ours_1000/renders", + "reference_dir": "datadir/processed_training_ready/gaussian/map-free-reloc/res_540/s00000/train/ours_1000/gt", + "resolution": "res_540", + "zero_reference": False, + }, + "loader": { + "batch_size": 8, + "num_workers": 4, + "shuffle": True, + "pin_memory": True, + "persistent_workers": False, + }, + "transforms": { + "crop_size": 392, + }, + "neighbour_config": { + "strategy": "random", + "cross": 5, + "deterministic": False, + }, + }, + "this_main": { + "skip_vis": False, + "save_fig_dir": "./debug/dataset/simple_reference", + "epochs": 1, + "skip_batches": -1, + "deterministic_crop": False, + "return_debug_info": True, + "return_item_paths": True, + "resize_short_side": 518, # -1 to disable + # "crop_mode": "dataset_default", + "crop_mode": None, + }, + "model": { + "patch_size": 14, + "do_reference_cross": True, + "predict": { + "metric": { + "type": "ssim", + "min": 0, + "max": 1, + }, + }, + }, + } + cfg = OmegaConf.create(cfg) + + # overwrite cfg in some conditions + if cfg.data.dataset.resolution == "res_540": + cfg.data.transforms.crop_size = 518 + elif cfg.data.dataset.resolution == "res_400": + cfg.data.transforms.crop_size = 392 + elif cfg.data.dataset.resolution == "res_200": + cfg.data.transforms.crop_size = 196 + else: + raise ValueError("Unknown resolution") + + if cfg.data.loader.num_workers == 0: + cfg.data.loader.persistent_workers = False + + # for dataloader + img_norm_stat = ImageNetMeanStd() + transforms = { + "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), + } + + if cfg.this_main.crop_mode == "dataset_default": + transforms["query_crop"] = CropperFactory( + output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), + same_on_batch=True, + deterministic=cfg.this_main.deterministic_crop, + ) + transforms["reference_crop"] = CropperFactory( + output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), + same_on_batch=False, + deterministic=cfg.this_main.deterministic_crop, + ) + + if cfg.this_main.resize_short_side > 0: + transforms["resize"] = T.Resize( + cfg.this_main.resize_short_side, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + ) + + dataset = SimpleReference( + query_dir=cfg.data.dataset.query_dir, + reference_dir=cfg.data.dataset.reference_dir, + transforms=transforms, + neighbour_config=cfg.data.neighbour_config, + return_debug_info=cfg.this_main.return_debug_info, + return_item_paths=cfg.this_main.return_item_paths, + zero_reference=cfg.data.dataset.zero_reference, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=cfg.data.loader.batch_size, + shuffle=cfg.data.loader.shuffle, + num_workers=cfg.data.loader.num_workers, + pin_memory=cfg.data.loader.pin_memory, + persistent_workers=cfg.data.loader.persistent_workers, + collate_fn=dataset.collate_fn_debug if cfg.this_main.return_debug_info else None, + ) + + # actual looping dataset + for e in tqdm(range(cfg.this_main.epochs), desc="Epoch", dynamic_ncols=True): + for b, batch in enumerate(tqdm(dataloader, desc="Batch", dynamic_ncols=True)): + if cfg.this_main.skip_batches > 0 and b >= cfg.this_main.skip_batches: + break + + if cfg.this_main.skip_vis: + continue + + vis_batch( + cfg, + batch, + metric_type=cfg.model.predict.metric.type, + metric_min=cfg.model.predict.metric.min, + metric_max=cfg.model.predict.metric.max, + e=e, + b=b, + ) diff --git a/crossscore/dataloading/transformation/__init__.py b/crossscore/dataloading/transformation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/dataloading/transformation/crop.py b/crossscore/dataloading/transformation/crop.py new file mode 100644 index 0000000..b613bb7 --- /dev/null +++ b/crossscore/dataloading/transformation/crop.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +import numpy as np +import torch +from torchvision.transforms import v2 as T + + +def get_crop_params(input_size, output_size, deterministic): + """Get random crop parameters for a given image and output size. + Args: + img: numpy array hwc + output_size (tuple): Expected output size of the crop. + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + in_h, in_w = input_size + out_h, out_w = output_size + + # i, j, h, w + if deterministic: + i, j = 0, 0 + else: + i = np.random.randint(0, in_h - out_h + 1) + j = np.random.randint(0, in_w - out_w + 1) + return torch.tensor([i, j, out_h, out_w]) + + +class Cropper(ABC): + def __init__(self, output_size, deterministic=False): + self.output_size = output_size + self.deterministic = deterministic + + @abstractmethod + def __call__(self, *args): + raise NotImplementedError + + +class RandomCropperBatchSeparate(Cropper): + """For an input tensor, assuming it's batched, and apply **DIFF** crop params + to each item in the batch. + """ + + def __call__(self, imgs): + # x: (B, C, H, W), (B, H, W) + if imgs.ndim not in [3, 4]: + raise ValueError("imgs.ndim must be one of [3, 4]") + + out_list = [] + crop_param_list = [] + for img in imgs: + crop_param = get_crop_params(img.shape[-2:], self.output_size, self.deterministic) + img = T.functional.crop(img, *crop_param) + out_list.append(img) + crop_param_list.append(crop_param) + out_list = torch.stack(out_list) + crop_param_list = torch.stack(crop_param_list) + return { + "out": out_list, # (B, C, H, W) or (B, H, W) + "crop_param": crop_param_list, # (B, 4) + } + + +class RandomCropperBatchSame(Cropper): + """For a list of input tensors, assuming they're batched, and apply **SAME** + crop params to all. + """ + + def __call__(self, *args): + # use one set of crop params for all input + crop_param = get_crop_params(args[0].shape[-2:], self.output_size, self.deterministic) + out = [T.functional.crop(x, *crop_param) for x in args] + return { + "out": out, + "crop_param": crop_param, + } + + +class CropperFactory: + def __init__(self, output_size, same_on_batch, deterministic=False): + self.output_size = output_size + if same_on_batch: + self.cropper = RandomCropperBatchSame(output_size, deterministic) + else: + self.cropper = RandomCropperBatchSeparate(output_size, deterministic) + + def __call__(self, *args): + return self.cropper(*args) diff --git a/crossscore/model/__init__.py b/crossscore/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/model/cross_reference.py b/crossscore/model/cross_reference.py new file mode 100644 index 0000000..32879bd --- /dev/null +++ b/crossscore/model/cross_reference.py @@ -0,0 +1,94 @@ +import torch +from .customised_transformer.transformer import ( + TransformerDecoderLayerCustomised, + TransformerDecoderCustomised, +) +from .regression_layer import RegressionLayer +from crossscore.utils.misc.image import jigsaw_to_image + + +class CrossReferenceNet(torch.nn.Module): + def __init__(self, cfg, dinov2_cfg): + super().__init__() + self.cfg = cfg + self.dinov2_cfg = dinov2_cfg + + # set up input projection + self.input_proj = torch.nn.Identity() + + # set up output final activation function + self.final_activation_fn = RegressionLayer( + metric_type=self.cfg.model.predict.metric.type, + metric_min=self.cfg.model.predict.metric.min, + metric_max=self.cfg.model.predict.metric.max, + pow_factor=self.cfg.model.predict.metric.power_factor, + ) + + # layers + self.attn = TransformerDecoderCustomised( + decoder_layer=TransformerDecoderLayerCustomised( + d_model=self.dinov2_cfg.hidden_size, + nhead=8, + dim_feedforward=self.dinov2_cfg.hidden_size, + dropout=0.0, + batch_first=True, + do_self_attn=self.cfg.model.decoder_do_self_attn, + do_short_cut=self.cfg.model.decoder_do_short_cut, + ), + num_layers=2, + ) + + # set up head out dimension + out_size = cfg.model.patch_size**2 + + # head + self.head = torch.nn.Sequential( + torch.nn.Linear(self.dinov2_cfg.hidden_size, self.dinov2_cfg.hidden_size), + torch.nn.LeakyReLU(), + torch.nn.Linear(self.dinov2_cfg.hidden_size, out_size), + self.final_activation_fn, + ) + + def forward( + self, + featmap_query, + featmap_ref, + memory_mask, + dim_params, + need_attn_weights, + need_attn_weights_head_id, + ): + """ + :param featmap_query: (B, num_patches, hidden_size) + :param featmap_ref: (B, N_ref * num_patches, hidden_size) + :param memory_mask: None + :param dim_params: dict + """ + B = dim_params["B"] + N_patch_h = dim_params["N_patch_h"] + N_patch_w = dim_params["N_patch_w"] + N_ref = dim_params["N_ref"] + + results = {} + score_map, _, mha_weights = self.attn( + tgt=featmap_query, + memory=featmap_ref, + memory_mask=memory_mask, + need_weights=need_attn_weights, + need_weights_head_id=need_attn_weights_head_id, + ) # (B, num_patches_tgt, hidden_size), (B, num_patches_tgt, num_patches_mem) + + # return score map + score_map = self.head(score_map) # (B, num_patches, num_ssim_pixels) + + # reshape to image size + P = self.cfg.model.patch_size + score_map = score_map.view(B, -1, P, P) + score_map = jigsaw_to_image(score_map, grid_size=(N_patch_h, N_patch_w)) # (B, H, W) + results["score_map"] = score_map + + # return attn weights + if need_attn_weights: + mha_weights = mha_weights.view(B, N_patch_h, N_patch_w, N_ref, N_patch_h, N_patch_w) + results["attn_weights_map_mha"] = mha_weights + return results diff --git a/crossscore/model/customised_transformer/__init__.py b/crossscore/model/customised_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/model/customised_transformer/transformer.py b/crossscore/model/customised_transformer/transformer.py new file mode 100644 index 0000000..49f650e --- /dev/null +++ b/crossscore/model/customised_transformer/transformer.py @@ -0,0 +1,268 @@ +from typing import Optional, Callable, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules.transformer import ( + TransformerDecoder, + _get_seq_len, + _detect_is_causal_mask, + _get_activation_fn, +) +from torch.nn.modules.activation import MultiheadAttention +from torch.nn.modules.dropout import Dropout +from torch.nn.modules.linear import Linear +from torch.nn.modules.normalization import LayerNorm + + +# fmt: off +# Zirui: This class is copied and modified from TransformerDecoderLayer in pytorch 2.1.2 +class TransformerDecoderLayerCustomised(torch.nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectively. Otherwise it's done after. + Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + + Alternatively, when ``batch_first`` is ``True``: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 20, 512) + >>> out = decoder_layer(tgt, memory) + """ + __constants__ = ['norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, + bias: bool = True, device=None, dtype=None, do_self_attn=True, do_short_cut=True) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.do_self_attn = do_self_attn # Zirui: added + self.do_short_cut = do_short_cut # Zirui: added + + if self.do_self_attn: + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, + bias=bias, **factory_kwargs) + self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, + bias=bias, **factory_kwargs) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + need_weights=True, + need_weights_head_id=0, + ): + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``False``. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in Transformer class. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + if self.do_self_attn: + sa_out, sa_weights = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal, need_weights) + if self.do_short_cut: + x = x + sa_out + else: + x = sa_out + else: + sa_weights = None + + mha_out, mha_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal, need_weights) + if self.do_short_cut: + x = x + mha_out + else: + x = mha_out + + x = x + self._ff_block(self.norm3(x)) + else: + if self.do_self_attn: + sa_out, sa_weights = self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal, need_weights) + if self.do_short_cut: + x = self.norm1(x + sa_out) + else: + x = self.norm1(sa_out) + else: + sa_weights = None + + mha_out, mha_weights = self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal, need_weights) + if self.do_short_cut: + x = self.norm2(x + mha_out) + else: + x = self.norm2(mha_out) + + x = self.norm3(x + self._ff_block(x)) + + if sa_weights is not None: + sa_weights = sa_weights[:, need_weights_head_id] # return attn weights of a specific head + if mha_weights is not None: + mha_weights = mha_weights[:, need_weights_head_id] # return attn weights of a specific head + return x, sa_weights, mha_weights + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, need_weights: bool = True): + x, attn_weights = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=need_weights, + average_attn_weights=False,) + if need_weights: + attn_weights = attn_weights.detach() # attn weights of all heads + return self.dropout1(x), attn_weights + + # multihead attention block + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, need_weights: bool = True): + x, attn_weights = self.multihead_attn(x, mem, mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=need_weights, + average_attn_weights=False,) + if need_weights: + attn_weights = attn_weights.detach() # attn weights of all heads + return self.dropout2(x), attn_weights + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +class TransformerDecoderCustomised(TransformerDecoder): + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, need_weights: bool = True, need_weights_head_id: int = 0): + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in Transformer class. + """ + output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].multihead_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + for mod in self.layers: + output, sa_weights, mha_weights = mod( + output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + need_weights=need_weights, + need_weights_head_id=need_weights_head_id, + ) + + if self.norm is not None: + output = self.norm(output) + + # Only return the last layer's attention weights. + # If need_weights is False, they are None. + return output, sa_weights, mha_weights diff --git a/crossscore/model/positional_encoding.py b/crossscore/model/positional_encoding.py new file mode 100644 index 0000000..e0e51dd --- /dev/null +++ b/crossscore/model/positional_encoding.py @@ -0,0 +1,75 @@ +import torch + + +class MultiViewPosionalEmbeddings(torch.nn.Module): + + def __init__( + self, + positional_encoding_h, + positional_encoding_w, + interpolate_mode, + req_grad, + patch_size=14, + hidden_size=384, + ): + """Apply positional encoding to input multi-view embeddings. + Conceptually, posional encoding are in (pe_h, pe_w, C) shape. + We can interpolate in 2D to adapt to different input image size, but + no interpolation in the view dimension. + + Shorthand: + P: patch size + C: hidden size + N: number of + pe: positional encoding + mv: multi-view + emb: embedding + + :param patch_size: DINOv2 patch size def 14 + :param hidden_size: DINOv2 hidden size def 384 (dinov2_small) + """ + super().__init__() + self.P = patch_size + self.C = hidden_size + self.pe_h = positional_encoding_h + self.pe_w = positional_encoding_w + self.interpolate_mode = interpolate_mode + self.PE = torch.nn.Parameter( + torch.randn(1, self.pe_h, self.pe_w, self.C), + req_grad, + ) + + def forward(self, mv_emb, N_view, img_h, img_w): + """ + :param mv_emb: (B, N_patch, C), N_patch: N_view * emb_h * emb_w + """ + B = mv_emb.shape[0] + emb_h = img_h // self.P + emb_w = img_w // self.P + + # a short cut no need to interpolate + if emb_h == self.pe_h and emb_w == self.pe_w: + # no need to interpolate + mv_emb = mv_emb.view(B, N_view, emb_h, emb_w, self.C) + mv_emb = mv_emb + self.PE + mv_emb = mv_emb.view(B, N_view * emb_h * emb_w, self.C) + return mv_emb + + # 1. interpolate position embedding + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + _PE = torch.nn.functional.interpolate( + self.PE.permute(0, 3, 1, 2), + scale_factor=( + (emb_h + 1e-4) / self.pe_h, + (emb_w + 1e-4) / self.pe_w, + ), + mode=self.interpolate_mode, + align_corners=True, + ) # (1, C, emb_h, emb_w) + + # 2. embed and reshape back + mv_emb = mv_emb.view(B, N_view, emb_h, emb_w, self.C) + mv_emb = mv_emb + _PE.permute(0, 2, 3, 1)[None] + mv_emb = mv_emb.reshape(B, N_view * emb_h * emb_w, self.C) + return mv_emb diff --git a/crossscore/model/regression_layer.py b/crossscore/model/regression_layer.py new file mode 100644 index 0000000..f943afd --- /dev/null +++ b/crossscore/model/regression_layer.py @@ -0,0 +1,80 @@ +from functools import partial +import sys +from pathlib import Path +import torch + +from crossscore.utils.check_config import check_metric_prediction_config + + +class RegressionLayer(torch.nn.Module): + def __init__(self, metric_type, metric_min, metric_max, pow_factor="default"): + """ + Make a regression layer based on the metric configuration. + Use power_factor to help predict very small numbers. + """ + super().__init__() + + check_metric_prediction_config(metric_type, metric_min, metric_max) + self.metric_type = metric_type + self.metric_min = metric_min + self.metric_max = metric_max + + self.activation_fn = self._get_activation_fn() + self.pow_fn = self._get_pow_fn(pow_factor) + + def forward(self, x): + x = self.activation_fn(x) + x = self.pow_fn(x) + return x + + def _get_activation_fn(self): + if self.metric_min == -1: + activation_fn = torch.nn.Tanh() + elif self.metric_min == 0: + activation_fn = torch.nn.Sigmoid() + else: + raise ValueError(f"metric_min={self.metric_min} not supported") + return activation_fn + + def _get_pow_fn(self, p): + # define a lookup table for default power factor + pow_default_table = { + "ssim": 1, + "mae": 2, + "mse": 4, + } + + # only apply power fn for a non-negative score value range + if self.metric_min == 0: + if p == "default": + # use default power factor from the look up table + p = pow_default_table[self.metric_type] + else: + pass # use the provided power factor + else: + p = 1 + + if float(p) == 1.0: + pow_fn = torch.nn.Identity() + else: + pow_fn = partial(torch.pow, exponent=p) + return pow_fn + + +if __name__ == "__main__": + for metric_type in ["ssim", "mae", "mse"]: + for metric_min in [-1, 0]: + for p in ["some_typo", "default", 0.1, 1, 1.5, 5]: + print(f"--------") + print(f"metric_type: {metric_type}, metric_min: {metric_min}, pow_factor: {p}") + try: + l = RegressionLayer( + metric_type=metric_type, + metric_min=metric_min, + metric_max=1, + pow_factor=p, + ) + print(f"activation_fn: {l.activation_fn}") + print(f"pow_fn: {l.pow_fn}") + except Exception as e: + print(f"Error: {e}") diff --git a/crossscore/task/__init__.py b/crossscore/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/task/core.py b/crossscore/task/core.py new file mode 100644 index 0000000..0bca761 --- /dev/null +++ b/crossscore/task/core.py @@ -0,0 +1,186 @@ +"""CrossScoreNet: the core neural network for CrossScore inference.""" + +import torch +from transformers import Dinov2Config, Dinov2Model +from omegaconf import OmegaConf + +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.model.cross_reference import CrossReferenceNet +from crossscore.model.positional_encoding import MultiViewPosionalEmbeddings + + +class CrossScoreNet(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + img_norm_stat = ImageNetMeanStd() + self.register_buffer( + "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) + ) + + # backbone, freeze + self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) + self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) + for param in self.backbone.parameters(): + param.requires_grad = False + + # positional encoding layer + self.pos_enc_fn = MultiViewPosionalEmbeddings( + positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, + positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, + interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, + req_grad=self.cfg.model.pos_enc.multi_view.req_grad, + patch_size=self.cfg.model.patch_size, + hidden_size=self.dinov2_cfg.hidden_size, + ) + + # cross reference predictor + if self.cfg.model.do_reference_cross: + self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) + + def forward( + self, + query_img, + ref_cross_imgs, + need_attn_weights=False, + need_attn_weights_head_id=0, + norm_img=False, + ): + """ + :param query_img: (B, 3, H, W) + :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) + :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + + if norm_img: + img_mean = self.img_mean_std[None, :3, None, None] + img_std = self.img_mean_std[None, 3:, None, None] + query_img = (query_img - img_mean) / img_std + if ref_cross_imgs is not None: + ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] + + featmaps = self.get_featmaps(query_img, ref_cross_imgs) + results = {} + + # processing (and predicting) for query + featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) + + if self.cfg.model.do_reference_cross: + N_ref_cross = ref_cross_imgs.shape[1] + + # (B, N_ref_cross*num_patches, hidden_size) + featmaps["ref_cross"] = self.pos_enc_fn( + featmaps["ref_cross"], + N_view=N_ref_cross, + img_h=H, + img_w=W, + ) + + # prediction + dim_params = { + "B": B, + "N_patch_h": N_patch_h, + "N_patch_w": N_patch_w, + "N_ref": N_ref_cross, + } + results_ref_cross = self.ref_cross( + featmaps["query"], + featmaps["ref_cross"], + None, + dim_params, + need_attn_weights, + need_attn_weights_head_id, + ) + results["score_map_ref_cross"] = results_ref_cross["score_map"] + results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] + return results + + @torch.no_grad() + def get_featmaps(self, query_img, ref_cross_imgs): + """ + :param query_img: (B, 3, H, W) + :param ref_cross: (B, N_ref_cross, 3, H, W) + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + N_query = 1 + N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] + N_all_imgs = N_query + N_ref_cross + + # concat all images to go through backbone for once + all_imgs = [query_img.view(B, 1, 3, H, W)] + if ref_cross_imgs is not None: + all_imgs.append(ref_cross_imgs) + all_imgs = torch.cat(all_imgs, dim=1) + all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) + + # bbo: backbone output + bbo_all = self.backbone(all_imgs) + featmap_all = bbo_all.last_hidden_state[:, 1:] + featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) + + # query + featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) + N_patches = featmap_query.shape[1] + hidden_size = featmap_query.shape[2] + + # cross ref + if ref_cross_imgs is not None: + featmap_ref_cross = featmap_all[:, -N_ref_cross:] + featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) + else: + featmap_ref_cross = None + + featmaps = { + "query": featmap_query, # (B, num_patches, hidden_size) + "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) + } + return featmaps + + +def load_model(ckpt_path: str, device: str = "cpu") -> CrossScoreNet: + """Load a CrossScoreNet model from a Lightning or direct checkpoint. + + Args: + ckpt_path: Path to the .ckpt file. + device: Device to load the model on. + + Returns: + CrossScoreNet model in eval mode. + """ + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) + + # Extract config from checkpoint (saved by Lightning's save_hyperparameters) + if "hyper_parameters" in checkpoint: + cfg = OmegaConf.create(checkpoint["hyper_parameters"]) + else: + # Fallback: use default config + from pathlib import Path + + config_dir = Path(__file__).parent.parent / "config" + model_cfg = OmegaConf.load(config_dir / "model" / "model.yaml") + cfg = OmegaConf.create({"model": model_cfg}) + + model = CrossScoreNet(cfg) + + # Handle Lightning checkpoint format (keys prefixed with "model.") + state_dict = checkpoint.get("state_dict", checkpoint) + new_state_dict = {} + for k, v in state_dict.items(): + # Strip "model." prefix from Lightning checkpoint keys + if k.startswith("model."): + new_state_dict[k[6:]] = v + else: + new_state_dict[k] = v + + model.load_state_dict(new_state_dict, strict=False) + model.eval() + model.to(device) + return model diff --git a/crossscore/utils/__init__.py b/crossscore/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/utils/check_config.py b/crossscore/utils/check_config.py new file mode 100644 index 0000000..c8e67d8 --- /dev/null +++ b/crossscore/utils/check_config.py @@ -0,0 +1,71 @@ +def check_metric_prediction_config( + metric_type, + metric_min, + metric_max, +): + valid_max = False + valid_min = False + valid_type = False + + if metric_type in ["ssim", "mse", "mae"]: + valid_type = True + + if metric_max == 1: + valid_max = True + + if metric_type == "ssim": + if metric_min in [-1, 0]: + valid_min = True + elif metric_type in ["mse", "mae"]: + if metric_min == 0: + valid_min = True + + if not valid_type: + raise ValueError(f"Invalid metric type {metric_type}") + + valid_range = valid_min and valid_max + if not valid_range: + raise ValueError(f"Invalid metric range {metric_min} to {metric_max} for {metric_type}") + + +def check_reference_type(do_reference_cross): + if do_reference_cross: + ref_type = "cross" + else: + raise ValueError("Reference type must be 'cross'") + return ref_type + + +class ConfigChecker: + """ + Check if a config object is valid for + - train/val/test/predict steps that correspond to the lightning module; + - dataloader creation. + """ + + def __init__(self, cfg): + self.cfg = cfg + + def _check_common_lightning(self): + check_reference_type(self.cfg.model.do_reference_cross) + check_metric_prediction_config( + self.cfg.model.predict.metric.type, + self.cfg.model.predict.metric.min, + self.cfg.model.predict.metric.max, + ) + + def check_train_val(self): + self._check_common_lightning() + + def check_test(self): + self._check_common_lightning() + + def check_predict(self): + self._check_common_lightning() + + def check_dataset(self): + check_metric_prediction_config( + self.cfg.model.predict.metric.type, + self.cfg.model.predict.metric.min, + self.cfg.model.predict.metric.max, + ) diff --git a/crossscore/utils/io/__init__.py b/crossscore/utils/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/utils/io/images.py b/crossscore/utils/io/images.py new file mode 100644 index 0000000..9f98e04 --- /dev/null +++ b/crossscore/utils/io/images.py @@ -0,0 +1,63 @@ +import imageio +import numpy as np +from dataclasses import dataclass +from PIL import Image +from typing import List + + +@dataclass +class ImageNetMeanStd: + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + + +def f32(img): + img = img.astype(np.float32) + img = img / 255.0 + return img + + +def u8(img): + img = img * 255.0 + img = img.astype(np.uint8) + return img + + +def image_read(p): + img = np.array(Image.open(p)) + img = f32(img) + return img + + +def metric_map_read(p, vrange: List[int]): + """Read metric maps and convert to float. + Note: + - when read/write int32 to png, it acutally reads/writes uint16 but looks like int32. + - uint16 has range [0, 65535] + """ + m = np.array(Image.open(p)) # HW np.int32 + m = m.astype(np.float32) + if vrange == [0, 1]: + m = m / 65535 + elif vrange == [-1, 1]: + m = m / 32767 - 1 + else: + raise ValueError("Invalid range for metric map reading. Must be '[0,1]' or '[-1,1]'") + return m # HW np.float32 + + +def metric_map_write(p, m, vrange: List[int]): + """Convert float metric maps to integer and write to png. + Note: + - when read/write int32 to png, it acutally reads/writes uint16 but looks like int32. + - uint16 has range [0, 65535] + """ + if vrange == [0, 1]: + m = m * 65535 # [0,1] -> [0, 65535] + elif vrange == [-1, 1]: + m = (m + 1) * 32767 # [-1,1] -> [0, 2] -> [0, 65534] + else: + raise ValueError("Invalid range for metric map writing. Must be '[0,1]' or '[-1,1]'") + m = m.astype(np.int32) + # set compression level 0 for even faster writing + imageio.imwrite(p, m) diff --git a/crossscore/utils/misc/__init__.py b/crossscore/utils/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/utils/misc/image.py b/crossscore/utils/misc/image.py new file mode 100644 index 0000000..614f826 --- /dev/null +++ b/crossscore/utils/misc/image.py @@ -0,0 +1,89 @@ +import matplotlib.pyplot as plt +import numpy as np +import matplotlib.cm as cm +from PIL import Image, ImageDraw, ImageFont +from crossscore.utils.io.images import u8 + + +def jigsaw_to_image(x, grid_size): + """ + :param x: (B, N_patch_h * N_patch_w, patch_size_h, patch_size_w) + :param grid_size: a tuple: (N_patch_h, N_patch_w) + :return: (B, H, W) + """ + batch_size, num_patches, jigsaw_h, jigsaw_w = x.size() + assert num_patches == grid_size[0] * grid_size[1] + x_image = x.view(batch_size, grid_size[0], grid_size[1], jigsaw_h, jigsaw_w) + output_h = grid_size[0] * jigsaw_h + output_w = grid_size[1] * jigsaw_w + x_image = x_image.permute(0, 1, 3, 2, 4).contiguous() + x_image = x_image.view(batch_size, output_h, output_w) + return x_image + + +def de_norm_img(img, mean_std): + """De-normalize images that are normalized by mean and std in ImageNet-style. + :param img: (H, W, 3) + :param mean_std: (6, ) + """ + mean, std = mean_std[:3], mean_std[3:] + img = img * std[None, None] + img = img + mean[None, None] + return img + + +def gray2rgb(img, vrange, cmap="turbo"): + """ + Args: + img: HW, numpy.float32 + vrange: (min, max), float + cmap: str + """ + vmin, vmax = vrange + norm_op = plt.Normalize(vmin=vmin, vmax=vmax) + colormap = cm.get_cmap(cmap) + + img = norm_op(img) + img = colormap(img) + rgb_image = u8(img[:, :, :3]) + return rgb_image + + +def attn2rgb(attn_map, cmap="turbo"): + """Visualise attention map in rgb. + The attn_map is softmaxed so we need to use log to make it more visible. + Args: + attn_map: HW, numpy.float32 + cmap: str + """ + eps = 1e-8 # to avoid log(0) + attn_map = attn_map.clip(0, 1) + attn_map = attn_map + eps # (1e-8, 1 + 1e-8) + attn_map = attn_map.clip(0, 1) # (1e-8, 1) + # invert softmax (exp'd) attn weights + attn_map = np.log(attn_map) # (np.log(eps), 0) + attn_map = attn_map - np.log(eps) # (0, -np.log(eps)) + + # some norm_op and colormap + norm_op = plt.Normalize(vmin=0, vmax=-np.log(eps)) + colormap = cm.get_cmap(cmap) + attn_map = norm_op(attn_map) + attn_map = colormap(attn_map) + rgb_image = u8(attn_map[:, :, :3]) + return rgb_image + + +def img_add_text( + img_rgb, + text, + text_position=(20, 20), + text_colour=(255, 255, 255), + font_size=50, + font_path="/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf", +): + img = Image.fromarray(img_rgb) + font = ImageFont.truetype(font_path, font_size) + draw = ImageDraw.Draw(img) + draw.text(text_position, text, text_colour, font=font) + img = np.array(img) + return img diff --git a/crossscore/utils/neighbour/__init__.py b/crossscore/utils/neighbour/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/utils/neighbour/sampler.py b/crossscore/utils/neighbour/sampler.py new file mode 100644 index 0000000..e0f87eb --- /dev/null +++ b/crossscore/utils/neighbour/sampler.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +import numpy as np + + +class SampleBase(ABC): + def __init__(self, N_sample): + self.N_sample = N_sample + + @abstractmethod + def sample(self): + pass + + +class SamplerRandom(SampleBase): + def __init__(self, N_sample, deterministic): + self.deterministic = deterministic + super().__init__(N_sample) + + def sample(self, query, ref_list): + num_ref = len(ref_list) + if self.N_sample > num_ref: + # pad empty_image placeholders if ref list < N_sample + num_empty = self.N_sample - num_ref + placeholder = ["empty_image"] * num_empty + result = ref_list + placeholder + result = np.random.permutation(result).tolist() + else: + result = [] + + if self.deterministic: + samples = ref_list[: self.N_sample] + else: + samples = np.random.choice(ref_list, self.N_sample, replace=False).tolist() + result.extend(samples) + return result + + +class SamplerFactory: + def __init__( + self, + strategy_name, + N_sample, + deterministic, + **kwargs, + ): + self.N_sample = N_sample + self.deterministic = deterministic + + if strategy_name == "random": + self.sampler = SamplerRandom( + N_sample=self.N_sample, + deterministic=self.deterministic, + ) + else: + raise NotImplementedError + + def __call__(self, query, ref_list): + return self.sampler.sample(query, ref_list) diff --git a/environment_cpu.yaml b/environment_cpu.yaml new file mode 100644 index 0000000..90554a1 --- /dev/null +++ b/environment_cpu.yaml @@ -0,0 +1,20 @@ +name: CrossScore +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python>=3.9 + # PyTorch CPU-only + - pytorch>=2.0.0 + - torchvision>=0.15.0 + - cpuonly + # CrossScore dependencies + - transformers>=4.30.0 + - omegaconf>=2.3.0 + - pillow>=9.0.0 + - imageio>=2.20.0 + - numpy>=1.22.0 + - matplotlib>=3.5.0 + - tqdm>=4.60.0 + - huggingface_hub>=0.14.0 diff --git a/environment_gpu.yaml b/environment_gpu.yaml new file mode 100644 index 0000000..8f5675f --- /dev/null +++ b/environment_gpu.yaml @@ -0,0 +1,21 @@ +name: CrossScore +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python>=3.9 + # PyTorch with CUDA + - pytorch>=2.0.0 + - torchvision>=0.15.0 + - pytorch-cuda=12.1 + # CrossScore dependencies + - transformers>=4.30.0 + - omegaconf>=2.3.0 + - pillow>=9.0.0 + - imageio>=2.20.0 + - numpy>=1.22.0 + - matplotlib>=3.5.0 + - tqdm>=4.60.0 + - huggingface_hub>=0.14.0 diff --git a/task/core.py b/task/core.py index 80de8be..99b8fb1 100644 --- a/task/core.py +++ b/task/core.py @@ -1,164 +1,20 @@ +"""Training-only Lightning module. Not part of the pip package. + +Imports the CrossScoreNet model from the crossscore package and wraps it +in a LightningModule for training/validation/testing. +""" + from pathlib import Path import torch import lightning -import wandb -from transformers import Dinov2Config, Dinov2Model from omegaconf import DictConfig, OmegaConf from lightning.pytorch.utilities import rank_zero_only -from utils.evaluation.metric import abs2psnr, correlation -from utils.evaluation.metric_logger import ( - MetricLoggerScalar, - MetricLoggerHistogram, - MetricLoggerCorrelation, - MetricLoggerImg, -) -from utils.plot.batch_visualiser import BatchVisualiserFactory -from utils.io.images import ImageNetMeanStd -from utils.io.batch_writer import BatchWriter -from utils.io.score_summariser import ( - SummaryWriterPredictedOnline, - SummaryWriterPredictedOnlineTestPrediction, -) -from model.cross_reference import CrossReferenceNet -from model.positional_encoding import MultiViewPosionalEmbeddings - - -class CrossScoreNet(torch.nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - - # used in 1. denormalising images for visualisation - # and 2. normalising images for training when required - img_norm_stat = ImageNetMeanStd() - self.register_buffer( - "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) - ) - - # backbone, freeze - self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) - self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) - for param in self.backbone.parameters(): - param.requires_grad = False - - # positional encoding layer - self.pos_enc_fn = MultiViewPosionalEmbeddings( - positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, - positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, - interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, - req_grad=self.cfg.model.pos_enc.multi_view.req_grad, - patch_size=self.cfg.model.patch_size, - hidden_size=self.dinov2_cfg.hidden_size, - ) - - # cross reference predictor - if self.cfg.model.do_reference_cross: - self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) - - def forward( - self, - query_img, - ref_cross_imgs, - need_attn_weights, - need_attn_weights_head_id, - norm_img, - ): - """ - :param query_img: (B, 3, H, W) - :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) - :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - - if norm_img: - img_mean = self.img_mean_std[None, :3, None, None] - img_std = self.img_mean_std[None, 3:, None, None] - query_img = (query_img - img_mean) / img_std - if ref_cross_imgs is not None: - ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] - - featmaps = self.get_featmaps(query_img, ref_cross_imgs) - results = {} - - # processing (and predicting) for query - featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) - - if self.cfg.model.do_reference_cross: - N_ref_cross = ref_cross_imgs.shape[1] - - # (B, N_ref_cross*num_patches, hidden_size) - featmaps["ref_cross"] = self.pos_enc_fn( - featmaps["ref_cross"], - N_view=N_ref_cross, - img_h=H, - img_w=W, - ) - - # prediction - dim_params = { - "B": B, - "N_patch_h": N_patch_h, - "N_patch_w": N_patch_w, - "N_ref": N_ref_cross, - } - results_ref_cross = self.ref_cross( - featmaps["query"], - featmaps["ref_cross"], - None, - dim_params, - need_attn_weights, - need_attn_weights_head_id, - ) - results["score_map_ref_cross"] = results_ref_cross["score_map"] - results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] - return results - - @torch.no_grad() - def get_featmaps(self, query_img, ref_cross_imgs): - """ - :param query_img: (B, 3, H, W) - :param ref_cross: (B, N_ref_cross, 3, H, W) - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - N_query = 1 - N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] - N_all_imgs = N_query + N_ref_cross - # concat all images to go through backbone for once - all_imgs = [query_img.view(B, 1, 3, H, W)] - if ref_cross_imgs is not None: - all_imgs.append(ref_cross_imgs) - all_imgs = torch.cat(all_imgs, dim=1) - all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) +# Model from the pip package +from crossscore.task.core import CrossScoreNet - # bbo: backbone output - bbo_all = self.backbone(all_imgs) - featmap_all = bbo_all.last_hidden_state[:, 1:] - featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) - - # query - featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) - N_patches = featmap_query.shape[1] - hidden_size = featmap_query.shape[2] - - # cross ref - if ref_cross_imgs is not None: - featmap_ref_cross = featmap_all[:, -N_ref_cross:] - featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) - else: - featmap_ref_cross = None - - featmaps = { - "query": featmap_query, # (B, num_patches, hidden_size) - "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) - } - return featmaps +# Training-only imports +from crossscore.utils.io.images import ImageNetMeanStd class CrossScoreLightningModule(lightning.LightningModule): @@ -166,19 +22,17 @@ def __init__(self, cfg: DictConfig): super().__init__() self.cfg = cfg - # write config to wandb self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) - # init my network + # init my network (from pip package) self.model = CrossScoreNet(cfg=self.cfg) - # init visualiser - self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() + # lazy imports for training-only deps + from crossscore.utils.check_config import check_reference_type # init loss fn if self.cfg.model.loss.fn == "l1": self.loss_fn = torch.nn.L1Loss() - self.to_psnr_fn = abs2psnr else: raise NotImplementedError @@ -187,100 +41,24 @@ def __init__(self, cfg: DictConfig): if self.cfg.model.do_reference_cross: self.ref_mode_names.append("ref_cross") - def on_fit_start(self): - # reset logging cache - if self.global_rank == 0: - self._reset_logging_cache_train() - self._reset_logging_cache_validation() - - self.frame_score_summariser = SummaryWriterPredictedOnline( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - ) - - def on_test_start(self): - Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.test.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.test.out_dir, - ) - - def on_predict_start(self): - Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.predict.out_dir, - ) - - def _reset_logging_cache_train(self): - self.train_cache = { - "loss": { - k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "map": { - "score": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "l1_diff": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "delta": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["self", "cross"] - }, - }, - } - - def _reset_logging_cache_validation(self): - self.validation_cache = { - "loss": { - k: MetricLoggerScalar(max_length=None) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names - }, - "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, - } - def _core_step(self, batch, batch_idx, skip_loss=False): outputs = self.model( - query_img=batch["query/img"], # (B, C, H, W) - ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) + query_img=batch["query/img"], + ref_cross_imgs=batch.get("reference/cross/imgs", None), need_attn_weights=self.cfg.model.need_attn_weights, need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, norm_img=False, ) - if skip_loss: # only used in predict_step + if skip_loss: return outputs - score_map = batch["query/score_map"] # (B, H, W) - + score_map = batch["query/score_map"] loss = [] - # cross reference model predicts + if self.cfg.model.do_reference_cross: - score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) - l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) + score_map_cross = outputs["score_map_ref_cross"] + l1_diff_map_cross = torch.abs(score_map_cross - score_map) if self.cfg.model.loss.fn == "l1": loss_cross = l1_diff_map_cross.mean() else: @@ -294,203 +72,25 @@ def _core_step(self, batch, batch_idx, skip_loss=False): return outputs def training_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def test_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def predict_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx, skip_loss=True) - return outputs + return self._core_step(batch, batch_idx, skip_loss=True) @rank_zero_only def on_train_batch_end(self, outputs, batch, batch_idx): - self.train_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.train_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) - self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) - - # logger vis batch - if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: - fig = self.visualiser.vis(batch, outputs) - self.logger.experiment.log({"train_batch": fig}) - - # logger vis X batches statics - if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: - # log loss - tmp_loss = self.train_cache["loss"]["final"].compute() - self.log("train/loss", tmp_loss, prog_bar=True) - - if self.cfg.model.do_reference_cross: - tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() - self.log("train/loss_cross", tmp_loss_cross) - - # log psnr - if self.cfg.model.do_reference_cross: - self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) - - # log correlation - if self.cfg.model.do_reference_cross: - self.log( - "train/correlation_cross", - self.train_cache["correlation"]["ref_cross"].compute(), - ) - - # logger vis X batches histogram - if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: - if self.cfg.model.do_reference_cross: - self.logger.experiment.log( - { - "train/score_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() - ), - "train/l1_diff_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() - ), - } - ) + self.log("train/loss", outputs["loss"], prog_bar=True) def on_validation_batch_end(self, outputs, batch, batch_idx): - self.validation_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.validation_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - if batch_idx < self.cfg.logger.cache_size.validation.n_fig: - fig = self.visualiser.vis(batch, outputs) - self.validation_cache["fig"]["batch"].update(fig) - - def on_test_batch_end(self, outputs, batch, batch_idx): - results = {"test/loss": outputs["loss"]} - - if self.cfg.model.do_reference_cross: - corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) - psnr = self.to_psnr_fn(outputs["loss_cross"]) - results["test/loss_cross"] = outputs["loss_cross"] - results["test/corr_cross"] = corr - results["test/psnr_cross"] = psnr - - self.log_dict( - results, - on_step=self.cfg.logger.test.on_step, - sync_dist=self.cfg.logger.test.sync_dist, - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.test.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.test.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - def on_predict_batch_end(self, outputs, batch, batch_idx): - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.predict.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - @rank_zero_only - def on_train_epoch_end(self): - self._reset_logging_cache_train() - - def on_validation_epoch_end(self): - sync_dist = True - self.log( - "validation/loss", - self.validation_cache["loss"]["final"].compute(), - prog_bar=True, - sync_dist=sync_dist, - ) - self.logger.experiment.log( - {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, - ) - - if self.cfg.model.do_reference_cross: - self.log( - "validation/loss_cross", - self.validation_cache["loss"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/correlation_cross", - self.validation_cache["correlation"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/psnr_cross", - self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), - sync_dist=sync_dist, - ) - - self._reset_logging_cache_validation() - self.frame_score_summariser.reset() - - def on_test_epoch_end(self): - self.frame_score_summariser.summarise() - - def on_predict_epoch_end(self): - self.frame_score_summariser.summarise() + self.log("validation/loss", outputs["loss"], prog_bar=True) def configure_optimizers(self): - # how to use configure_optimizers: - # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers - - # we freeze backbone and we only pass parameters that requires grad to optimizer: - # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 - # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor - # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 parameters = [p for p in self.model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW( params=parameters, @@ -501,8 +101,7 @@ def configure_optimizers(self): step_size=self.cfg.trainer.lr_scheduler.step_size, gamma=self.cfg.trainer.lr_scheduler.gamma, ) - - results = { + return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, @@ -510,4 +109,3 @@ def configure_optimizers(self): "frequency": 1, }, } - return results diff --git a/task/predict.py b/task/predict.py index e565b78..e51552b 100644 --- a/task/predict.py +++ b/task/predict.py @@ -1,9 +1,6 @@ from datetime import datetime -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.dataset.simple_reference import SimpleReference -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from task.core import CrossScoreLightningModule +from crossscore.dataloading.dataset.simple_reference import SimpleReference +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_predict") diff --git a/task/test.py b/task/test.py index ac0f6e3..f2c1f6d 100644 --- a/task/test.py +++ b/task/test.py @@ -1,8 +1,5 @@ -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_test") diff --git a/task/train.py b/task/train.py index acc2650..bc84188 100644 --- a/task/train.py +++ b/task/train.py @@ -1,9 +1,6 @@ -import sys from pathlib import Path from datetime import timedelta -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -16,11 +13,11 @@ from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd -from utils.check_config import ConfigChecker +from task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.utils.check_config import ConfigChecker @hydra.main(version_base="1.3", config_path="../config", config_name="default")