diff --git a/ACL_PyTorch/built-in/cv/SAM/README.md b/ACL_PyTorch/built-in/cv/SAM/README.md index 4ca3afa1f55cc9977408741eddf3a8e66c4a0e2a..7e0577be265a77f3f4f40aed9f7e61e7cfeac751 100644 --- a/ACL_PyTorch/built-in/cv/SAM/README.md +++ b/ACL_PyTorch/built-in/cv/SAM/README.md @@ -78,7 +78,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/cv/SAM git clone https://github.com/facebookresearch/segment-anything.git cd segment-anything git reset --hard 6fdee8f2727f4506cfbbe553e23b895e27956588 -patch -p2 < ../segment_anything_diff.patch +git apply ../segment_anything_diff.patch pip3 install -e . cd .. ``` diff --git a/ACL_PyTorch/built-in/cv/SAM/requirements.txt b/ACL_PyTorch/built-in/cv/SAM/requirements.txt index 6722adb464f8d1cbc54c229ff0e6e974124339b4..b4969b22635980cd34e286a64d4bbf4968e918c1 100644 --- a/ACL_PyTorch/built-in/cv/SAM/requirements.txt +++ b/ACL_PyTorch/built-in/cv/SAM/requirements.txt @@ -1,5 +1,5 @@ torch==2.1.0 -torch_npu==2.1.0.post17.dev20250905 +torch_npu==2.1.0.post10 torchvision==0.16.0 torchaudio==2.1.0 decorator diff --git a/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch b/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch index aec413383ecb17bc2ed7eb15def82ad337b44ff7..96284944bc9f131e8c6f4819e25304bcc553f8a2 100644 --- a/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch +++ b/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch @@ -1,6 +1,7 @@ -diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/scripts/export_onnx_model.py ---- a/segment-anything/scripts/export_onnx_model.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/scripts/export_onnx_model.py 2023-11-18 16:15:20.088025762 +0800 +diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py +index 5c6f838..0bfaff2 100644 +--- a/scripts/export_onnx_model.py ++++ b/scripts/export_onnx_model.py @@ -6,8 +6,12 @@ import torch @@ -14,7 +15,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc import argparse import warnings -@@ -24,11 +28,30 @@ +@@ -24,11 +28,30 @@ parser = argparse.ArgumentParser( ) parser.add_argument( @@ -47,7 +48,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc ) parser.add_argument( -@@ -56,11 +79,21 @@ +@@ -56,11 +79,21 @@ parser.add_argument( ) parser.add_argument( @@ -71,7 +72,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." ), ) -@@ -97,7 +130,9 @@ +@@ -97,7 +130,9 @@ parser.add_argument( def run_export( model_type: str, checkpoint: str, @@ -82,7 +83,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc opset: int, return_single_mask: bool, gelu_approximate: bool = False, -@@ -107,6 +142,74 @@ +@@ -107,6 +142,74 @@ def run_export( print("Loading model...") sam = sam_model_registry[model_type](checkpoint=checkpoint) @@ -157,7 +158,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc onnx_model = SamOnnxModel( model=sam, return_single_mask=return_single_mask, -@@ -129,16 +232,17 @@ +@@ -129,16 +232,17 @@ def run_export( mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), @@ -178,7 +179,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) -@@ -164,7 +268,7 @@ +@@ -164,7 +268,7 @@ def run_export( providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(output, providers=providers) _ = ort_session.run(None, ort_inputs) @@ -187,7 +188,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc def to_numpy(tensor): -@@ -176,7 +280,9 @@ +@@ -176,7 +280,9 @@ if __name__ == "__main__": run_export( model_type=args.model_type, checkpoint=args.checkpoint, @@ -198,7 +199,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc opset=args.opset, return_single_mask=args.return_single_mask, gelu_approximate=args.gelu_approximate, -@@ -184,18 +290,34 @@ +@@ -184,18 +290,34 @@ if __name__ == "__main__": return_extra_metrics=args.return_extra_metrics, ) @@ -238,10 +239,11 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc + ) + print("Done!") \ No newline at end of file -diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segment-anything/segment_anything/modeling/image_encoder.py ---- a/segment-anything/segment_anything/modeling/image_encoder.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/segment_anything/modeling/image_encoder.py 2023-11-13 19:26:32.000000000 +0800 -@@ -253,8 +253,8 @@ +diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py +index 66351d9..31d622c 100644 +--- a/segment_anything/modeling/image_encoder.py ++++ b/segment_anything/modeling/image_encoder.py +@@ -253,8 +253,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T """ B, H, W, C = x.shape @@ -252,7 +254,7 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w -@@ -322,6 +322,15 @@ +@@ -322,6 +322,15 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor return rel_pos_resized[relative_coords.long()] @@ -268,7 +270,7 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, -@@ -351,8 +360,8 @@ +@@ -351,8 +360,8 @@ def add_decomposed_rel_pos( B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) @@ -279,10 +281,33 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] -diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/segment_anything/utils/onnx.py ---- a/segment-anything/segment_anything/utils/onnx.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/segment_anything/utils/onnx.py 2023-11-18 16:14:01.512027850 +0800 -@@ -112,7 +112,6 @@ +diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py +index 5d2fdb0..ee8da94 100644 +--- a/segment_anything/modeling/mask_decoder.py ++++ b/segment_anything/modeling/mask_decoder.py +@@ -123,9 +123,15 @@ class MaskDecoder(nn.Module): + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask +- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) ++ N = tokens.shape[0] ++ B, C, H, W = image_embeddings.shape ++ src = image_embeddings.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W) ++ + src = src + dense_prompt_embeddings +- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) ++ ++ B, C, H, W = image_pe.shape ++ pos_src = image_pe.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W) ++ + b, c, h, w = src.shape + + # Run the transformer +diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py +index 3196bdf..e718afc 100644 +--- a/segment_anything/utils/onnx.py ++++ b/segment_anything/utils/onnx.py +@@ -112,7 +112,6 @@ class SamOnnxModel(nn.Module): point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, @@ -290,7 +315,7 @@ diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/ ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) -@@ -131,14 +130,4 @@ +@@ -131,14 +130,4 @@ class SamOnnxModel(nn.Module): if self.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) @@ -306,3 +331,4 @@ diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/ - - return upscaled_masks, scores, masks + return scores, masks +