diff --git a/.gitignore b/.gitignore index de52713..ffb57d3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,12 @@ tmp.sh log* datadir* debug* + +# Package build artifacts +dist/ +build/ +*.egg-info/ +*.egg + +# Plan file +plan.md diff --git a/README.md b/README.md index 7445e99..ef5c9ee 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,81 @@ University of Oxford. ## Table of Content -- [Environment](#Environment) +- [Installation (pip)](#installation-pip) +- [Quick Start](#quick-start) +- [Environment (conda, legacy)](#environment-conda-legacy) - [Data](#Data) - [Training](#Training) - [Inferencing](#Inferencing) -## Environment -We provide a `environment.yaml` file to set up a `conda` environment: +## Installation (pip) + +**Step 1**: Install PyTorch with your preferred CUDA version (see [pytorch.org](https://pytorch.org/get-started/locally/)): +```bash +# Example: PyTorch with CUDA 12.1 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + +# Example: PyTorch with CUDA 11.8 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 + +# Example: CPU only +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +``` + +**Step 2**: Install CrossScore: +```bash +pip install crossscore +``` + +Or install from source: +```bash +git clone https://github.com/ActiveVisionLab/CrossScore.git +cd CrossScore +pip install -e . +``` + +> **Why install PyTorch separately?** PyTorch distributions are coupled with specific CUDA versions. By letting you install PyTorch first, we avoid version conflicts with your system's CUDA setup. CrossScore works with PyTorch 2.0+ and any CUDA version it supports. + +## Quick Start + +### Python API +```python +import crossscore + +# Score query images against reference images +# Model checkpoint is auto-downloaded on first use (~129MB) +results = crossscore.score( + query_dir="path/to/query/images", + reference_dir="path/to/reference/images", +) + +# Per-image mean scores +print(results["scores"]) # [0.82, 0.91, 0.76, ...] + +# Score map tensors (pixel-level quality maps) +for score_map in results["score_maps"]: + print(score_map.shape) # (batch_size, H, W) + +# Colorized score map PNGs are written to results["out_dir"] +``` + +### Command Line +```bash +crossscore --query-dir path/to/queries --reference-dir path/to/references + +# With options +crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + +# Force CPU mode +crossscore --query-dir renders/ --reference-dir gt/ --cpu +``` + +### Environment Variables +- `CROSSSCORE_CKPT_PATH`: Use a specific local checkpoint instead of auto-downloading +- `CROSSSCORE_CACHE_DIR`: Custom cache directory (default: `~/.cache/crossscore`) + +## Environment (conda, legacy) +We also provide a `environment.yaml` file to set up a `conda` environment: ```bash git clone https://github.com/ActiveVisionLab/CrossScore.git cd CrossScore @@ -86,7 +154,7 @@ on our project page. - [ ] Create a HuggingFace demo page. - [ ] Release ECCV quantitative results related scripts. - [x] Release [data processing scripts](https://github.com/ziruiw-dev/CrossScore-3DGS-Preprocessing) -- [ ] Release PyPI and Conda package. +- [x] Release PyPI package. ## Acknowledgement This research is supported by an diff --git a/crossscore/__init__.py b/crossscore/__init__.py new file mode 100644 index 0000000..356ca68 --- /dev/null +++ b/crossscore/__init__.py @@ -0,0 +1,38 @@ +"""CrossScore: Towards Multi-View Image Evaluation and Scoring. + +A pip-installable package for neural image quality assessment using +cross-reference scoring with DINOv2 backbone. + +Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) + >>> print(results["scores"]) # per-image mean scores +""" + +__version__ = "1.0.0" + + +def score(*args, **kwargs): + """Score query images against reference images using CrossScore. + + See crossscore.api.score for full documentation. + """ + from crossscore.api import score as _score + + return _score(*args, **kwargs) + + +def get_checkpoint_path(): + """Get path to the CrossScore checkpoint, downloading if necessary. + + See crossscore._download.get_checkpoint_path for full documentation. + """ + from crossscore._download import get_checkpoint_path as _get + + return _get() + + +__all__ = ["score", "get_checkpoint_path"] diff --git a/crossscore/_download.py b/crossscore/_download.py new file mode 100644 index 0000000..2a669bc --- /dev/null +++ b/crossscore/_download.py @@ -0,0 +1,33 @@ +"""Utilities for downloading CrossScore model checkpoints.""" + +import os +from pathlib import Path + +HF_REPO_ID = "ActiveVisionLab/CrossScore" +CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt" + + +def get_checkpoint_path() -> str: + """Get path to the CrossScore checkpoint, downloading it if necessary. + + Downloads from HuggingFace Hub on first use and caches locally. + Set environment variables to customize: + CROSSSCORE_CKPT_PATH - use a specific local checkpoint file + + Returns: + Path to the checkpoint file. + """ + # Allow user to override with a custom path + custom_path = os.environ.get("CROSSSCORE_CKPT_PATH") + if custom_path: + if not Path(custom_path).exists(): + raise FileNotFoundError(f"Checkpoint not found at CROSSSCORE_CKPT_PATH={custom_path}") + return custom_path + + from huggingface_hub import hf_hub_download + + path = hf_hub_download( + repo_id=HF_REPO_ID, + filename=CHECKPOINT_FILENAME, + ) + return path diff --git a/crossscore/api.py b/crossscore/api.py new file mode 100644 index 0000000..7e7380b --- /dev/null +++ b/crossscore/api.py @@ -0,0 +1,171 @@ +"""High-level API for CrossScore image quality assessment.""" + +from pathlib import Path +from typing import Optional, Union, List + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import v2 as T +from omegaconf import OmegaConf +from tqdm import tqdm + +from crossscore._download import get_checkpoint_path +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.dataloading.dataset.simple_reference import SimpleReference + + +def _write_score_maps(score_maps, query_paths, out_dir, metric_type, metric_min, metric_max): + """Write score maps to disk as colorized PNGs.""" + from PIL import Image + from crossscore.utils.misc.image import gray2rgb + + vrange_vis = [metric_min, metric_max] + out_dir = Path(out_dir) / "score_maps" + out_dir.mkdir(parents=True, exist_ok=True) + + idx = 0 + for batch_maps, batch_paths in zip(score_maps, query_paths): + for score_map, qpath in zip(batch_maps, batch_paths): + fname = Path(qpath).stem + ".png" + rgb = gray2rgb(score_map.cpu().numpy(), vrange_vis) + Image.fromarray(rgb).save(out_dir / fname) + idx += 1 + return str(out_dir) + + +def score( + query_dir: str, + reference_dir: str, + ckpt_path: Optional[str] = None, + metric_type: str = "ssim", + batch_size: int = 8, + num_workers: int = 4, + resize_short_side: int = 518, + device: Optional[str] = None, + out_dir: Optional[str] = None, + write_score_maps: bool = True, +) -> dict: + """Score query images against reference images using CrossScore. + + Args: + query_dir: Directory containing query images (e.g., NVS rendered images). + reference_dir: Directory containing reference images (e.g., real captured images). + ckpt_path: Path to model checkpoint. Auto-downloads if not provided. + metric_type: Metric type to predict. One of "ssim", "mae", "mse". + batch_size: Batch size for inference. + num_workers: Number of data loading workers. + resize_short_side: Resize images so short side equals this value. -1 to disable. + device: Device string ("cuda", "cuda:0", "cpu"). Auto-detected if None. + out_dir: Output directory for score maps. Defaults to "./crossscore_output". + write_score_maps: Whether to write colorized score map PNGs to disk. + + Returns: + Dictionary with: + - "score_maps": List of score map tensors, each (B, H, W) + - "scores": List of per-image mean scores (float) + - "out_dir": Output directory path (if write_score_maps=True) + + Example: + >>> import crossscore + >>> results = crossscore.score( + ... query_dir="path/to/query/images", + ... reference_dir="path/to/reference/images", + ... ) + >>> print(results["scores"]) # per-image mean scores + """ + from crossscore.task.core import load_model + + # Determine device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Get checkpoint + if ckpt_path is None: + ckpt_path = get_checkpoint_path() + + # Load model + model = load_model(ckpt_path, device=device) + + # Set up data transforms + img_norm_stat = ImageNetMeanStd() + transforms = { + "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), + } + if resize_short_side > 0: + transforms["resize"] = T.Resize( + resize_short_side, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + ) + + # Build dataset and dataloader + neighbour_config = {"strategy": "random", "cross": 5, "deterministic": False} + dataset = SimpleReference( + query_dir=query_dir, + reference_dir=reference_dir, + transforms=transforms, + neighbour_config=neighbour_config, + return_item_paths=True, + zero_reference=False, + ) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=(device != "cpu"), + persistent_workers=False, + ) + + # Run inference + all_score_maps = [] + all_scores = [] + all_query_paths = [] + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="CrossScore"): + query_img = batch["query/img"].to(device) + ref_imgs = batch.get("reference/cross/imgs") + if ref_imgs is not None: + ref_imgs = ref_imgs.to(device) + + outputs = model( + query_img=query_img, + ref_cross_imgs=ref_imgs, + norm_img=False, + ) + + score_map = outputs["score_map_ref_cross"] # (B, H, W) + all_score_maps.append(score_map.cpu()) + + # Per-image mean score + for i in range(score_map.shape[0]): + all_scores.append(score_map[i].mean().item()) + + # Track query paths for output naming + if "item_paths" in batch and "query/img" in batch["item_paths"]: + all_query_paths.append(batch["item_paths"]["query/img"]) + + # Build results + metric_min = -1 if metric_type == "ssim" else 0 + if metric_type == "ssim": + metric_min = 0 # CrossScore predicts SSIM in [0, 1] by default + + results = { + "score_maps": all_score_maps, + "scores": all_scores, + } + + # Write outputs + if write_score_maps and all_score_maps: + if out_dir is None: + out_dir = "./crossscore_output" + written_dir = _write_score_maps( + all_score_maps, all_query_paths, out_dir, + metric_type, metric_min, metric_max=1, + ) + results["out_dir"] = written_dir + + return results diff --git a/crossscore/cli.py b/crossscore/cli.py new file mode 100644 index 0000000..7be9d9c --- /dev/null +++ b/crossscore/cli.py @@ -0,0 +1,98 @@ +"""Command-line interface for CrossScore.""" + +import argparse + + +def main(): + parser = argparse.ArgumentParser( + description="CrossScore: Multi-View Image Quality Assessment", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""\ +Examples: + crossscore --query-dir path/to/queries --reference-dir path/to/references + crossscore --query-dir renders/ --reference-dir gt/ --metric-type mae --batch-size 4 + crossscore --query-dir renders/ --reference-dir gt/ --cpu +""", + ) + parser.add_argument( + "--query-dir", required=True, help="Directory containing query images" + ) + parser.add_argument( + "--reference-dir", required=True, help="Directory containing reference images" + ) + parser.add_argument( + "--ckpt-path", + default=None, + help="Path to model checkpoint (auto-downloads if not provided)", + ) + parser.add_argument( + "--metric-type", + default="ssim", + choices=["ssim", "mae", "mse"], + help="Metric type to predict (default: ssim)", + ) + parser.add_argument( + "--batch-size", type=int, default=8, help="Batch size (default: 8)" + ) + parser.add_argument( + "--num-workers", type=int, default=4, help="Data loading workers (default: 4)" + ) + parser.add_argument( + "--resize-short-side", + type=int, + default=518, + help="Resize short side to this value, -1 to disable (default: 518)", + ) + parser.add_argument( + "--device", + default=None, + help="Device string, e.g. 'cuda', 'cuda:0', 'cpu' (default: auto-detect)", + ) + parser.add_argument( + "--cpu", + action="store_true", + help="Force CPU mode (no GPU)", + ) + parser.add_argument( + "--out-dir", + default=None, + help="Output directory for results (default: ./crossscore_output)", + ) + parser.add_argument( + "--no-write", + action="store_true", + help="Do not write score map images to disk", + ) + + args = parser.parse_args() + + from crossscore.api import score + + device = "cpu" if args.cpu else args.device + + results = score( + query_dir=args.query_dir, + reference_dir=args.reference_dir, + ckpt_path=args.ckpt_path, + metric_type=args.metric_type, + batch_size=args.batch_size, + num_workers=args.num_workers, + resize_short_side=args.resize_short_side, + device=device, + out_dir=args.out_dir, + write_score_maps=not args.no_write, + ) + + n_images = len(results["scores"]) + print(f"\nCrossScore completed: {n_images} images scored") + if results["scores"]: + mean_score = sum(results["scores"]) / len(results["scores"]) + print(f"Mean score: {mean_score:.4f}") + for i, s in enumerate(results["scores"]): + print(f" Image {i}: {s:.4f}") + if "out_dir" in results: + print(f"Score maps written to: {results['out_dir']}") + + +if __name__ == "__main__": + main() diff --git a/crossscore/config/model/model.yaml b/crossscore/config/model/model.yaml new file mode 100644 index 0000000..06fc8d6 --- /dev/null +++ b/crossscore/config/model/model.yaml @@ -0,0 +1,32 @@ +patch_size: 14 +do_reference_cross: True + +decoder_do_self_attn: True +decoder_do_short_cut: True +need_attn_weights: False # def False, requires more gpu mem if True +need_attn_weights_head_id: 0 # check which attn head + +backbone: + from_pretrained: facebook/dinov2-small + +pos_enc: + multi_view: + interpolate_mode: bilinear + req_grad: False + h: 40 # def 40 so we always interpolate in training, could be 37 or 16 too. + w: 40 + +loss: + fn: l1 + +predict: + metric: + type: ssim + # type: mae + # type: mse + + # min: -1 + min: 0 + max: 1 + + power_factor: default # can be a scalar \ No newline at end of file diff --git a/crossscore/dataloading/__init__.py b/crossscore/dataloading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/dataloading/dataset/__init__.py b/crossscore/dataloading/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloading/dataset/nvs_dataset.py b/crossscore/dataloading/dataset/nvs_dataset.py similarity index 99% rename from dataloading/dataset/nvs_dataset.py rename to crossscore/dataloading/dataset/nvs_dataset.py index 8ebec16..8e6cc7f 100644 --- a/dataloading/dataset/nvs_dataset.py +++ b/crossscore/dataloading/dataset/nvs_dataset.py @@ -5,10 +5,9 @@ from torch.utils.data import Dataset from omegaconf import OmegaConf -sys.path.append(str(Path(__file__).parents[2])) -from utils.io.images import metric_map_read, image_read -from utils.neighbour.sampler import SamplerFactory -from utils.check_config import ConfigChecker +from crossscore.utils.io.images import metric_map_read, image_read +from crossscore.utils.neighbour.sampler import SamplerFactory +from crossscore.utils.check_config import ConfigChecker class NeighbourSelector: diff --git a/dataloading/dataset/simple_reference.py b/crossscore/dataloading/dataset/simple_reference.py similarity index 96% rename from dataloading/dataset/simple_reference.py rename to crossscore/dataloading/dataset/simple_reference.py index 3fc4722..cf72e92 100644 --- a/dataloading/dataset/simple_reference.py +++ b/crossscore/dataloading/dataset/simple_reference.py @@ -3,8 +3,7 @@ import torch from omegaconf import OmegaConf -sys.path.append(str(Path(__file__).parents[2])) -from dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch +from crossscore.dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch class SimpleReference(NvsDataset): @@ -89,8 +88,8 @@ def get_paths(query_dir, reference_dir): from lightning import seed_everything from torchvision.transforms import v2 as T from tqdm import tqdm - from dataloading.transformation.crop import CropperFactory - from utils.io.images import ImageNetMeanStd + from crossscore.dataloading.transformation.crop import CropperFactory + from crossscore.utils.io.images import ImageNetMeanStd from omegaconf import OmegaConf seed_everything(1) diff --git a/crossscore/dataloading/transformation/__init__.py b/crossscore/dataloading/transformation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloading/transformation/crop.py b/crossscore/dataloading/transformation/crop.py similarity index 100% rename from dataloading/transformation/crop.py rename to crossscore/dataloading/transformation/crop.py diff --git a/crossscore/model/__init__.py b/crossscore/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/cross_reference.py b/crossscore/model/cross_reference.py similarity index 98% rename from model/cross_reference.py rename to crossscore/model/cross_reference.py index a952d78..32879bd 100644 --- a/model/cross_reference.py +++ b/crossscore/model/cross_reference.py @@ -4,7 +4,7 @@ TransformerDecoderCustomised, ) from .regression_layer import RegressionLayer -from utils.misc.image import jigsaw_to_image +from crossscore.utils.misc.image import jigsaw_to_image class CrossReferenceNet(torch.nn.Module): diff --git a/crossscore/model/customised_transformer/__init__.py b/crossscore/model/customised_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/customised_transformer/transformer.py b/crossscore/model/customised_transformer/transformer.py similarity index 100% rename from model/customised_transformer/transformer.py rename to crossscore/model/customised_transformer/transformer.py diff --git a/model/positional_encoding.py b/crossscore/model/positional_encoding.py similarity index 100% rename from model/positional_encoding.py rename to crossscore/model/positional_encoding.py diff --git a/model/regression_layer.py b/crossscore/model/regression_layer.py similarity index 95% rename from model/regression_layer.py rename to crossscore/model/regression_layer.py index e651048..f943afd 100644 --- a/model/regression_layer.py +++ b/crossscore/model/regression_layer.py @@ -3,8 +3,7 @@ from pathlib import Path import torch -sys.path.append(str(Path(__file__).parents[1])) -from utils.check_config import check_metric_prediction_config +from crossscore.utils.check_config import check_metric_prediction_config class RegressionLayer(torch.nn.Module): diff --git a/crossscore/task/__init__.py b/crossscore/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/crossscore/task/core.py b/crossscore/task/core.py new file mode 100644 index 0000000..0bca761 --- /dev/null +++ b/crossscore/task/core.py @@ -0,0 +1,186 @@ +"""CrossScoreNet: the core neural network for CrossScore inference.""" + +import torch +from transformers import Dinov2Config, Dinov2Model +from omegaconf import OmegaConf + +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.model.cross_reference import CrossReferenceNet +from crossscore.model.positional_encoding import MultiViewPosionalEmbeddings + + +class CrossScoreNet(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + img_norm_stat = ImageNetMeanStd() + self.register_buffer( + "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) + ) + + # backbone, freeze + self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) + self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) + for param in self.backbone.parameters(): + param.requires_grad = False + + # positional encoding layer + self.pos_enc_fn = MultiViewPosionalEmbeddings( + positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, + positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, + interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, + req_grad=self.cfg.model.pos_enc.multi_view.req_grad, + patch_size=self.cfg.model.patch_size, + hidden_size=self.dinov2_cfg.hidden_size, + ) + + # cross reference predictor + if self.cfg.model.do_reference_cross: + self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) + + def forward( + self, + query_img, + ref_cross_imgs, + need_attn_weights=False, + need_attn_weights_head_id=0, + norm_img=False, + ): + """ + :param query_img: (B, 3, H, W) + :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) + :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + + if norm_img: + img_mean = self.img_mean_std[None, :3, None, None] + img_std = self.img_mean_std[None, 3:, None, None] + query_img = (query_img - img_mean) / img_std + if ref_cross_imgs is not None: + ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] + + featmaps = self.get_featmaps(query_img, ref_cross_imgs) + results = {} + + # processing (and predicting) for query + featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) + + if self.cfg.model.do_reference_cross: + N_ref_cross = ref_cross_imgs.shape[1] + + # (B, N_ref_cross*num_patches, hidden_size) + featmaps["ref_cross"] = self.pos_enc_fn( + featmaps["ref_cross"], + N_view=N_ref_cross, + img_h=H, + img_w=W, + ) + + # prediction + dim_params = { + "B": B, + "N_patch_h": N_patch_h, + "N_patch_w": N_patch_w, + "N_ref": N_ref_cross, + } + results_ref_cross = self.ref_cross( + featmaps["query"], + featmaps["ref_cross"], + None, + dim_params, + need_attn_weights, + need_attn_weights_head_id, + ) + results["score_map_ref_cross"] = results_ref_cross["score_map"] + results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] + return results + + @torch.no_grad() + def get_featmaps(self, query_img, ref_cross_imgs): + """ + :param query_img: (B, 3, H, W) + :param ref_cross: (B, N_ref_cross, 3, H, W) + """ + B = query_img.shape[0] + H, W = query_img.shape[-2:] + N_patch_h = H // self.cfg.model.patch_size + N_patch_w = W // self.cfg.model.patch_size + N_query = 1 + N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] + N_all_imgs = N_query + N_ref_cross + + # concat all images to go through backbone for once + all_imgs = [query_img.view(B, 1, 3, H, W)] + if ref_cross_imgs is not None: + all_imgs.append(ref_cross_imgs) + all_imgs = torch.cat(all_imgs, dim=1) + all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) + + # bbo: backbone output + bbo_all = self.backbone(all_imgs) + featmap_all = bbo_all.last_hidden_state[:, 1:] + featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) + + # query + featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) + N_patches = featmap_query.shape[1] + hidden_size = featmap_query.shape[2] + + # cross ref + if ref_cross_imgs is not None: + featmap_ref_cross = featmap_all[:, -N_ref_cross:] + featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) + else: + featmap_ref_cross = None + + featmaps = { + "query": featmap_query, # (B, num_patches, hidden_size) + "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) + } + return featmaps + + +def load_model(ckpt_path: str, device: str = "cpu") -> CrossScoreNet: + """Load a CrossScoreNet model from a Lightning or direct checkpoint. + + Args: + ckpt_path: Path to the .ckpt file. + device: Device to load the model on. + + Returns: + CrossScoreNet model in eval mode. + """ + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) + + # Extract config from checkpoint (saved by Lightning's save_hyperparameters) + if "hyper_parameters" in checkpoint: + cfg = OmegaConf.create(checkpoint["hyper_parameters"]) + else: + # Fallback: use default config + from pathlib import Path + + config_dir = Path(__file__).parent.parent / "config" + model_cfg = OmegaConf.load(config_dir / "model" / "model.yaml") + cfg = OmegaConf.create({"model": model_cfg}) + + model = CrossScoreNet(cfg) + + # Handle Lightning checkpoint format (keys prefixed with "model.") + state_dict = checkpoint.get("state_dict", checkpoint) + new_state_dict = {} + for k, v in state_dict.items(): + # Strip "model." prefix from Lightning checkpoint keys + if k.startswith("model."): + new_state_dict[k[6:]] = v + else: + new_state_dict[k] = v + + model.load_state_dict(new_state_dict, strict=False) + model.eval() + model.to(device) + return model diff --git a/crossscore/utils/__init__.py b/crossscore/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/check_config.py b/crossscore/utils/check_config.py similarity index 100% rename from utils/check_config.py rename to crossscore/utils/check_config.py diff --git a/crossscore/utils/io/__init__.py b/crossscore/utils/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/io/images.py b/crossscore/utils/io/images.py similarity index 100% rename from utils/io/images.py rename to crossscore/utils/io/images.py diff --git a/crossscore/utils/misc/__init__.py b/crossscore/utils/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/misc/image.py b/crossscore/utils/misc/image.py similarity index 98% rename from utils/misc/image.py rename to crossscore/utils/misc/image.py index 73b2e9c..614f826 100644 --- a/utils/misc/image.py +++ b/crossscore/utils/misc/image.py @@ -2,7 +2,7 @@ import numpy as np import matplotlib.cm as cm from PIL import Image, ImageDraw, ImageFont -from utils.io.images import u8 +from crossscore.utils.io.images import u8 def jigsaw_to_image(x, grid_size): diff --git a/crossscore/utils/neighbour/__init__.py b/crossscore/utils/neighbour/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/neighbour/sampler.py b/crossscore/utils/neighbour/sampler.py similarity index 100% rename from utils/neighbour/sampler.py rename to crossscore/utils/neighbour/sampler.py diff --git a/dataloading/data_manager.py b/dataloading/data_manager.py deleted file mode 100644 index 3c4fb8b..0000000 --- a/dataloading/data_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from omegaconf import OmegaConf, ListConfig -from dataloading.dataset.nvs_dataset import NvsDataset -from pprint import pprint - - -def get_dataset(cfg, transforms, data_split, return_item_paths=False): - if isinstance(cfg.data.dataset.path, str): - dataset_path_list = [cfg.data.dataset.path] - elif isinstance(cfg.data.dataset.path, ListConfig): - dataset_path_list = OmegaConf.to_object(cfg.data.dataset.path) - else: - raise ValueError("cfg.data.dataset.path should be a string or a ListConfig") - - N_dataset = len(dataset_path_list) - print(f"Get {N_dataset} datasets for {data_split}:") - pprint(f"cfg.data.dataset.path: {dataset_path_list}") - print("==================================") - - dataset_list = [] - for i in range(N_dataset): - dataset_list.append( - NvsDataset( - dataset_path=dataset_path_list[i], - resolution=cfg.data.dataset.resolution, - data_split=data_split, - transforms=transforms, - neighbour_config=cfg.data.neighbour_config, - metric_type=cfg.model.predict.metric.type, - metric_min=cfg.model.predict.metric.min, - metric_max=cfg.model.predict.metric.max, - return_item_paths=return_item_paths, - num_gaussians_iters=cfg.data.dataset.num_gaussians_iters, - zero_reference=cfg.data.dataset.zero_reference, - ) - ) - if N_dataset == 1: - dataset = dataset_list[0] - else: - dataset = torch.utils.data.ConcatDataset(dataset_list) - return dataset diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0728490 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,75 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "crossscore" +version = "1.0.0" +description = "CrossScore: Towards Multi-View Image Evaluation and Scoring" +readme = "README.md" +license = {text = "Apache-2.0"} +requires-python = ">=3.9" +authors = [ + {name = "Zirui Wang"}, + {name = "Wenjing Bian"}, + {name = "Victor Adrian Prisacariu"}, +] +keywords = ["image quality assessment", "neural rendering", "novel view synthesis", "deep learning"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", +] + +# NOTE on PyTorch/CUDA: +# We require torch>=2.0.0 but do NOT pin a specific CUDA version. +# Users should install PyTorch with their preferred CUDA version FIRST: +# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +# Then install CrossScore: +# pip install crossscore +# +# This avoids CUDA version conflicts since PyTorch+CUDA combos vary by system. +dependencies = [ + # Core deep learning (flexible ranges - user installs PyTorch first) + "torch>=2.0.0", + "torchvision>=0.15.0", + + # DINOv2 backbone + "transformers>=4.30.0", + + # Configuration + "omegaconf>=2.3.0", + + # Image processing + "Pillow>=9.0.0", + "imageio>=2.20.0", + "numpy>=1.22.0", + "matplotlib>=3.5.0", + + # Progress bar + "tqdm>=4.60.0", +] + +[project.urls] +Homepage = "https://crossscore.active.vision" +Repository = "https://github.com/ActiveVisionLab/CrossScore" +Paper = "https://arxiv.org/abs/2404.14409" + +[project.scripts] +crossscore = "crossscore.cli:main" + +[tool.setuptools.packages.find] +include = ["crossscore*"] + +[tool.setuptools.package-data] +crossscore = [ + "config/**/*.yaml", +] diff --git a/task/core.py b/task/core.py index 80de8be..99b8fb1 100644 --- a/task/core.py +++ b/task/core.py @@ -1,164 +1,20 @@ +"""Training-only Lightning module. Not part of the pip package. + +Imports the CrossScoreNet model from the crossscore package and wraps it +in a LightningModule for training/validation/testing. +""" + from pathlib import Path import torch import lightning -import wandb -from transformers import Dinov2Config, Dinov2Model from omegaconf import DictConfig, OmegaConf from lightning.pytorch.utilities import rank_zero_only -from utils.evaluation.metric import abs2psnr, correlation -from utils.evaluation.metric_logger import ( - MetricLoggerScalar, - MetricLoggerHistogram, - MetricLoggerCorrelation, - MetricLoggerImg, -) -from utils.plot.batch_visualiser import BatchVisualiserFactory -from utils.io.images import ImageNetMeanStd -from utils.io.batch_writer import BatchWriter -from utils.io.score_summariser import ( - SummaryWriterPredictedOnline, - SummaryWriterPredictedOnlineTestPrediction, -) -from model.cross_reference import CrossReferenceNet -from model.positional_encoding import MultiViewPosionalEmbeddings - - -class CrossScoreNet(torch.nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - - # used in 1. denormalising images for visualisation - # and 2. normalising images for training when required - img_norm_stat = ImageNetMeanStd() - self.register_buffer( - "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) - ) - - # backbone, freeze - self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) - self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) - for param in self.backbone.parameters(): - param.requires_grad = False - - # positional encoding layer - self.pos_enc_fn = MultiViewPosionalEmbeddings( - positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, - positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, - interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, - req_grad=self.cfg.model.pos_enc.multi_view.req_grad, - patch_size=self.cfg.model.patch_size, - hidden_size=self.dinov2_cfg.hidden_size, - ) - - # cross reference predictor - if self.cfg.model.do_reference_cross: - self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) - - def forward( - self, - query_img, - ref_cross_imgs, - need_attn_weights, - need_attn_weights_head_id, - norm_img, - ): - """ - :param query_img: (B, 3, H, W) - :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) - :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - - if norm_img: - img_mean = self.img_mean_std[None, :3, None, None] - img_std = self.img_mean_std[None, 3:, None, None] - query_img = (query_img - img_mean) / img_std - if ref_cross_imgs is not None: - ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] - - featmaps = self.get_featmaps(query_img, ref_cross_imgs) - results = {} - - # processing (and predicting) for query - featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) - - if self.cfg.model.do_reference_cross: - N_ref_cross = ref_cross_imgs.shape[1] - - # (B, N_ref_cross*num_patches, hidden_size) - featmaps["ref_cross"] = self.pos_enc_fn( - featmaps["ref_cross"], - N_view=N_ref_cross, - img_h=H, - img_w=W, - ) - - # prediction - dim_params = { - "B": B, - "N_patch_h": N_patch_h, - "N_patch_w": N_patch_w, - "N_ref": N_ref_cross, - } - results_ref_cross = self.ref_cross( - featmaps["query"], - featmaps["ref_cross"], - None, - dim_params, - need_attn_weights, - need_attn_weights_head_id, - ) - results["score_map_ref_cross"] = results_ref_cross["score_map"] - results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] - return results - - @torch.no_grad() - def get_featmaps(self, query_img, ref_cross_imgs): - """ - :param query_img: (B, 3, H, W) - :param ref_cross: (B, N_ref_cross, 3, H, W) - """ - B = query_img.shape[0] - H, W = query_img.shape[-2:] - N_patch_h = H // self.cfg.model.patch_size - N_patch_w = W // self.cfg.model.patch_size - N_query = 1 - N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] - N_all_imgs = N_query + N_ref_cross - # concat all images to go through backbone for once - all_imgs = [query_img.view(B, 1, 3, H, W)] - if ref_cross_imgs is not None: - all_imgs.append(ref_cross_imgs) - all_imgs = torch.cat(all_imgs, dim=1) - all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) +# Model from the pip package +from crossscore.task.core import CrossScoreNet - # bbo: backbone output - bbo_all = self.backbone(all_imgs) - featmap_all = bbo_all.last_hidden_state[:, 1:] - featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) - - # query - featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) - N_patches = featmap_query.shape[1] - hidden_size = featmap_query.shape[2] - - # cross ref - if ref_cross_imgs is not None: - featmap_ref_cross = featmap_all[:, -N_ref_cross:] - featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) - else: - featmap_ref_cross = None - - featmaps = { - "query": featmap_query, # (B, num_patches, hidden_size) - "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) - } - return featmaps +# Training-only imports +from crossscore.utils.io.images import ImageNetMeanStd class CrossScoreLightningModule(lightning.LightningModule): @@ -166,19 +22,17 @@ def __init__(self, cfg: DictConfig): super().__init__() self.cfg = cfg - # write config to wandb self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) - # init my network + # init my network (from pip package) self.model = CrossScoreNet(cfg=self.cfg) - # init visualiser - self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() + # lazy imports for training-only deps + from crossscore.utils.check_config import check_reference_type # init loss fn if self.cfg.model.loss.fn == "l1": self.loss_fn = torch.nn.L1Loss() - self.to_psnr_fn = abs2psnr else: raise NotImplementedError @@ -187,100 +41,24 @@ def __init__(self, cfg: DictConfig): if self.cfg.model.do_reference_cross: self.ref_mode_names.append("ref_cross") - def on_fit_start(self): - # reset logging cache - if self.global_rank == 0: - self._reset_logging_cache_train() - self._reset_logging_cache_validation() - - self.frame_score_summariser = SummaryWriterPredictedOnline( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - ) - - def on_test_start(self): - Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.test.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.test.out_dir, - ) - - def on_predict_start(self): - Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) - else: - self.batch_writer = None - - self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( - metric_type=self.cfg.model.predict.metric.type, - metric_min=self.cfg.model.predict.metric.min, - dir_out=self.cfg.logger.predict.out_dir, - ) - - def _reset_logging_cache_train(self): - self.train_cache = { - "loss": { - k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "map": { - "score": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "l1_diff": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in self.ref_mode_names - }, - "delta": { - k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) - for k in ["self", "cross"] - }, - }, - } - - def _reset_logging_cache_validation(self): - self.validation_cache = { - "loss": { - k: MetricLoggerScalar(max_length=None) - for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names - }, - "correlation": { - k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names - }, - "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, - } - def _core_step(self, batch, batch_idx, skip_loss=False): outputs = self.model( - query_img=batch["query/img"], # (B, C, H, W) - ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) + query_img=batch["query/img"], + ref_cross_imgs=batch.get("reference/cross/imgs", None), need_attn_weights=self.cfg.model.need_attn_weights, need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, norm_img=False, ) - if skip_loss: # only used in predict_step + if skip_loss: return outputs - score_map = batch["query/score_map"] # (B, H, W) - + score_map = batch["query/score_map"] loss = [] - # cross reference model predicts + if self.cfg.model.do_reference_cross: - score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) - l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) + score_map_cross = outputs["score_map_ref_cross"] + l1_diff_map_cross = torch.abs(score_map_cross - score_map) if self.cfg.model.loss.fn == "l1": loss_cross = l1_diff_map_cross.mean() else: @@ -294,203 +72,25 @@ def _core_step(self, batch, batch_idx, skip_loss=False): return outputs def training_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def test_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx) - return outputs + return self._core_step(batch, batch_idx) def predict_step(self, batch, batch_idx): - outputs = self._core_step(batch, batch_idx, skip_loss=True) - return outputs + return self._core_step(batch, batch_idx, skip_loss=True) @rank_zero_only def on_train_batch_end(self, outputs, batch, batch_idx): - self.train_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.train_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) - self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) - - # logger vis batch - if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: - fig = self.visualiser.vis(batch, outputs) - self.logger.experiment.log({"train_batch": fig}) - - # logger vis X batches statics - if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: - # log loss - tmp_loss = self.train_cache["loss"]["final"].compute() - self.log("train/loss", tmp_loss, prog_bar=True) - - if self.cfg.model.do_reference_cross: - tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() - self.log("train/loss_cross", tmp_loss_cross) - - # log psnr - if self.cfg.model.do_reference_cross: - self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) - - # log correlation - if self.cfg.model.do_reference_cross: - self.log( - "train/correlation_cross", - self.train_cache["correlation"]["ref_cross"].compute(), - ) - - # logger vis X batches histogram - if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: - if self.cfg.model.do_reference_cross: - self.logger.experiment.log( - { - "train/score_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() - ), - "train/l1_diff_histogram_cross": wandb.Histogram( - np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() - ), - } - ) + self.log("train/loss", outputs["loss"], prog_bar=True) def on_validation_batch_end(self, outputs, batch, batch_idx): - self.validation_cache["loss"]["final"].update(outputs["loss"]) - - if self.cfg.model.do_reference_cross: - self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) - self.validation_cache["correlation"]["ref_cross"].update( - outputs["score_map_ref_cross"], batch["query/score_map"] - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - if batch_idx < self.cfg.logger.cache_size.validation.n_fig: - fig = self.visualiser.vis(batch, outputs) - self.validation_cache["fig"]["batch"].update(fig) - - def on_test_batch_end(self, outputs, batch, batch_idx): - results = {"test/loss": outputs["loss"]} - - if self.cfg.model.do_reference_cross: - corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) - psnr = self.to_psnr_fn(outputs["loss_cross"]) - results["test/loss_cross"] = outputs["loss_cross"] - results["test/corr_cross"] = corr - results["test/psnr_cross"] = psnr - - self.log_dict( - results, - on_step=self.cfg.logger.test.on_step, - sync_dist=self.cfg.logger.test.sync_dist, - ) - - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.test.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.test.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - def on_predict_batch_end(self, outputs, batch, batch_idx): - self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) - - # write image to vis - if ( - self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 - and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 - ): - fig = self.visualiser.vis(batch, outputs) - fig.image.save( - Path( - self.cfg.logger.predict.out_dir, - "vis", - f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", - ) - ) - - if self.cfg.logger.predict.write.flag.batch: - self.batch_writer.write_out( - batch_input=batch, - batch_output=outputs, - local_rank=self.local_rank, - batch_idx=batch_idx, - ) - - @rank_zero_only - def on_train_epoch_end(self): - self._reset_logging_cache_train() - - def on_validation_epoch_end(self): - sync_dist = True - self.log( - "validation/loss", - self.validation_cache["loss"]["final"].compute(), - prog_bar=True, - sync_dist=sync_dist, - ) - self.logger.experiment.log( - {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, - ) - - if self.cfg.model.do_reference_cross: - self.log( - "validation/loss_cross", - self.validation_cache["loss"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/correlation_cross", - self.validation_cache["correlation"]["ref_cross"].compute(), - sync_dist=sync_dist, - ) - self.log( - "validation/psnr_cross", - self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), - sync_dist=sync_dist, - ) - - self._reset_logging_cache_validation() - self.frame_score_summariser.reset() - - def on_test_epoch_end(self): - self.frame_score_summariser.summarise() - - def on_predict_epoch_end(self): - self.frame_score_summariser.summarise() + self.log("validation/loss", outputs["loss"], prog_bar=True) def configure_optimizers(self): - # how to use configure_optimizers: - # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers - - # we freeze backbone and we only pass parameters that requires grad to optimizer: - # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 - # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor - # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 parameters = [p for p in self.model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW( params=parameters, @@ -501,8 +101,7 @@ def configure_optimizers(self): step_size=self.cfg.trainer.lr_scheduler.step_size, gamma=self.cfg.trainer.lr_scheduler.gamma, ) - - results = { + return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, @@ -510,4 +109,3 @@ def configure_optimizers(self): "frequency": 1, }, } - return results diff --git a/task/predict.py b/task/predict.py index e565b78..e51552b 100644 --- a/task/predict.py +++ b/task/predict.py @@ -1,9 +1,6 @@ from datetime import datetime -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.dataset.simple_reference import SimpleReference -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from task.core import CrossScoreLightningModule +from crossscore.dataloading.dataset.simple_reference import SimpleReference +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_predict") diff --git a/task/test.py b/task/test.py index ac0f6e3..f2c1f6d 100644 --- a/task/test.py +++ b/task/test.py @@ -1,8 +1,5 @@ -import sys from pathlib import Path -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -12,10 +9,10 @@ import hydra from omegaconf import DictConfig, open_dict -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd +from task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd @hydra.main(version_base="1.3", config_path="../config", config_name="default_test") diff --git a/task/train.py b/task/train.py index acc2650..bc84188 100644 --- a/task/train.py +++ b/task/train.py @@ -1,9 +1,6 @@ -import sys from pathlib import Path from datetime import timedelta -sys.path.append(str(Path(__file__).parents[1])) - import torch from torch.utils.data import DataLoader from torchvision.transforms import v2 as T @@ -16,11 +13,11 @@ from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from core import CrossScoreLightningModule -from dataloading.data_manager import get_dataset -from dataloading.transformation.crop import CropperFactory -from utils.io.images import ImageNetMeanStd -from utils.check_config import ConfigChecker +from task.core import CrossScoreLightningModule +from crossscore.dataloading.data_manager import get_dataset +from crossscore.dataloading.transformation.crop import CropperFactory +from crossscore.utils.io.images import ImageNetMeanStd +from crossscore.utils.check_config import ConfigChecker @hydra.main(version_base="1.3", config_path="../config", config_name="default") diff --git a/utils/data_processing/split_gaussian_processed.py b/utils/data_processing/split_gaussian_processed.py deleted file mode 100644 index 20734f9..0000000 --- a/utils/data_processing/split_gaussian_processed.py +++ /dev/null @@ -1,134 +0,0 @@ -import argparse -import os -import json -from pathlib import Path -from pprint import pprint -import numpy as np - - -def split_list_by_ratio(list_input, ratio_dict): - # Check if the sum of ratios is close to 1 - if not 0.999 < sum(ratio_dict.values()) < 1.001: - raise ValueError("The sum of the ratios must be close to 1") - - total_length = len(list_input) - lengths = {k: int(v * total_length) for k, v in ratio_dict.items()} - - # Adjust the last split to include any rounding difference - last_split_name = list(ratio_dict.keys())[-1] - lengths[last_split_name] = total_length - sum(lengths.values()) + lengths[last_split_name] - - # Split the list - split_lists = {} - start = 0 - for split_name, length in lengths.items(): - split_lists[split_name] = list_input[start : start + length] - start += length - - split_lists = {k: v.tolist() for k, v in split_lists.items()} - return split_lists - - -if __name__ == "__main__": - np.random.seed(1234) - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_path", - type=str, - default="~/projects/mview/storage/scratch_dataset/gaussian-splatting-processed/RealEstate200/res_540", - ) - parser.add_argument("--min_seq_len", type=int, default=2, help="keep seq with len(imgs) >= 2") - parser.add_argument("--min_psnr", type=float, default=10.0, help="keep seq with psnr >= 10.0") - parser.add_argument( - "--split_ratio", - nargs="+", - type=float, - default=[0.8, 0.1, 0.1], - help="train/val/test split ratio", - ) - args = parser.parse_args() - - data_path = Path(args.data_path).expanduser() - log_files = sorted([f for f in os.listdir(data_path) if f.endswith(".log")]) - - # get low psnr scenes from log files - scene_all = [] - scene_low_psnr = {} - for log_f in log_files: - with open(data_path / log_f, "r") as f: - lines = f.readlines() - for line in lines: - # assume scene name printed a few lines before PSNR - if "Output folder" in line: - scene_name = line.split("Output folder: ")[1].split("/")[-1] - scene_name = scene_name.removesuffix("\n") - elif "[ITER 7000] Evaluating train" in line: - psnr = line.split("PSNR ")[1] - psnr = psnr.removesuffix("\n") - psnr = float(psnr) - - scene_all.append(scene_name) - if psnr < args.min_psnr: - scene_low_psnr[scene_name] = psnr - else: - pass - - # get low seq length scenes from data folders - gaussian_splits = ["train", "test"] - scene_low_length = {} - for scene_name in scene_all: - for gs_split in gaussian_splits: - tmp_dir = data_path / scene_name / gs_split / "ours_1000" / "gt" - num_img = len(os.listdir(tmp_dir)) - if num_img < args.min_seq_len: - scene_low_length[scene_name] = num_img - - num_scene_total_after_gaussian = len(scene_all) - num_scene_low_psnr = len(scene_low_psnr) - num_scene_low_length = len(scene_low_length) - - # filter out low psnr scenes - scene_all = [s for s in scene_all if s not in scene_low_psnr.keys()] - num_scene_total_filtered_low_psnr = len(scene_all) - - # filter out low seq length scenes - scene_all = [s for s in scene_all if s not in scene_low_length.keys()] - num_scene_total_filtered_low_length = len(scene_all) - - # split train/val/test - scene_all = np.random.permutation(scene_all) - num_scene_after_all_filtering = len(scene_all) - ratio = { - "train": args.split_ratio[0], - "val": args.split_ratio[1], - "test": args.split_ratio[2], - } - scene_split_info = split_list_by_ratio(scene_all, ratio) - num_scene_train = len(scene_split_info["train"]) - num_scene_val = len(scene_split_info["val"]) - num_scene_test = len(scene_split_info["test"]) - num_scene_after_split = sum([len(v) for v in scene_split_info.values()]) - assert num_scene_after_split == num_scene_after_all_filtering - - # save to json - stats = { - "min_psnr": args.min_psnr, - "min_seq_len": args.min_seq_len, - "split_ratio": args.split_ratio, - "num_scene_total_after_gaussian": num_scene_total_after_gaussian, - "num_scene_low_psnr": num_scene_low_psnr, - "num_scene_low_length": num_scene_low_length, - "num_scene_total_filtered_low_psnr": num_scene_total_filtered_low_psnr, - "num_scene_total_filtered_low_length": num_scene_total_filtered_low_length, - "num_scene_after_all_filtering": num_scene_after_all_filtering, - "num_scene_train": num_scene_train, - "num_scene_val": num_scene_val, - "num_scene_test": num_scene_test, - "num_scene_after_split": num_scene_after_split, - } - - pprint(stats, sort_dicts=False) - out_dict = {"stats": stats, **scene_split_info} - - with open(data_path / "split.json", "w") as f: - json.dump(out_dict, f, indent=2) diff --git a/utils/evaluation/metric.py b/utils/evaluation/metric.py deleted file mode 100644 index 82e47ed..0000000 --- a/utils/evaluation/metric.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np -import torch - - -def psnr(a, b, return_map=False): - mse_map = torch.nn.functional.mse_loss(a, b, reduction="none") - psnr_map = -10 * torch.log10(mse_map) - if return_map: - return psnr_map - else: - return psnr_map.mean() - - -def mse2psnr(a): - return -10 * torch.log10(a) - - -def abs2psnr(a): - return -10 * torch.log10(a.pow(2)) - - -def psnr2mse(a): - return 10 ** (-a / 10) - - -def correlation(a, b): - x = torch.stack([a.flatten(), b.flatten()], dim=0) # (2, N) - corr = x.corrcoef() # (2, 2) - corr = corr[0, 1] # only this one is meaningful - return corr diff --git a/utils/evaluation/metric_logger.py b/utils/evaluation/metric_logger.py deleted file mode 100644 index e1cf202..0000000 --- a/utils/evaluation/metric_logger.py +++ /dev/null @@ -1,55 +0,0 @@ -from abc import ABC, abstractmethod -import torch -import numpy as np -from .metric import correlation - - -class MetricLogger(ABC): - def __init__(self, max_length): - self.storage = [] - self.max_length = max_length - - @torch.no_grad() - def update(self, x): - if self.max_length is not None and len(self) >= self.max_length: - self.reset() - self.storage.append(x) - - def reset(self): - self.storage.clear() - - def __len__(self): - return len(self.storage) - - @abstractmethod - def compute(self): - raise NotImplementedError - - -class MetricLoggerScalar(MetricLogger): - @torch.no_grad() - def compute(self, aggregation_fn=torch.mean): - tmp = torch.stack(self.storage) - result = aggregation_fn(tmp) - return result - - -class MetricLoggerHistogram(MetricLogger): - @torch.no_grad() - def compute(self, bins=10, range=None): - tmp = torch.cat(self.storage).cpu().numpy() - result = np.histogram(tmp, bins=bins, range=range) - return result - - -class MetricLoggerCorrelation(MetricLoggerScalar): - @torch.no_grad() - def update(self, a, b): - corr = correlation(a, b) - super().update(corr) - - -class MetricLoggerImg(MetricLogger): - @torch.no_grad() - def compute(self): - return self.storage diff --git a/utils/evaluation/summarise_score_gt.py b/utils/evaluation/summarise_score_gt.py deleted file mode 100644 index f9e8f75..0000000 --- a/utils/evaluation/summarise_score_gt.py +++ /dev/null @@ -1,43 +0,0 @@ -from argparse import ArgumentParser -import sys -from pathlib import Path - -sys.path.append(str(Path(__file__).parents[2])) -from utils.io.score_summariser import SummaryWriterGroundTruth - - -def parse_args(): - parser = ArgumentParser(description="Summarise the ground truth results.") - parser.add_argument( - "--dir_in", - type=str, - default="datadir/processed_training_ready/gaussian/map-free-reloc/res_540", - help="The ground truth data dir that contains scene dirs.", - ) - parser.add_argument( - "--dir_out", - type=str, - default="~/projects/mview/storage/scratch_dataset/score_summary", - help="The output directory to save the summarised results.", - ) - parser.add_argument( - "--fast_debug", - type=int, - default=-1, - help="num batch to load for debug. Set to -1 to disable", - ) - parser.add_argument("-n", "--num_workers", type=int, default=16) - parser.add_argument("-f", "--force", type=eval, default=False, choices=[True, False]) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - summariser = SummaryWriterGroundTruth( - dir_in=args.dir_in, - dir_out=args.dir_out, - num_workers=args.num_workers, - fast_debug=args.fast_debug, - force=args.force, - ) - summariser.write_csv() diff --git a/utils/io/batch_writer.py b/utils/io/batch_writer.py deleted file mode 100644 index d6b4ce9..0000000 --- a/utils/io/batch_writer.py +++ /dev/null @@ -1,270 +0,0 @@ -import json -from pathlib import Path -from PIL import Image -import numpy as np -from utils.io.images import metric_map_write, u8 -from utils.misc.image import gray2rgb, attn2rgb, de_norm_img - - -def get_vrange(predict_metric_type, predict_metric_min, predict_metric_max): - # Using uint16 when write in gray, normalise to the intrinsic range - # based on score type, regardless of the model prediction range - if predict_metric_type == "ssim": - vrange_intrinsic = [-1, 1] - elif predict_metric_type in ["mse", "mae"]: - vrange_intrinsic = [0, 1] - else: - raise ValueError(f"metric_type {predict_metric_type} not supported") - - # RGB for visualization only, normalise to the model prediction range - vrange_vis = [predict_metric_min, predict_metric_max] - return vrange_intrinsic, vrange_vis - - -class BatchWriter: - """Write batch outputs to disk.""" - - def __init__(self, cfg, phase: str, img_mean_std): - if phase not in ["test", "predict"]: - raise ValueError( - f"Phase {phase} not supported. Has to be a Lightening phase test/predict." - ) - self.cfg = cfg - self.phase = phase - self.img_mean_std = img_mean_std - - self.out_dir = Path(self.cfg.logger[phase].out_dir) - self.write_config = self.cfg.logger[phase].write.config - self.write_flag = self.cfg.logger[phase].write.flag - - # overwrite the flag for attn_weights if the model does not have it - self.write_flag.attn_weights = ( - self.write_flag.attn_weights and self.cfg.model.need_attn_weights - ) - - self.predict_metric_type = self.cfg.model.predict.metric.type - self.predict_metric_min = self.cfg.model.predict.metric.min - self.predict_metric_max = self.cfg.model.predict.metric.max - - self.vrange_intrinsic, self.vrange_vis = get_vrange( - self.predict_metric_type, self.predict_metric_min, self.predict_metric_max - ) - - # prepare out_dirs, create them if required - self.out_dir_dict = {"batch": Path(self.out_dir, "batch")} - if self.write_flag["batch"]: - for k in self.write_flag.keys(): - if k not in ["batch", "score_map_prediction"]: - if self.write_flag[k]: - self.out_dir_dict[k] = Path(self.out_dir_dict["batch"], k) - self.out_dir_dict[k].mkdir(parents=True, exist_ok=True) - - def write_out(self, batch_input, batch_output, local_rank, batch_idx): - if self.write_flag["score_map_prediction"]: - self._write_score_map_prediction( - self.out_dir_dict["batch"], - batch_input, - batch_output, - local_rank, - batch_idx, - ) - - if self.write_flag["score_map_gt"]: - self._write_score_map_gt( - self.out_dir_dict["score_map_gt"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["item_path_json"]: - self._write_item_path_json( - self.out_dir_dict["item_path_json"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["image_query"]: - self._write_query_image( - self.out_dir_dict["image_query"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["image_reference"]: - self._write_reference_image( - self.out_dir_dict["image_reference"], - batch_input, - local_rank, - batch_idx, - ) - - if self.write_flag["attn_weights"]: - self._write_attn_weights( - self.out_dir_dict["attn_weights"], - batch_input, - batch_output, - local_rank, - batch_idx, - check_patch_mode="centre", - ) - - def _write_score_map_prediction( - self, out_dir, batch_input, batch_output, local_rank, batch_idx - ): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - score_map_type_list = [k for k in batch_output.keys() if k.startswith("score_map")] - for score_map_type in score_map_type_list: - tmp_out_dir = Path(out_dir, score_map_type) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - - if len(query_img_paths) != len(batch_output[score_map_type]): - raise ValueError("num of query images and score maps are not equal") - - for b, (query_img_p, score_map) in enumerate( - zip(query_img_paths, batch_output[score_map_type]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - self._write_a_score_map_with_colour_mode( - out_path=tmp_out_dir / tmp_out_name, score_map=score_map.cpu().numpy() - ) - - def _write_score_map_gt(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - - if len(query_img_paths) != len(batch_input["query/score_map"]): - raise ValueError("num of query images and score maps are not equal") - - for b, (query_img_p, score_map) in enumerate( - zip(query_img_paths, batch_input["query/score_map"]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - self._write_a_score_map_with_colour_mode( - out_path=out_dir / tmp_out_name, score_map=score_map.cpu().numpy() - ) - - def _write_item_path_json(self, out_dir, batch_input, local_rank, batch_idx): - out_path = out_dir / f"r{local_rank}_B{str(batch_idx).zfill(4)}.json" - - # get a deep copy to avoid infering other writing functions - item_paths = batch_input["item_paths"].copy() - for ref_type in ["reference/cross/imgs"]: - if len(item_paths[ref_type]) > 0: - # transpose ref paths to (N_ref, B) - item_paths[ref_type] = np.array(item_paths[ref_type]).T.tolist() - - with open(out_path, "w") as f: - json.dump(item_paths, f, indent=2) - - def _write_query_image(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - - for b, (query_img_p, image_query) in enumerate( - zip(query_img_paths, batch_input["query/img"]) - ): - tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" - tmp_out_path = Path(out_dir, tmp_out_name) - image_query = de_norm_img(image_query.permute(1, 2, 0), self.img_mean_std) - image_query = u8(image_query.cpu().numpy()) - Image.fromarray(image_query).save(tmp_out_path) - - def _write_reference_image(self, out_dir, batch_input, local_rank, batch_idx): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - for ref_type in ["reference/cross/imgs"]: - if len(batch_input["item_paths"][ref_type]) > 0: - ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) - - # create a subfolder for each query image to store its ref images - for b, query_img_p in enumerate(query_img_paths): - tmp_out_dir = Path( - out_dir, - f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", - ref_type.split("/")[1], - ) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - tmp_ref_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in ref_img_paths[b] # (N_ref, ) - ] - ref_imgs = batch_input[ref_type][b] # (N_ref, C, H, W) - for ref_idx, (ref_img_p, ref_img) in enumerate( - zip(tmp_ref_img_paths, ref_imgs) - ): - tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" - tmp_out_path = Path(tmp_out_dir, tmp_out_name) - ref_img = de_norm_img(ref_img.permute(1, 2, 0), self.img_mean_std) - ref_img = u8(ref_img.cpu().numpy()) - Image.fromarray(ref_img).save(tmp_out_path) - - def _write_attn_weights( - self, out_dir, batch_input, batch_output, local_rank, batch_idx, check_patch_mode - ): - query_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in batch_input["item_paths"]["query/img"] - ] - for ref_type in ["reference/cross/imgs"]: - if len(batch_input["item_paths"][ref_type]) > 0: - ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) - ref_type_short = ref_type.split("/")[1] - - # create a subfolder for each query image to store its ref images - for b, query_img_p in enumerate(query_img_paths): - tmp_out_dir = Path( - out_dir, - f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", - ref_type_short, - ) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - tmp_ref_img_paths = [ - str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") - for p in ref_img_paths[b] # (N_ref, ) - ] - - # get attn maps of the centre patch in the query image - # (B, H, W, N_ref, H, W) - attn_weights_map = batch_output[f"attn_weights_map_ref_{ref_type_short}"] - tmp_h, tmp_w = attn_weights_map.shape[1:3] - - if check_patch_mode == "centre": - query_patch = (tmp_h // 2, tmp_w // 2) - elif check_patch_mode == "random": - query_patch = ( - np.random.randint(0, tmp_h), - np.random.randint(0, tmp_w), - ) - else: - raise ValueError(f"Unknown check_patch_mode: {check_patch_mode}") - attn_weights_map = attn_weights_map[b][query_patch] # (N_ref, H, W) - - # write attn maps - for ref_idx, (ref_img_p, attn_m) in enumerate( - zip(tmp_ref_img_paths, attn_weights_map) - ): - tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" - tmp_out_path = Path(tmp_out_dir, tmp_out_name) - attn_m = attn2rgb(attn_m.cpu().numpy()) # (H, W, 3) - Image.fromarray(attn_m).save(tmp_out_path) - - def _write_a_score_map_with_colour_mode(self, out_path, score_map): - if self.write_config.score_map_colour_mode == "gray": - metric_map_write(out_path, score_map, self.vrange_intrinsic) - elif self.write_config.score_map_colour_mode == "rgb": - rgb = gray2rgb(score_map, self.vrange_vis) - Image.fromarray(rgb).save(out_path) - else: - raise ValueError(f"colour_mode {self.write_config.score_map_colour_mode} not supported") diff --git a/utils/io/score_summariser.py b/utils/io/score_summariser.py deleted file mode 100644 index fc34972..0000000 --- a/utils/io/score_summariser.py +++ /dev/null @@ -1,315 +0,0 @@ -from pathlib import Path -import os -from glob import glob - -import torch -from torch.utils.data import Dataset, DataLoader -import numpy as np -import pandas as pd -from pandas import DataFrame -from tqdm import tqdm - -from utils.io.images import metric_map_read -from utils.evaluation.metric import mse2psnr - - -class ScoreReader(Dataset): - def __init__(self, score_map_dir_list): - # get all paths to read - read_score_types = ["ssim", "mae"] - self.read_paths_all = {k: [] for k in read_score_types} - for score_read_type in read_score_types: - for score_map_dir in score_map_dir_list: - tmp_dir = os.path.join(score_map_dir, score_read_type) - tmp_paths = [os.path.join(tmp_dir, n) for n in sorted(os.listdir(tmp_dir))] - self.read_paths_all[score_read_type].extend(tmp_paths) - - # (N_frames, 2) - self.read_paths_all = np.stack([self.read_paths_all[k] for k in read_score_types], axis=1) - - def __len__(self): - return len(self.read_paths_all) - - def __getitem__(self, idx): - path_ssim, path_mae = self.read_paths_all[idx] - ssim_map = metric_map_read(path_ssim, vrange=[-1, 1]) - mae_map = metric_map_read(path_mae, vrange=[0, 1]) - mse_map = np.square(mae_map) - - score_ssim_n11 = ssim_map.mean() - score_ssim_01 = ssim_map.clip(0, 1).mean() - score_mae = mae_map.mean() - score_mse = mse_map.mean() - score_psnr = mse2psnr(torch.tensor([score_mse])).numpy() - - results = { - "ssim_-1_1": score_ssim_n11, - "ssim_0_1": score_ssim_01, - "mae": score_mae, - "mse": score_mse, - "psnr": score_psnr, - "path_ssim": path_ssim, - } - return results - - -class SummaryWriterGroundTruth: - """ - Load the ground truth results from disk and summarise the results. - """ - - def __init__(self, dir_in, dir_out, num_workers, fast_debug, force): - self.dir_in = Path(dir_in).expanduser() - self.dir_out = Path(dir_out).expanduser() - self.num_workers = num_workers - self.fast_debug = fast_debug - self.force = force - - self.dataset_type = self.dir_in.parent.name - self.rendering_method = self.dir_in.parents[1].name - self.csv_dir = self.dir_out / self.dataset_type - self.csv_path = self.csv_dir / f"{self.rendering_method}.csv" - self.csv_dir.mkdir(parents=True, exist_ok=True) - self.rows = [] - self.columns = [ - "scene_name", - "rendered_dir", - "image_name", - "gt_ssim_-1_1", - "gt_ssim_0_1", - "gt_mae", - "gt_mse", - "gt_psnr", - ] - - def write_csv(self): - write = self._check_write(self.csv_path, self.force) - if write: - rows = self._load_per_frame_score() - df = DataFrame(data=rows, columns=self.columns) - df.to_csv(self.csv_path, index=False, float_format="%.4f") - - def _check_write(self, csv_path, force): - if csv_path.exists(): - if force: - csv_path.unlink() - print(f"Write to csv {csv_path} (OVERWRITE)") - write = True - else: - print(f"Write to csv {csv_path} (SKIP)") - write = False - else: - print(f"Write to csv {csv_path} (NORMAL)") - write = True - return write - - def _load_per_frame_score(self): - # use glob to find all dir named "metric_map" in the scene dirs - score_map_dir_list = sorted(glob(str(self.dir_in / "**/metric_map"), recursive=True)) - score_reader = ScoreReader(score_map_dir_list) - score_loader = DataLoader( - dataset=score_reader, - batch_size=16, - shuffle=False, - num_workers=self.num_workers, - ) - - # process score maps to csv rows, each row contains the following columns - rows = [] - for i, data in enumerate(tqdm(score_loader, desc=f"Loading gt scores", dynamic_ncols=True)): - for j in range(len(data["path_ssim"])): - path_ssim = data["path_ssim"][j] - scene_name = path_ssim.split("/")[-6] - rendered_dir = os.path.join(*path_ssim.split("/")[:-3]) - image_name = path_ssim.split("/")[-1] - image_name = image_name.replace("frame_", "") - tmp_row = [ - scene_name, - rendered_dir, - image_name, - data["ssim_-1_1"][j].item(), - data["ssim_0_1"][j].item(), - data["mae"][j].item(), - data["mse"][j].item(), - data["psnr"][j].item(), - ] - rows.append(tmp_row) - if self.fast_debug > 0 and i >= self.fast_debug: - break - return rows - - -class SummaryWriterPredictedOnline: - """ - Used in at the fit/test/predict phase with Lightning Module. - """ - - def __init__(self, metric_type, metric_min): - metric_type_str = self._get_metric_type_str(metric_type, metric_min) - self.columns = [ - "scene_name", - "rendered_dir", - "image_name", - f"pred_{metric_type_str}", - ] - self.reset() - - def _get_metric_type_str(self, metric_type, metric_min): - if metric_type == "ssim": - if metric_min == -1: - metric_str = f"{metric_type}_-1_1" - elif metric_min == 0: - metric_str = f"{metric_type}_0_1" - else: - metric_str = f"{metric_type}" - return metric_str - - def reset(self): - self.rows = DataFrame(columns=self.columns) - - def update(self, batch_input, batch_output): - """Store per frame scores to rows for each batch.""" - query_img_paths = batch_input["item_paths"]["query/img"] # (B,) - ref_types = [t for t in batch_output.keys() if t.startswith("score_map")] - - if len(ref_types) != 1: - raise ValueError(f"Expect exactly one ref_type: self/cross, but got {ref_types}.") - - rows_batch = [] - for ref_type in ref_types: - score_maps = batch_output[ref_type] # (B, H, W) - scores = score_maps.mean(dim=[-1, -2]) # (B,) - - scene_names = [p.split("/")[-5] for p in query_img_paths] - - rendered_dirs = [os.path.join(*p.split("/")[:-2]) for p in query_img_paths] - image_names = [p.split("/")[-1] for p in query_img_paths] - image_names = [n.replace("frame_", "") for n in image_names] - - for i in range(len(scene_names)): - rows_batch.append( - [scene_names[i], rendered_dirs[i], image_names[i], scores[i].item()] - ) - - # concat rows to panda dataframe - self.rows = pd.concat([self.rows, DataFrame(rows_batch, columns=self.columns)]) - - def summarise(self): - """Organise rows using dataset type and rendering method and sort them.""" - - # get unique rendering methods - rendering_method_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-6]).unique() - - # get unique dataset types - dataset_type_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-5]).unique() - - # organise rows using dataset type and rendering method - self.summary = {} - for dataset_type in dataset_type_list: - self.summary[dataset_type] = {} - for rendering_method in rendering_method_list: - # get rows with the same dataset type and rendering method - tmp_rows = self.rows[ - (self.rows["rendered_dir"].str.contains(rendering_method)) - & (self.rows["rendered_dir"].str.contains(dataset_type)) - ] - - # sort rows by scene_name, rendered_dir, image_name - tmp_rows = tmp_rows.sort_values(by=["scene_name", "rendered_dir", "image_name"]) - - self.summary[dataset_type][rendering_method] = tmp_rows - - def __len__(self): - return len(self.rows) - - def __repr__(self): - return self.rows.__repr__() - - -class SummaryWriterPredictedOnlineTestPrediction(SummaryWriterPredictedOnline): - """ - Used in at the end of validation/test/predict phase. - Summarising with online predicted results to avoid reading from disks. - """ - - def __init__(self, metric_type, metric_min, dir_out): - super().__init__(metric_type, metric_min) - self.csv_dir = Path(dir_out).expanduser() / "score_summary" - self.csv_dir.mkdir(parents=True, exist_ok=True) - self.cache_csv_path = self.csv_dir / f"summarise_cache.csv" - - def summarise(self): - super().summarise() - - # write to csv files - for dataset_type, dataset_summary in self.summary.items(): - for rendering_method, rows in dataset_summary.items(): - tmp_csv_dir = self.csv_dir / dataset_type - tmp_csv_dir.mkdir(parents=True, exist_ok=True) - tmp_csv_path = tmp_csv_dir / f"{rendering_method}.csv" - rows.to_csv(tmp_csv_path, index=False, float_format="%.4f") - - -class SummaryReader: - @staticmethod - def read_summary(summary_dir, dataset, method_list, scene_list, split_list, iter_list): - summary_dir = Path(summary_dir).expanduser() - summary_dir = summary_dir / dataset - - methods_available = [f.stem for f in summary_dir.iterdir() if f.is_file()] - methods_to_read = [] - if method_list != [""]: - for m in method_list: - if m in methods_available: - methods_to_read.append(m) - else: - raise ValueError(f"{m} is not available in {summary_dir}") - else: - methods_to_read = methods_available - - summary_files = [summary_dir / f"{m}.csv" for m in methods_to_read] - - # read csv files, create a new colume as 0th column for method_name - summary = pd.concat( - [pd.read_csv(f).assign(method_name=m) for f, m in zip(summary_files, methods_to_read)] - ) - - # filter with scene list using summary's scene_name column - if scene_list != [""]: - summary = summary[summary["scene_name"].isin(scene_list)] - - # filter split using rendered_dir column - if split_list != [""]: - new_s = [] - for split in split_list: - tmp_s = summary[summary["rendered_dir"].str.split("/").str[-2] == split] - new_s.append(tmp_s) - summary = pd.concat(new_s) - - # filter iter using rendered_dir column, the last part of the path should be EXACTLY "ours_{iter}" - if len(iter_list) > 0: - new_s = [] - for i in iter_list: - tmp_s = summary[summary["rendered_dir"].str.endswith(f"ours_{i}")] - new_s.append(tmp_s) - summary = pd.concat(new_s) - - # sort by scene_name, rendered_dir, image_name, method_name - summary = summary.sort_values(["scene_name", "rendered_dir", "image_name", "method_name"]) - - # reset index - summary = summary.reset_index(drop=True) - return summary - - @staticmethod - def check_summary_gt_prediction_rows(summary_gt, summary_prediction): - # these tow dataframes should have the same length - # and the columns [rendered_dir, image_name] should be identical - if len(summary_gt) != len(summary_prediction): - raise ValueError("Summary GT and prediction have different length") - - if not summary_gt["rendered_dir"].equals(summary_prediction["rendered_dir"]): - raise ValueError("Summary GT and prediction have different rendered_dir") - - if not summary_gt["image_name"].equals(summary_prediction["image_name"]): - raise ValueError("Summary GT and prediction have different image_name") diff --git a/utils/plot/batch_visualiser.py b/utils/plot/batch_visualiser.py deleted file mode 100644 index 668e04f..0000000 --- a/utils/plot/batch_visualiser.py +++ /dev/null @@ -1,414 +0,0 @@ -from pathlib import Path -from abc import ABC, abstractmethod -import numpy as np -import torch -import wandb -import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle -from torchvision.utils import make_grid -from PIL import Image -from utils.io.images import u8 -from utils.misc.image import de_norm_img -from utils.check_config import check_reference_type - - -class BatchVisualiserBase(ABC): - """Organise batch data and preview in a single image.""" - - def __init__(self, cfg, img_mean_std, ref_type): - self.cfg = cfg - self.img_mean_std = img_mean_std.cpu().numpy() - self.ref_type = ref_type - self.metric_type = cfg.model.predict.metric.type - self.metric_min = cfg.model.predict.metric.min - self.metric_max = cfg.model.predict.metric.max - - @abstractmethod - def vis(self, **kwargs): - raise NotImplementedError - - -class BatchVisualiserRef(BatchVisualiserBase): - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - ref_type = self.ref_type - - fig = plt.figure(figsize=(6, 6.5), layout="constrained") - # plt.style.use("dark_background") - # fig.set_facecolor("gray") - - # arrange subfigure - N_ref_imgs = batch_input[f"reference/{ref_type}/imgs"].shape[1] - h_ratios_subfig = [3, np.ceil(N_ref_imgs / 5)] - fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) - - # arrange subplot for top fig - inner = [ - ["score_map/gt", f"score_map/ref_{ref_type}"], - ] - - mosaic_top = [ - ["query", inner], - ] - width_ratio_top = [3, 2] - ax_dict_top = fig_top.subplot_mosaic( - mosaic_top, - empty_sentinel=None, - width_ratios=width_ratio_top, - ) - - if "item_paths" in batch_input.keys(): - # 1. anonymize the path - # 2. split to two rows - # 3. add query path to title - query_path = batch_input["item_paths"]["query/img"][vis_id] - query_path = query_path.replace(str(Path("~").expanduser()), "~/") - query_path = query_path.split("/") - query_path.insert(len(query_path) // 2, "\n") - query_path = Path(*query_path) - fig_top.suptitle(f"{query_path}", fontsize=7) - - # arrange subplot for bottom fig - mosaic_bottom = [[f"ref/{ref_type}/imgs"]] - height_ratios = [1] - ax_dict_bottom = fig_bottom.subplot_mosaic( - mosaic_bottom, - empty_sentinel=None, - height_ratios=height_ratios, - ) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - img_dict[f"ref/{ref_type}/imgs"] = make_grid( - batch_input[f"reference/{ref_type}/imgs"][vis_id], nrow=5 - ) # 3HW - img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW - img_dict[f"score_map/ref_{ref_type}"] = batch_output[f"score_map_ref_{ref_type}"][ - vis_id - ] # HW - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict_top.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict_top[k].imshow( - tmp_img, - vmin=self.metric_min, - vmax=self.metric_max, - cmap="turbo", - ) - title_txt = k.replace("score", self.metric_type) - title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") - title_txt = title_txt + f"\n{tmp_img.mean():.3f}" - if k == f"score_map/ref_{ref_type}": - if f"l1_diff_map_ref_{ref_type}" in batch_output.keys(): - loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() - title_txt = title_txt + f" ∆{loss:.2f}" - ax_dict_top[k].set_title(title_txt, fontsize=8) - elif k.startswith("delta_map"): - tmp_img = v.cpu().numpy() - # ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="turbo") - ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="bwr") - ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_top[k].imshow(tmp_img) - ax_dict_top[k].set_title(k) - ax_dict_top[k].set_xticks([]) - ax_dict_top[k].set_yticks([]) - - # add a colorbar, turn off ticks - ax_colorbar = fig_top.colorbar( - plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict_top["query"], fraction=0.046, pad=0.04 - ) - ax_colorbar.ax.set_yticklabels([]) - - # plot bottom fig - for k, v in img_dict.items(): - if k not in ax_dict_bottom.keys(): - continue - if k.startswith("attn_weights_map"): - # tmp_img = v.permute(1, 2, 0) # HW3 - tmp_img = v[0] # HW - tmp_img = tmp_img.clamp(0, 1) - tmp_img = tmp_img.cpu().numpy() - ax_dict_bottom[k].imshow(tmp_img, vmin=0, cmap="turbo") - ax_dict_bottom[k].set_title(k) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_bottom[k].imshow(tmp_img) - ax_dict_bottom[k].set_title(k) - - ax_dict_bottom[k].set_xticks([]) - ax_dict_bottom[k].set_yticks([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserRefAttnMap(BatchVisualiserBase): - def __init__(self, cfg, img_mean_std, ref_type, check_patch_mode): - super().__init__(cfg, img_mean_std, ref_type) - self.check_patch_mode = check_patch_mode - - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - fig = plt.figure(figsize=(6, 6.5), layout="constrained") - # plt.style.use("dark_background") - fig.set_facecolor("gray") - - has_reference_cross = "reference/cross/imgs" in batch_input.keys() - has_attn_weights_cross = batch_output.get("attn_weights_map_ref_cross", None) is not None - P = patch_size = 14 - - # arrange subfigure - h_ratios_subfig = [2, 1] - - fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) - - # arrange subplot for top fig - inner = [] - if has_reference_cross: - inner.append(["score_map/gt1", "score_map/ref_cross", "score_map/diff_cross"]) - inner = np.array(inner).T.tolist() # vertical layout - - mosaic_top = [ - ["query", inner], - ] - width_ratio_top = [4, 1] - - ax_dict_top = fig_top.subplot_mosaic( - mosaic_top, - empty_sentinel=None, - width_ratios=width_ratio_top, - ) - - # arrange subplot for bottom fig - mosaic_bottom = [] - if has_reference_cross: - mosaic_bottom.append(["ref/cross/imgs"]) - if has_attn_weights_cross: - mosaic_bottom.append(["attn_weights_map_ref_cross"]) - - height_ratios = [1] * len(mosaic_bottom) - - ax_dict_bottom = fig_bottom.subplot_mosaic( - mosaic_bottom, - empty_sentinel=None, - height_ratios=height_ratios, - ) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - if has_reference_cross: - img_dict["ref/cross/imgs"] = make_grid( - batch_input["reference/cross/imgs"][vis_id], nrow=5 - ) # 3HW - img_dict["score_map/gt1"] = batch_input["query/score_map"][vis_id] - img_dict["score_map/ref_cross"] = batch_output["score_map_ref_cross"][vis_id] - if "l1_diff_map_ref_cross" in batch_output.keys(): - img_dict["score_map/diff_cross"] = batch_output["l1_diff_map_ref_cross"][vis_id] - - if has_attn_weights_cross: - attn_weights_map_ref_cross = batch_output["attn_weights_map_ref_cross"] # BHWNHW - tmp_h, tmp_w = attn_weights_map_ref_cross.shape[1:3] - if self.check_patch_mode == "centre": - query_patch_cross = (tmp_h // 2, tmp_w // 2) - elif self.check_patch_mode == "random": - query_patch_cross = (np.random.randint(0, tmp_h), np.random.randint(0, tmp_w)) - else: - raise ValueError(f"Unknown check_patch_mode: {self.check_patch_mode}") - - # NHW - attn_weights_map_ref_cross = attn_weights_map_ref_cross[vis_id][query_patch_cross] - img_dict["attn_weights_map_ref_cross"] = make_grid( - attn_weights_map_ref_cross[:, None], nrow=5 # make grid expects N3HW - ) - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict_top.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict_top[k].imshow(tmp_img, vmin=0, vmax=1, cmap="turbo") - ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_top[k].imshow(tmp_img) - ax_dict_top[k].set_title(k) - - if has_attn_weights_cross and k == "query": - # draw a rectangle patch - query_pixel = (query_patch_cross[0] * P, query_patch_cross[1] * P) - rect = Rectangle( - (query_pixel[1], query_pixel[0]), - P, - P, - linewidth=2, - edgecolor="magenta", - facecolor="none", - ) - ax_dict_top[k].add_patch(rect) - ax_dict_top[k].set_xticks([]) - ax_dict_top[k].set_yticks([]) - - # plot bottom fig - for k, v in img_dict.items(): - if k not in ax_dict_bottom.keys(): - continue - if k.startswith("attn_weights_map"): - num_stabler = 1e-8 # to avoid log(0) - tmp_img = v[0] # HW - tmp_img = tmp_img.clamp(0, 1) - tmp_img = tmp_img.cpu().numpy() - # invert softmax (exp'd) attn weights - tmp_img = np.log(tmp_img + num_stabler) - np.log(num_stabler) - ax_dict_bottom[k].imshow(tmp_img, vmax=-np.log(num_stabler), cmap="turbo") - ax_dict_bottom[k].set_title(k) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict_bottom[k].imshow(tmp_img) - ax_dict_bottom[k].set_title(k) - - ax_dict_bottom[k].set_xticks([]) - ax_dict_bottom[k].set_yticks([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserRefFree(BatchVisualiserBase): - @torch.no_grad() - def vis(self, batch_input, batch_output, vis_id=0): - fig = plt.figure(figsize=(6.5, 5), layout="constrained") - # plt.style.use("dark_background") - # fig.set_facecolor("gray") - - ref_type = self.ref_type - inner = [["score_map/gt", f"score_map/ref_{ref_type}", "score_map/diff"]] - inner = np.array(inner).T.tolist() - - mosaic = [ - ["query", inner], - ] - width_ratio = [5, 2] - ax_dict = fig.subplot_mosaic( - mosaic, - empty_sentinel=None, - width_ratios=width_ratio, - ) - - if "item_paths" in batch_input.keys(): - # 1. anonymize the path - # 2. split to two rows - # 3. add query path to title - query_path = batch_input["item_paths"]["query/img"][vis_id] - query_path = query_path.replace(str(Path("~").expanduser()), "~/") - query_path = query_path.split("/") - query_path.insert(len(query_path) // 2, "\n") - query_path = Path(*query_path) - fig.suptitle(f"{query_path}", fontsize=7) - - # getting batch input output data - img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW - img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW - - # plot top fig - for k, v in img_dict.items(): - if k not in ax_dict.keys(): - continue - if k.startswith("score_map"): - tmp_img = v.cpu().numpy() - ax_dict[k].imshow( - tmp_img, - vmin=self.metric_min, - vmax=self.metric_max, - cmap="turbo", - ) - title_txt = k.replace("score", self.metric_type) - title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") - title_txt = title_txt + f"\n{tmp_img.mean():.3f}" - if k == f"score_map/ref_{ref_type}": - loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() - title_txt = title_txt + f" ∆{loss:.2f}" - ax_dict[k].set_title(title_txt, fontsize=8) - else: - tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 - tmp_img = de_norm_img(tmp_img, self.img_mean_std) - tmp_img = tmp_img.clip(0, 1) - tmp_img = u8(tmp_img) - ax_dict[k].imshow(tmp_img) - ax_dict[k].set_title(k) - ax_dict[k].set_xticks([]) - ax_dict[k].set_yticks([]) - - # add a colorbar, turn off ticks - ax_colorbar = fig.colorbar( - plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict["query"], fraction=0.046, pad=0.04 - ) - ax_colorbar.ax.set_yticklabels([]) - - # output - fig.canvas.draw() - out_img = Image.frombytes( - "RGBA", - fig.canvas.get_width_height(), - fig.canvas.buffer_rgba(), - ).convert("RGB") - - plt.close() - out_img = wandb.Image(out_img, file_type="jpg") - return out_img - - -class BatchVisualiserFactory: - def __init__(self, cfg, img_mean_std): - self.cfg = cfg - self.img_mean_std = img_mean_std - self.ref_type = check_reference_type(self.cfg.model.do_reference_cross) - - if self.ref_type in ["cross"]: - if self.cfg.model.need_attn_weights: - self.visualiser = BatchVisualiserRefAttnMap( - cfg, self.img_mean_std, self.ref_type, check_patch_mode="centre" - ) - else: - self.visualiser = BatchVisualiserRef(cfg, self.img_mean_std, self.ref_type) - else: - raise NotImplementedError - - def __call__(self): - return self.visualiser