diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a0f49d274d8db7362b07c92eff8a6eb018caa37
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md
@@ -0,0 +1,177 @@
+# DINOv2-ViT-推理指导
+
+- [概述](#ZH-CN_TOPIC_0000001172161123)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126217823)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126288456)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172156835)
+
+# 概述
+
+DINOv2是由Meta
+AI开发的一个自监督学习方法,它专注于在无监督学习鲁棒的视觉特征。这个方法基于最初的DINO方法并进行了改进,提升了ViT模型在各种计算机视觉任务上的性能。([来自开源代码仓](https://github.com/facebookresearch/dinov2/tree/main))
+
+# 推理环境准备\[所有版本\]
+
+- 该模型需要以下插件与驱动
+
+  **表 1**  版本配套表
+
+  | 配套 | 版本 |
+  |---------| ------- |
+  | 固件与驱动 | - |
+  | CANN | - |
+  | Python | 3.10.13 |
+  | PyTorch | 2.1.0 |
+  | MindIE | - |
+
+  注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。
+  固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。
+
+  MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例:
+  ```
+  source /usr/local/Ascend/ascend-tookit/set_env.sh
+  bash Ascend-mindie_*.run --install
+  ```
+
+# 快速上手
+
+1. 安装依赖包
+    ```shell
+    pip install transformers==4.44.1
+    pip install numpy==1.26.4
+    ```
+2. 权重下载
+
+   | 模型         | 下载                                                                                       |
+   |-------------|--------------------------------------------------------------------------------------------|
+   | ViT-S | [backbone only](https://huggingface.co/facebook/dinov2-small)                                     |
+   | ViT-B | [backbone only](https://huggingface.co/facebook/dinov2-base) |
+   | ViT-L | [backbone only](https://huggingface.co/facebook/dinov2-large)                                       |
+   | ViT-G | [backbone only](https://huggingface.co/facebook/dinov2-giant)                                        |
+
+   按上述链接下载模型权重,以dinov2-vit-base为例
+    ```shell
+    git lfs install
+    git clone https://huggingface.co/facebook/dinov2-base
+    ```
+2. 参数说明
+   导出和推理的参数命名有多数重合,公用说明如下:
+
+   | 模型          | 下载                                 |
+   |---------------|---------------------------------------------|
+   | soc-version   |  芯片类型,当前仅在Ascend910B4上调试          |
+   | device        |  NPU ID 号                                   |
+   | img-max-batch |  图片输入的最大batch size                     |
+   | image-path    |  输入图片地址                                 |
+   | model-version |  模型类型("small", "base", "large", "giant")|
+   | hf-model-path |  模型权重路径                                 |
+   | save-dir      |  不同类型模型保存路径                          |
+   更多参数请参考运行不同脚本的`parse_args`部分
+3. ONNX模型导出
+    ```shell
+    python onnx_export.py \
+        --soc-version ${soc_version} \
+        --image-path ${image_path} \
+        --model-version ${model_version} \
+        --hf-model-path ${hf_model_path} \
+        --save-dir ${save_dir}
+    ```
+   执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-onnx.pt`文件。
+   giant模型由于模型过大,导出时间较长,请耐心等待,并且会保存大量中间计算节点;保存onnx模型的save_dir和保存MindIETorch模型的save_dir必须不同。
+   
+
+4. 模型编译
+    
+    由于MindIETorch不支持mode为"bicubic"的nn.functional.interpolate,因此需要将模型中的embedding剥离出来进行在线推理,只编译模型encoder部分,执行以下脚本进行编译:
+    ```shell
+    python dino_compile.py \
+        --soc-version ${soc_version} \
+        --device ${device} \
+        --img-max-batch ${img_max_batch} \
+        --image-path ${image_path} \
+        --model-version ${model_version} \
+        --model-path ${hf_model_path} \
+        --save-dir ${save_dir}
+    ```
+   执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-MindIETorch.pt`文件。
+
+# 模型推理性能精度
+
+1. 精度验证
+    ```shell
+    dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt"
+    dinov2_onnx_path="./dinov2-${model_version}-onnx.pt"
+    python precision_test.py \
+        --dinov2-aie-path ${dinov2_aie_path} \
+        --dinov2-onnx-path ${dinov2_onnx_path} \
+        --device ${device} \
+        --image-path ${image_path} \
+        --model-version ${model_version} \
+        --hf-model-path ${hf_model_path} \
+    ```
+   执行结束后,期望输出如下:
+    ```
+    ----- Compare the outputs of ONNX and AIE dinov2 ${model_version} model -----
+    Number of outputs to compare: 2
+    Number of outputs with cosine similarity > 0.99: 2
+    Number of outputs to compare: 2
+    Number of outputs with cosine similarity > 0.99: 2
+    ```
+
+2. 性能验证
+
+   (a) aie模型性能测试
+    ```shell
+    dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt"
+    python perf_test_aie.py \
+        --dinov2-aie-path ${dinov2_aie_path} \
+        --device ${device} \
+        --image-path ${image_path} \
+        --img-max-batch ${img_max_batch} \
+        --hf-model-path ${hf_model_path} \
+    ```
+
+   执行结束后,期望输出如下(base):
+    ```
+    DINOV2 aie latency: 31.11 ms
+    DINOV2 aie throughput: 32.14 fps
+    ```
+
+   (b) onnx模型性能测试
+   (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
+    ```shell
+    pip uninstall onnxruntime
+    pip install onnxruntime-gpu
+    ```
+   验证onnxruntime-gpu是否安装成功:
+    ```python
+    import onnxruntime
+    print(onnxruntime.get_device())  # 若输出为GPU,则说明安装成功
+    ``` 
+   执行性能测试(CPU)
+    ```shell
+    dinov2_onnx_path="./dinov2-${model_version}-onnx.pt"
+    python perf_test_onnx.py \
+        --onnx-path ${dinov2_onnx_path} \
+        --image-path ${image_path} \
+        --hf-model-path ${hf_model_path} \
+    ```
+
+   执行结束后,期望输出如下(base):
+    ```
+    DINOV2 onnx latency: 268.65 ms
+    DINOV2 onnx throughput: 3.72 fps
+    ```
+
+   (c) 性能对比列表(GPU性能待测试):
+
+   | 模型    | MindIE-Torch(Ascend910B4) | ONNX(CPU) |
+    |---------|--------------------------------|---------------------|
+   | small | 3.93 ms / 254.74 fps | 164.50 ms / 6.08 fps |
+   | base | 4.00 ms / 250.14 fps | 535.00 ms / 1.90 fps |
+   | large | 9.59 ms / 104.30 fps | 1086.95 ms / 0.92 fps |
+   | giant | 20.09 ms / 49.78 fps | 3088.63 ms / 0.32 fps |
+    不同机器的测试出的性能在绝对值上可能有一定差异(特别是CPU性能),但相对值差异是保持一致的。
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..71654b09386cfff033e98b7ad5d3ca8fdeded206
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py
@@ -0,0 +1,130 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+import time
+
+import mindietorch
+import torch
+from PIL import Image
+from torch._export import export, dynamic_dim
+from transformers import AutoImageProcessor
+
+from model.dinov2_model import Dinov2Model_WO_Embedding
+
+
+def get_embed_input(args, model):
+    processor = AutoImageProcessor.from_pretrained(args.hf_model_path)
+    inputs = processor(images=Image.open(args.image_path), return_tensors="pt")
+    embeddings = model.embeddings
+    embedding_output = embeddings(inputs.pixel_values)
+    return embedding_output
+
+
+def export_dinov2(args):
+    model = Dinov2Model_WO_Embedding.from_pretrained(args.hf_model_path).float().eval()
+    embedding_output = get_embed_input(args, model)
+    embed_shape = embedding_output.shape
+    emb_input_shape = (args.img_max_batch, embed_shape[-2], embed_shape[-1])
+    input_emb = torch.ones(emb_input_shape, dtype=torch.float32)
+
+    constraints = [
+        dynamic_dim(input_emb, 0) >= 1,
+        dynamic_dim(input_emb, 0) <= args.img_max_batch,
+    ]
+
+    print("----- start exporting dynamic dinov2 -----")
+    intermediate_model = export(
+        model,
+        args=(input_emb,),
+        constraints=constraints
+    )
+    print("----- export dynamic dinov2 success! -----")
+    return embed_shape, intermediate_model
+
+
+def compile_dinov2(args):
+    # export dinov2
+    embed_shape, intermediate_model = export_dinov2(args)
+    # compile dinov2
+    mindietorch.set_device(args.device)
+    compile_inputs = [
+        mindietorch.Input(min_shape=(1, embed_shape[-2], embed_shape[-1]),
+                          max_shape=(args.img_max_batch, embed_shape[-2], embed_shape[-1])),
+    ]
+
+    print("----- start mindietorch compile -----")
+    ts = time.time()
+    compiled_model = mindietorch.compile(
+        intermediate_model,
+        inputs=compile_inputs,
+        precision_policy=mindietorch._enums.PrecisionPolicy.FP16,
+        soc_version=args.soc_version,
+    )
+    compile_cost = time.time() - ts
+    print(f"----- compile time cost: {compile_cost} -----")
+    print("----- end mindietorch compile -----")
+
+    print("----- start saving -----")
+    model_save_dir = f"{args.save_dir}"
+    if not os.path.exists(model_save_dir):
+        os.makedirs(model_save_dir)
+    compiled_file_name = f"dinov2-{args.model_version}-MindIETorch.pt"
+    torch.save(compiled_model, model_save_dir + compiled_file_name, pickle_protocol=4)
+    print("----- saving done -----")
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compile DINOv2-Vit model")
+    parser.add_argument(
+        "--soc-version",
+        default="Ascend910B4",
+        help="NPU version"
+    )
+    parser.add_argument(
+        "--device",
+        type=int,
+        default=0
+    )
+    parser.add_argument(
+        "--img-max-batch",
+        type=int,
+        default=8
+    )
+    parser.add_argument(
+        "--image-path",
+        default=""
+    )
+    parser.add_argument(
+        "--model-version",
+        default="base",
+        choices=["small", "base", "large", "giant"],
+        help="Specify the architecture of DINOv2-Vit model to be converted."
+    )
+    parser.add_argument(
+        "--hf-model-path",
+        default="",
+        type=str,
+        help="Path of the Huggingface DINOv2-Vit model."
+    )
+    parser.add_argument(
+        "--save-dir",
+        default="./"
+    )
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    input_args = parse_args()
+    compile_dinov2(input_args)
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..202175318474ecc8aad3f2b9986e828ba6f2bd5f
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py
@@ -0,0 +1,56 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.dinov2.modeling_dinov2 import Dinov2Model
+
+
+class Dinov2Model_WO_Embedding(Dinov2Model):
+    def forward(
+            self,
+            embedding_output: Optional[torch.Tensor] = None,
+            head_mask: Optional[torch.Tensor] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+    ) -> Tuple:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if embedding_output is None:
+            raise ValueError("You have to specify embedding_output")
+
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = sequence_output[:, 0, :]
+
+        if not return_dict:
+            head_outputs = (sequence_output, pooled_output)
+            return head_outputs + encoder_outputs[1:]
+
+        return sequence_output, pooled_output
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7082190276dc234f097d2b0119c42e87fe4fe66
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py
@@ -0,0 +1,81 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+
+import torch
+import torch.onnx
+from PIL import Image
+from transformers import AutoImageProcessor, AutoModel
+
+
+def convert_dinov2(args):
+    processor = AutoImageProcessor.from_pretrained(args.hf_model_path)
+    inputs = processor(images=Image.open(args.image_path), return_tensors="pt")
+    pixel_values = inputs.pixel_values
+    model = AutoModel.from_pretrained(args.hf_model_path).float().eval()
+
+    onnx_path = f"{args.save_dir}dinov2-{args.model_version}-onnx.pt"
+    print("----- Starting to export dynamic onnx -----")
+    torch.onnx.export(
+        model,
+        (pixel_values,),
+        onnx_path,
+        input_names=["pixel_values"],
+        output_names=["sequence_output", "pooled_output"],
+        export_params=True,
+        opset_version=13,
+        verbose=True,
+        dynamic_axes={
+            "pixel_values": {0: "image_batch_size"},
+            "sequence_output": {0: "image_batch_size"},
+            "pooled_output": {0: "image_batch_size"},
+        }
+    )
+    print("----- Successfully exported dynamic onnx! -----")
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compile DINOv2 model")
+    parser.add_argument(
+        "--soc-version",
+        default="Ascend910B4",
+        help="NPU version"
+    )
+    parser.add_argument(
+        "--image-path",
+        type=str,
+        default=""
+    )
+    parser.add_argument(
+        "--hf-model-path",
+        default="",
+        type=str,
+        help="Path of the Huggingface DINOv2 model."
+    )
+    parser.add_argument(
+        "--model-version",
+        default="base",
+        choices=["small", "base", "large", "giant"],
+        help="Specify the architecture of DINOv2-Vit model to be converted."
+    )
+    parser.add_argument(
+        "--save-dir",
+        default="./"
+    )
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    input_args = parse_args()
+    convert_dinov2(input_args)