{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "547900ed",
   "metadata": {},
   "source": [
    "# 图模式-编程技巧\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_static_graph_expert_programming.ipynb)&emsp;\n",
    "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_static_graph_expert_programming.py)&emsp;\n",
    "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_source.svg)](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/compile/static_graph_expert_programming.ipynb)\n",
    "\n",
    "本章介绍常用的静态图优化的高级编程技巧。这些技巧能够有效地提高静态图的编译效率、执行效率，并提升程序的稳定性。有关静态图编译的基础介绍，请见[Graph Mode加速](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/accelerate_with_static_graph.html)。\n",
    "\n",
    "## 如何优化编译性能\n",
    "\n",
    "### 使用lazy_inline装饰器\n",
    "\n",
    "神经网络模型的编译过程往往采用默认inline的方式，把层级的代码表达最终展开成一张扁平的计算图，一方面寻求最大的编译优化机会，另一方面也可以简化自动微分以及执行的逻辑。inline后形成的计算图包含所有计算节点，可以在更大的范围内进行优化，比如常量折叠、节点融合、并行分析等。同时，也可以更好地实现内存分配，减少内存申请和性能开销。虽然inline优化对于运行期性能提升有很大帮助，但过度inline会增加编译期负担。例如，随着计算图节点数量膨胀，执行pass的耗时也在急剧增长。\n",
    "\n",
    "为了减轻inline对编译性能的影响，对于重复调用相同计算单元的场景（典型的场景是在for循环中调用同一个Cell类的不同实例），我们提供了Lazy Inline机制来减少编译时间。\n",
    "\n",
    "#### 大模型pipeline并行场景\n",
    "\n",
    "在大模型场景中，编译耗时问题尤为突出。一是大模型的模型结构层次深、节点数多；二是训练时启用pipeline并行，模型规模和节点数进一步增加。如果原来图的规模是O，那开启pipeline并行，单节点图的规模变为(O/X)*Y，其中X为pipeline的stage数量，Y为micro batch的数量。以盘古13B网络为例，计算图中计算节点数量达到13.5万个，单次编译时长可接近3小时。\n",
    "\n",
    "类似盘古的大模型网络结构由多层layer组成。开启pipeline并行时，各个micro batch的layer层结构完全相同。此时，`PipelineCell`通过for循环多次调用相同结构的layer，代码如下所示："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "514c3fe6",
   "metadata": {},
   "source": [
    "```python\n",
    "from mindspore import nn\n",
    "\n",
    "class PipelineCell(nn.Cell):\n",
    "    def __init__(self, network, micro_size):\n",
    "        ...\n",
    "        self.network = network\n",
    "        self.micro_size = micro_size\n",
    "        ...\n",
    "\n",
    "    def construct(self, ...):\n",
    "        ...\n",
    "        for i in range(self.micro_size):\n",
    "            output = self.network(...)\n",
    "        ...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9080abe3",
   "metadata": {},
   "source": [
    "如果把循环体看作被频繁调用的子图，通过把它标记为Lazy Inline，告知编译器推迟inline处理，那么就可以在编译的大部分阶段大幅度减少计算图节点数量，从而获得性能收益。例如上面的代码，可以保留`network`实例的子图结构，不inline或者延迟inline。对此，我们提供了`@lazy_inline`装饰器来实现延迟inline。\n",
    "\n",
    "以Pangu_alpha网络为例，`PipelineCell`函数体中处理的`network`为`PanGUAlphaWithLoss`类的实例。为实现延迟inline，我们需要在`PanGUAlphaWithLoss`类的`__init__`函数上加`@lazy_inline`装饰器，以标记该类的子图结构需要被保留，不进行inline或者延迟inline。如下所示："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5607ec06",
   "metadata": {},
   "source": [
    "```python\n",
    "from mindspore import nn, lazy_inline\n",
    "\n",
    "class PanGUAlphaWithLoss(nn.Cell):\n",
    "    @lazy_inline\n",
    "    def __init__(self, ...):\n",
    "        ...\n",
    "\n",
    "    def construct(self, ...):\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f0309ef",
   "metadata": {},
   "source": [
    "还是以盘古13B网络为例，应用Lazy Inline方案后，计算图节点数量从13万多个下降到2万多个，编译时间从3小时缩短到20分钟。\n",
    "\n",
    "#### 更加泛化的一般场景\n",
    "\n",
    "`@lazy_inline`是`Cell::__init__`的装饰器，它会以`__init__`的所有参数生成Cell的`cell_init_args`属性。`cell_init_args`值相同，表示Cell类名和初始化参数值一致。对于相同Cell类的实例，它们的weights可能不同，因此对于`construct(self, x)`定义的网络结构，实际编译时可以转换为`construct(x, self.cell_init_args, self.trainable_parameters())`。对于同一个Cell类的不同实例，只要`cell_init_args`相同，就可以复用同一个网络结构，如下所示："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c3f3655",
   "metadata": {},
   "source": [
    "```python\n",
    "def construct(self, x)\n",
    "    reuse_construct(x, self.trainable_parameters())\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efa1a2c5",
   "metadata": {},
   "source": [
    "引入可复用计算图后，具有相同`cell_init_args`的Cell实例只需编译解析一次。因此，在更通用的场景下，调用同一个Cell类的不同实例，只要`cell_init_args`相同，都可以加上`@lazy_inline`装饰器来加速编译。例如GPT网络："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef0fbd8c",
   "metadata": {},
   "source": [
    "```python\n",
    "from mindspore import nn, lazy_inline\n",
    "\n",
    "class Block(nn.Cell):\n",
    "    @lazy_inline\n",
    "    def __init__(self, config):\n",
    "        ...\n",
    "\n",
    "    def construct(self, x, attention_mask, layer_past):\n",
    "        ...\n",
    "\n",
    "class GPT_Model(nn.Cell):\n",
    "    def __init__(self, config):\n",
    "        ...\n",
    "        for i in range(config.num_layers):\n",
    "            self.blocks.append(Block(config))\n",
    "            ...\n",
    "        self.num_layers = config.num_layers\n",
    "\n",
    "    def construct(self, input_ids, input_mask, layer_past):\n",
    "        ...\n",
    "        present_layer = ()\n",
    "        for i in range(self.num_layers):\n",
    "            hidden_states, present = self.blocks[i](...)\n",
    "            present_layer = present_layer + (present,)\n",
    "        ...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0184b17",
   "metadata": {},
   "source": [
    "GPT的网络结构由多层`Block`类的不同实例构成，这些`Block`的初始化参数都是同一个`config`。因此，加上`@lazy_inline`装饰器后，这些`Block`实例都可以复用同一个网络结构，并且在大部分编译阶段不会进行inline，从而大幅减少编译时间。\n",
    "\n",
    "#### 使用步骤\n",
    "\n",
    "如上面的例子，在网络脚本中，需要延迟inline和复用子图结构的Cell类，其`__init__`函数需加上`@lazy_inline`装饰器。\n",
    "\n",
    "#### 使用限制\n",
    "\n",
    "1. Cell以类名和`__init__`参数值生成Cell实例标识。假设`__init__`参数确定Cell 的所有属性，并且`construct`构图开始时Cell属性和`__init__`执行完后属性一致。因此，Cell与构图有关的属性，在`__init__`执行完后不能进行更改。例如：\n",
    "\n",
    "   ```python\n",
    "   from mindspore import nn, lazy_inline\n",
    "\n",
    "   class Block(nn.Cell):\n",
    "       @lazy_inline\n",
    "       def __init__(self, ...):\n",
    "           self.x = 0\n",
    "           ...\n",
    "       def construct(self, ...):\n",
    "        if self.x == 0:\n",
    "             ...\n",
    "        else:\n",
    "           ...\n",
    "           ...\n",
    "   class Model(nn.Cell):\n",
    "       def __init__(self, ...):\n",
    "           ...\n",
    "           self.num_layers = 10\n",
    "           for i in range(self.num_layers):\n",
    "               self.blocks.append(Block(...)) # 此处Block进行初始化\n",
    "               ...\n",
    "           self.blocks[0].x = 1               # 此处在Block初始化后修改Block的属性，会导致该Block无法复用同一份子图\n",
    "\n",
    "       def construct(self, ...):\n",
    "            ...\n",
    "            for i in range(self.num_layers):\n",
    "               res = self.blocks[i](...)\n",
    "           ...\n",
    "   ```\n",
    "\n",
    "如上代码所示，若网络Model中某个`Block`实例的属性`x`在初始化后被修改，则该`Block`实例无法准确复用同一个子图结构。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00c363fd",
   "metadata": {},
   "source": [
    "2. 一个Cell类的网络结构包含多个Cell_X类实例，同时每个Cell_X类的网络结构又包含多个Cell_Y实例，如果同时在Cell_X和Cell_Y类的`__init__`函数上加上`@lazy_inline`，那么只有最外层的Cell_X实例的网络结构会被编译成可复用的计算图并延迟inline，内层的Cell_Y实例的计算图还是会被inline。例如：\n",
    "\n",
    "   ```python\n",
    "   from mindspore import nn, lazy_inline\n",
    "\n",
    "   class InnerBlock(nn.Cell):\n",
    "       @lazy_inline             # InnerBlock不会被延迟inline\n",
    "       def __init__(self, ...):\n",
    "           ...\n",
    "       def construct(self, ...):\n",
    "           ...\n",
    "   class OuterBlock(nn.Cell):\n",
    "       @lazy_inline             # OuterBlock将会被延迟inline\n",
    "       def __init__(self, ...):\n",
    "           ...\n",
    "           self.num_layers = 10\n",
    "           for i in range(self.num_layers):\n",
    "               self.blocks.append(InnerBlock(...))\n",
    "       def construct(self, ...):\n",
    "            ...\n",
    "            for i in range(self.num_layers):\n",
    "              res = self.blocks[i](...)\n",
    "            ...\n",
    "   class Model(nn.Cell):\n",
    "       def __init__(self, ...):\n",
    "           ...\n",
    "           self.num_layers = 10\n",
    "           for i in range(self.num_layers):\n",
    "               self.blocks.append(OuterBlock(...))\n",
    "       def construct(self, ...):\n",
    "         ...\n",
    "            for i in range(self.num_layers):\n",
    "              res = self.blocks[i](...)\n",
    "           ...\n",
    "   ```\n",
    "\n",
    "后续计划支持多层级的Lazy Inline机制。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e9dd7b8",
   "metadata": {},
   "source": [
    "### 使用HyperMap\n",
    "\n",
    "使用场景：使用HyperMap替换for循环来优化编译性能。\n",
    "\n",
    "`HyperMap`是一个特殊的类，构造对象时需要传入映射函数f，调用对象时需要传入f的n个参数序列。更多使用方法见：[HyperMap](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/ops/mindspore.ops.HyperMap.html)。映射函数f必须是`MultitypeFuncGraph`类型, 可参考[MultitypeFuncGraph](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/ops/mindspore.ops.MultitypeFuncGraph.html)。在使用for循环批量处理列表元素时，可以通过`HyperMap`等价替换来优化网络编译性能。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88b51f82",
   "metadata": {},
   "source": [
    "### 使用编译缓存\n",
    "\n",
    "使用场景：在训练或推理时，如果编译依赖文件未发生变更，可以通过编译缓存来缩短编译时间。\n",
    "\n",
    "编译缓存本质上是存储网络模型的编译中间过程文件。当网络模型不变时，生产的编译中间过程文件也是一样的，因此可以复用上一次生成的中间过程文件。\n",
    "\n",
    "通过设置环境变量[MS_COMPILER_CACHE_ENABLE](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/env_var_list.html?highlight=MS_COMPILER_CACHE_ENABLE)，可以指定是否保存和加载编译缓存。\n",
    "\n",
    "通过设置环境变量[MS_COMPILER_CACHE_PATH](https://www.mindspore.cn/docs/zh-CN/r2.9.0/api_python/env_var_list.html?highlight=MS_COMPILER_CACHE_PATH)，可以指定MindSpore编译缓存目录，用于存储图和算子编译过程生成的缓存文件。\n",
    "\n",
    "以下为通过使能编译缓存来优化编译性能的代码示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e98fb56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Disable compile_cache cost time: 0.5485098361968994\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "mindspore.set_context(mode=mindspore.GRAPH_MODE)\n",
    "\n",
    "class Model(nn.Cell):\n",
    "    def construct(self, input_x, input_y):\n",
    "        output = input_x\n",
    "        for _ in range(200):\n",
    "            output = input_x + input_x * input_y + output\n",
    "        return output\n",
    "\n",
    "os.environ['MS_COMPILER_CACHE_ENABLE'] = '0'\n",
    "x = mindspore.tensor([1], mindspore.float32)\n",
    "y = mindspore.tensor([2], mindspore.float32)\n",
    "model = Model()\n",
    "start_time = time.time()\n",
    "out = model(x, y)\n",
    "end_time = time.time()\n",
    "print(\"Disable compile_cache cost time:\", end_time - start_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e66fb52",
   "metadata": {},
   "source": [
    "上述测试样例为关闭编译缓存状态，执行两次后的耗时如下（实际耗时与硬件环境有关，以下数据仅供参考）：\n",
    "\n",
    "```text\n",
    "Disable compile_cache cost time: 0.5485098361968994\n",
    "Disable compile_cache cost time: 0.4614279270172119\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3805561",
   "metadata": {},
   "source": [
    "可以看到，关闭编译缓存时，第2次执行样例比第1次耗时少一些，这是因为算子编译缓存是默认开启的，第2次执行样例能够利用前一次的算子编译缓存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278e2ca1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Enable compile_cache cost time: 0.6357541084289551\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "mindspore.set_context(mode=mindspore.GRAPH_MODE)\n",
    "\n",
    "class Model(nn.Cell):\n",
    "    def construct(self, input_x, input_y):\n",
    "        output = input_x\n",
    "        for _ in range(200):\n",
    "            output = input_x + input_x * input_y + output\n",
    "        return output\n",
    "\n",
    "os.environ['MS_COMPILER_CACHE_ENABLE'] = '1'\n",
    "os.environ['MS_COMPILER_CACHE_PATH'] = 'my_compile_cache'\n",
    "x = mindspore.tensor([1], mindspore.float32)\n",
    "y = mindspore.tensor([2], mindspore.float32)\n",
    "model = Model()\n",
    "start_time = time.time()\n",
    "out = model(x, y)\n",
    "end_time = time.time()\n",
    "os.environ['MS_COMPILER_CACHE_ENABLE'] = '0'\n",
    "print(\"Enable compile_cache cost time:\", end_time - start_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a76213d",
   "metadata": {},
   "source": [
    "上述测试样例为开启编译缓存状态，执行两次后的耗时如下（实际耗时与硬件环境有关，以下数据仅供参考）：\n",
    "\n",
    "```text\n",
    "Enable compile_cache cost time: 0.6357541084289551\n",
    "Enable compile_cache cost time: 0.09379792213439941\n",
    "```\n",
    "\n",
    "可以看到，开启编译缓存时，第2次执行样例耗时只有第一次执行耗时的1/7左右。\n",
    "\n",
    "说明：开启编译缓存功能时，第一次执行由于暂未生成缓存，所以会有如下Warning："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e49d53e5",
   "metadata": {},
   "source": [
    "```text\n",
    "Warning: Check the consistency of dependency files hash failed. Execute all the compilation actions.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5655959",
   "metadata": {},
   "source": [
    "## 如何优化执行性能\n",
    "\n",
    "### 使用jit_class\n",
    "\n",
    "使用场景：使用`@jit_class`装饰器修饰自定义类，提高执行性能。jit_class应用于静态图模式，在动态图模式下，`@jit_class`会被忽略，不影响动态图模式的执行逻辑。\n",
    "\n",
    "#### jit_class的介绍\n",
    "\n",
    "用户在网络脚本中定义一个类时，可以写成继承于`Cell`的类、自定义类、`@jit_class`修饰的类，它们的用法和区别如下：\n",
    "\n",
    "- 继承于Cell的类\n",
    "\n",
    "  Cell是MindSpore中神经网络的基本构成单元，模型或者神经网络层应当继承该类。静态图模式下，使用`Cell`类并且在`construct`函数中编写执行代码，此时`construct`函数的代码会被编译成静态计算图。\n",
    "\n",
    "- 自定义类\n",
    "\n",
    "  定义自定义类后，可以对类进行实例化、调用类对象的属性和方法，请参考[自定义类的使用](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph.html#支持自定义类的使用)。相比于`Cell`的类定义，自定义类更贴近用户调用Python类的使用习惯。自定义类在静态图模式下的实现方式与`Cell`不同，例如，调用自定义类对象的函数方法时，其函数方法中的代码不会被编译成静态计算图，而是通过Python解释器进行解释执行。\n",
    "\n",
    "- `@jit_class`修饰的类\n",
    "\n",
    "  为了兼顾用户的Python使用习惯和静态图编译带来的性能优势，提供了`@jit_class`装饰器。给自定义类修饰`@jit_class`装饰器后，该类的函数代码会被编译成静态计算图，基于图优化、静态图整图下沉等技术，编译器可以针对计算图进行全局的优化，从而获得较好的执行性能。\n",
    "\n",
    "在静态图模式下，通过使用`@jit_class`修饰自定义类，用户可以创建、调用该类的实例，并且可以获取其属性和方法。\n",
    "\n",
    "#### jit_class装饰器的使用\n",
    "\n",
    "jit_class装饰器仅支持修饰自定义类，不支持修饰继承于`Cell`的类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c0234770",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    value = mindspore.tensor(np.array([1, 2, 3]))\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self):\n",
    "        return InnerNet().value\n",
    "\n",
    "net = Net()\n",
    "out = net()\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe681740",
   "metadata": {},
   "source": [
    "如果jit_class修饰继承于`Cell`的类，将会报错。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106d6e6a",
   "metadata": {},
   "source": [
    "```python\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        return x\n",
    "\n",
    "x = mindspore.tensor(1)\n",
    "net = Net()\n",
    "net(x)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e566051b",
   "metadata": {},
   "source": [
    "报错信息如下："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d633c29",
   "metadata": {},
   "source": [
    "```text\n",
    "TypeError: Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6163d8ce",
   "metadata": {},
   "source": [
    "jit_class支持自定义类嵌套使用、自定义类与`Cell`嵌套使用的场景。需要注意的是，类继承时，如果父类使用了jit_class，子类也会具有jit_class的能力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "32579fd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class Inner:\n",
    "    def __init__(self):\n",
    "        self.value = mindspore.tensor(np.array([1, 2, 3]))\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    def __init__(self):\n",
    "        self.inner = Inner()\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.inner_net = InnerNet()\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self):\n",
    "        out = self.inner_net.inner.value\n",
    "        return out\n",
    "\n",
    "net = Net()\n",
    "out = net()\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e99952c",
   "metadata": {},
   "source": [
    "#### 获取类的属性和方法\n",
    "\n",
    "支持通过类名或类实例调用属性和方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c5c1a327",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    def __init__(self, val):\n",
    "        self.number = val\n",
    "\n",
    "    def act(self, x, y):\n",
    "        return self.number * (x + y)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.inner_net = InnerNet(2)\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self, x, y):\n",
    "        return self.inner_net.number + self.inner_net.act(x, y)\n",
    "\n",
    "x = mindspore.tensor(2, dtype=mindspore.int32)\n",
    "y = mindspore.tensor(3, dtype=mindspore.int32)\n",
    "net = Net()\n",
    "out = net(x, y)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42eeec97",
   "metadata": {},
   "source": [
    "#### 创建类的实例\n",
    "\n",
    "对于将会被编译成静态计算图的函数，如`Cell`的`construct`函数、`@jit`修饰的函数或前两者调用的子函数，如果需要在函数内创建`@jit_class`所修饰的类的实例，参数要求为常量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6697edd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    def __init__(self, val):\n",
    "        self.number = val + 3\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self):\n",
    "        net = InnerNet(2)\n",
    "        return net.number\n",
    "\n",
    "net = Net()\n",
    "out = net()\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5cabba3",
   "metadata": {},
   "source": [
    "#### 调用类的实例\n",
    "\n",
    "调用`@jit_class`所修饰的类的实例时，将会调用该类的`__call__`函数方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f0a25831",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    def __init__(self, number):\n",
    "        self.number = number\n",
    "\n",
    "    def __call__(self, x, y):\n",
    "        return self.number * (x + y)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x, y):\n",
    "        net = InnerNet(2)\n",
    "        out = net(x, y)\n",
    "        return out\n",
    "\n",
    "x = mindspore.tensor(2, dtype=mindspore.int32)\n",
    "y = mindspore.tensor(3, dtype=mindspore.int32)\n",
    "net = Net()\n",
    "out = net(x, y)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e384c9",
   "metadata": {},
   "source": [
    "如果该类没有定义`__call__`函数，将会报错提示。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "592c8d0c",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class InnerNet:\n",
    "    def __init__(self, number):\n",
    "        self.number = number\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x, y):\n",
    "        net = InnerNet(2)\n",
    "        out = net(x, y)\n",
    "        return out\n",
    "\n",
    "x = mindspore.tensor(2, dtype=mindspore.int32)\n",
    "y = mindspore.tensor(3, dtype=mindspore.int32)\n",
    "net = Net()\n",
    "out = net(x, y)\n",
    "print(out)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29c4b91f",
   "metadata": {},
   "source": [
    "报错信息如下："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48212c94",
   "metadata": {},
   "source": [
    "```text\n",
    "ValueError: MsClassObject: 'InnerNet' has no __call__ function, please check the code.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78c3a5b9",
   "metadata": {},
   "source": [
    "### 使用select算子\n",
    "\n",
    "使用场景：使用`Select`算子替代if控制流语句，减少静态图子图生成，提高执行性能（也可以提高编译性能）。\n",
    "\n",
    "编写网络时，经常会用到if语句。如果if语句的条件为变量，每个if语句都会产生额外子图。在静态图模式下，子图数量越多，编译耗时越久，因此部分场景可以通过`Select`算子等价替换if语句来优化编译性能。\n",
    "\n",
    "需要注意的是，使用`Select`算子替换if语句会影响网络的运行性能。一方面，`Select`算子会同时执行true分支和false分支，而if语句只执行其一个分支，因此使用if运行耗时较少；另一方面，`Select`算子性能优于if语句产生的控制流算子，使用if运行耗时反而可能增加。综合上述两种因素，最终运行性能变化情况需要结合实际情况判断。一般来讲，当分支中算子数量较少，建议使用`Select`算子；当分支中算子数量较多，建议使用if语句。\n",
    "\n",
    "以下为使用`Select`算子替代if语句来优化编译性能的代码示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "50935382",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "if net cost time: 2.1914021968841553\n",
      "select net cost time: 0.7116048336029053\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import mindspore\n",
    "from mindspore import ops\n",
    "\n",
    "@mindspore.jit\n",
    "def if_net(x, y):\n",
    "    out = 0\n",
    "    for _ in range(100):\n",
    "        if x < y:\n",
    "            x = x - y\n",
    "        else:\n",
    "            x = x + y\n",
    "        out = out + x\n",
    "    return out\n",
    "\n",
    "start_time = time.time()\n",
    "out = if_net(mindspore.tensor([0]), mindspore.tensor([1]))\n",
    "end_time = time.time()\n",
    "print(\"if net cost time:\", end_time - start_time)\n",
    "\n",
    "@mindspore.jit\n",
    "def select_net(x, y):\n",
    "    out = x\n",
    "    for _ in range(100):\n",
    "        cond = x < y\n",
    "        x = ops.select(cond, x - y, x + y)\n",
    "        out = out + x\n",
    "    return out\n",
    "\n",
    "start_time = time.time()\n",
    "out = select_net(mindspore.tensor([0]), mindspore.tensor([1]))\n",
    "end_time = time.time()\n",
    "print(\"select net cost time:\", end_time - start_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45095e4e",
   "metadata": {},
   "source": [
    "### 使用Vmap进行批处理\n",
    "\n",
    "使用场景：在处理无依赖关系的批量数据且相关的算子支持Vmap功能时，可以使用Vmap替代for循环处理批量数据来优化执行性能（也可以提高编译性能）。\n",
    "\n",
    "MindSpore已支持Vmap功能。\n",
    "\n",
    "以下为使用Vmap替换for循环处理批量数据来优化编译性能的代码示例：\n",
    "\n",
    "代码的运行结果如下（实际耗时与硬件环境有关，以下数据仅供参考）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b89d6496",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vmap cost time: 0.05766916275024414\n",
      "for loop cost time: 1.9284062385559082\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import mindspore\n",
    "from mindspore import ops, vmap\n",
    "\n",
    "def hswish_func(x):\n",
    "    return ops.HSwish()(x)\n",
    "\n",
    "@mindspore.jit\n",
    "def manually_batched(xs):\n",
    "    output = []\n",
    "    for i in range(xs.shape[0]):\n",
    "        output.append(hswish_func(xs[i]))\n",
    "    return ops.stack(output)\n",
    "\n",
    "shape = (100, 2)\n",
    "prop = 100\n",
    "x_np = (np.random.randn(*shape) * prop).astype(np.float32)\n",
    "x = mindspore.tensor(x_np)\n",
    "x = ops.sub(x, 0)\n",
    "\n",
    "start_time = time.time()\n",
    "output_vmap = vmap(hswish_func, in_axes=(0,))(x)\n",
    "end_time = time.time()\n",
    "print(\"Vmap cost time:\", end_time - start_time)\n",
    "\n",
    "start_time = time.time()\n",
    "output_manually = manually_batched(x)\n",
    "end_time = time.time()\n",
    "print(\"for loop cost time:\", end_time - start_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ded0e87",
   "metadata": {},
   "source": [
    "上述样例中，相当于需要批量处理100组Tensor数据。可以看到，使用Vmap处理的性能超过使用for循环处理性能的30倍。\n",
    "\n",
    "## 依赖控制保证执行序\n",
    "\n",
    "如果函数的运行结果依赖或影响外部状态，则认为该函数具有副作用，比如函数会改变外部全局变量、函数的结果依赖全局变量的值。如果算子会改变输入参数的值，或者算子的输出依赖全局参数的值，则认为这是带副作用的算子。\n",
    "\n",
    "根据内存属性和IO状态，将副作用划分为内存副作用和IO副作用。当前内存副作用主要有Assign、优化器算子等等，IO副作用主要有Print算子。详细可以查看算子定义，内存副作用算子在定义中有side_effect_mem属性，IO副作用算子在定义中有side_effect_io属性。\n",
    "\n",
    "Depend用于处理依赖项操作。在大多数情况下，如果操作符有IO副作用或内存副作用，则将根据用户的语义执行，不需要另外使用Depend算子来保证执行顺序。在某些情况下，如果两个运算符A和B没有顺序依赖关系，但A必须在B之前执行，建议使用Depend指定执行顺序。使用方法如下："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "966360dc",
   "metadata": {},
   "source": [
    "```python\n",
    "a = A(x)\n",
    "b = B(y)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fe66617",
   "metadata": {},
   "source": [
    "插入Depend算子："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60817470",
   "metadata": {},
   "source": [
    "```python\n",
    "a = A(x)\n",
    "y = Depend(y, a)\n",
    "b = B(y)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec290f5b",
   "metadata": {},
   "source": [
    "值得说明的是，用于浮点数溢出状态检测的一组特殊算子存在隐含副作用，但又不属于IO副作用或内存副作用。此外，使用时还有严格的顺序要求：在使用NPUClearFloatStatus算子前需要保证NPUAllocFloatStatus已经执行，使用NPUGetFloatStatus算子前需要保证NPUClearFloatStatus已经执行。由于这些算子使用较少，目前的方案是将其定义为无副作用形式，以Depend确保执行顺序。注意：此类浮点数溢出状态检测的算子仅在Ascend平台支持。如下："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7043dde3",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn, ops, Tensor\n",
    "\n",
    "mindspore.set_device(\"Ascend\")\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.alloc_status = ops.NPUAllocFloatStatus()\n",
    "        self.get_status = ops.NPUGetFloatStatus()\n",
    "        self.clear_status = ops.NPUClearFloatStatus()\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        init = self.alloc_status()\n",
    "        clear_status = self.clear_status(init)\n",
    "        x = ops.Depend()(x, clear_status)\n",
    "        res = ops.sub(x, ops.neg(x))\n",
    "        init = ops.Depend()(init, res)\n",
    "        get_status = self.get_status(init)\n",
    "        res = ops.Depend()(res, get_status)\n",
    "        return res\n",
    "\n",
    "value = 5\n",
    "data = np.full((2, 3), value, dtype=np.float16)\n",
    "x = mindspore.tensor(data, dtype=mindspore.float16)\n",
    "net = Net()\n",
    "res = net(x)\n",
    "print(res)\n",
    "```\n",
    "\n",
    "```text\n",
    "[[10. 10. 10.]\n",
    " [10. 10. 10.]]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d14746ba",
   "metadata": {},
   "source": [
    "## 优化冗余显存拷贝操作\n",
    "\n",
    "在函数式编程中，通过参数和返回值之外的渠道和外界进行数据交换的函数，被称为非纯函数，被认为是存在副作用的。在MindSpore框架内部，针对副作用的问题会插入Load算子，该算子属于虚拟算子，不需要在后端执行，不占用显存，仅用于表示需要读取全局变量的值。在图模式下，需要编译完整个图之后才将图中的各个算子下发到后端执行。使用Load算子多次读取全局变量，而不是多次使用真实算子多保存全局变量的值，这样可以减少显存消耗。\n",
    "\n",
    "但是，全局变量的值可能是变化的，如果没有真实算子保存值，某些场景下会出现精度问题。针对这种情况，MindSpore框架内部会插入真实算子，占用一定的显存来保存全局变量的值，从而避免精度问题。\n",
    "\n",
    "MindSpore提供了MS_DEV_SIDE_EFFECT_LOAD_ELIM开关来优化显存占用程度，可以通过设置export MS_DEV_SIDE_EFFECT_LOAD_ELIM=0/1/2/3进行调整。设置规则如下：\n",
    "\n",
    "- 当设置为0时，表示对框架内部的Load算子都插入真实算子，显存占用最多，保证网络精度。\n",
    "- 当设置为1或者没有设置值（即默认模式）时，表示仅在可能出现精度问题的场景保守地插入真实算子，保证网络精度。\n",
    "- 当设置为2时，在损耗一定编译性能的前提下，尽量少地插入真实算子，优化显存占用，且保证网络精度。\n",
    "- 当设置为3时，不插入真实算子，不保证网络精度，显存消耗最少。\n",
    "\n",
    "我们可以通过用例和生成的中间表示（IR）来进一步理解。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05537350",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn, ops, Tensor, Parameter\n",
    "\n",
    "class ForwardNet(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(ForwardNet, self).__init__()\n",
    "        self.weight = Parameter(mindspore.tensor(np.array(0), mindspore.int32), name=\"param\")\n",
    "\n",
    "    def construct(self, x):\n",
    "        out = 0\n",
    "        i = 0\n",
    "        while i < 3:\n",
    "            ops.assign(self.weight, i)\n",
    "            out = x * self.weight + out\n",
    "            i = i + 1\n",
    "        return out\n",
    "\n",
    "class BackwardNet(nn.Cell):\n",
    "    def __init__(self, net):\n",
    "        super(BackwardNet, self).__init__(auto_prefix=False)\n",
    "        self.forward_net = net\n",
    "        self.grad = ops.GradOperation(get_all=True)\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self, *inputs):\n",
    "        grads = self.grad(self.forward_net)(*inputs)\n",
    "        return grads\n",
    "\n",
    "x = mindspore.tensor(np.array(1), mindspore.int32)\n",
    "graph_forward_net = ForwardNet()\n",
    "graph_backward_net = BackwardNet(graph_forward_net)\n",
    "graph_mode_grads = graph_backward_net(x)\n",
    "output_except = (mindspore.tensor(np.array(3), mindspore.int32),)\n",
    "assert np.all(graph_mode_grads == output_except)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60c7ca19",
   "metadata": {},
   "source": [
    "如上用例，通过设置export MS_DEV_SAVE_GRAPHS=1保存中间文件，可以得到中间文件IR。为了便于查看，我们将得到的中间文件简化如下：\n",
    "\n",
    "在未对框架内部的Load算子插入真实算子时，IR文件如下。可以看到存在3个Load算子，分别获取para2_param全局变量在不同时间点的值，而该变量会被Assign算子修改。即3个Load获取的值是不同的。如果我们没有对Load算子插入真实算子，即未保存para2_param全局变量在不同时间点的值，则最终结果是不对的。此时MS_DEV_SIDE_EFFECT_LOAD_ELIM设置为3，内存占用最少，但结果存在精度问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08ae4853",
   "metadata": {},
   "source": [
    "```text\n",
    "# IR entry: @BackwardNet_construct\n",
    "# Total subgraphs: 1\n",
    "# Total params: 2\n",
    "# Params:\n",
    "%para1_inputs0 : <Tensor[Int32], ()>\n",
    "%para2_param : <Ref[Tensor[Int32]], (), ref_key=:param>  :  has_default\n",
    "\n",
    "subgraph @BackwardNet_construct() {\n",
    "  %0 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=0), U)\n",
    "  %1 = UpdateState(U, %0)\n",
    "  %2 = Load(%para2_param, %1)\n",
    "  %3 = UpdateState(%1, %2)\n",
    "  %4 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=1), %3)\n",
    "  %5 = UpdateState(%3, %4)\n",
    "  %6 = Load(%para2_param, %5)\n",
    "  %7 = UpdateState(%5, %6)\n",
    "  %8 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=2), %7)\n",
    "  %9 = UpdateState(%7, %8)\n",
    "  %10 = Load(%para2_param, %9)\n",
    "  %11 = MakeTuple(%10, %6)\n",
    "  %12 = AddN(%11)\n",
    "  %13 = MakeTuple(%12, %2)\n",
    "  %14 = AddN(%13)\n",
    "  %15 = MakeTuple(%14)\n",
    "  %16 = UpdateState(%9, %10)\n",
    "  %17 = Depend(%15, %16)\n",
    "  Return(%17)\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7378662",
   "metadata": {},
   "source": [
    "将MS_DEV_SIDE_EFFECT_LOAD_ELIM设置为0、1或2时，可以得到简化后的IR图如下。由于该场景中的Load算子均需要插入真实算子来保存每次Assign算子修改后的值，因此设置为0、1、2时得到的IR文件是一致的。更多复杂的情况下，设置为0、1、2时可能是不同的，在此不再一一展开。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0a3d96c",
   "metadata": {},
   "source": [
    "```text\n",
    "# IR entry: @BackwardNet_construct\n",
    "# Total subgraphs: 1\n",
    "# Total params: 2\n",
    "# Params:\n",
    "%para1_inputs0 : <Tensor[Int32], ()>\n",
    "%para2_param : <Ref[Tensor[Int32]], (), ref_key=:param>  :  has_default\n",
    "\n",
    "subgraph @BackwardNet_construct() {\n",
    "  %0 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=0), U)\n",
    "  %1 = UpdateState(U, %0)\n",
    "  %2 = Load(%para2_param, %1)\n",
    "  %3 = TensorMove(%2)\n",
    "  %4 = UpdateState(%1, %3)\n",
    "  %5 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=1), %4)\n",
    "  %6 = UpdateState(%4, %5)\n",
    "  %7 = Load(%para2_param, %6)\n",
    "  %8 = TensorMove(%7)\n",
    "  %9 = UpdateState(%6, %8)\n",
    "  %10 = Assign(%para2_param, Tensor(shape=[], dtype=Int32, value=2), %9)\n",
    "  %11 = UpdateState(%9, %10)\n",
    "  %12 = Load(%para2_param, %11)\n",
    "  %13 = TensorMove(%12)\n",
    "  %14 = MakeTuple(%13, %8)\n",
    "  %15 = AddN(%14)\n",
    "  %16 = MakeTuple(%15, %3)\n",
    "  %17 = AddN(%16)\n",
    "  %18 = MakeTuple(%17)\n",
    "  %19 = UpdateState(%11, %13, %15)\n",
    "  %20 = Depend(%18, %19)\n",
    "  Return(%20)\n",
    "}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
