{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "be874a9a",
   "metadata": {},
   "source": [
    "# 图模式语法-python语句\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_statements.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_statements.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_source.svg)](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/compile/statements.ipynb)\n",
    "\n",
    "## 简单语句\n",
    "\n",
    "### raise语句\n",
    "\n",
    "支持使用`raise`触发异常。`raise`语法格式：`raise[Exception [, args]]`。语句中的`Exception`是异常的类型，`args`是用户提供的异常参数，通常可以是字符串或者其他对象。\n",
    "\n",
    "目前支持的异常类型有：NoExceptionType、UnknownError、ArgumentError、NotSupportError、NotExistsError、DeviceProcessError、AbortedError、IndexError、ValueError、TypeError、KeyError、AttributeError、NameError、AssertionError、BaseException、KeyboardInterrupt、Exception、StopIteration、OverflowError、ZeroDivisionError、EnvironmentError、IOError、OSError、ImportError、MemoryError、UnboundLocalError、RuntimeError、NotImplementedError、IndentationError、RuntimeWarning。\n",
    "\n",
    "图模式下的raise语法不支持`Dict`类型的变量。\n",
    "\n",
    "例如："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707aba68",
   "metadata": {},
   "source": [
    "```python\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self, x, y):\n",
    "        if x <= y:\n",
    "            raise ValueError(\"x should be greater than y.\")\n",
    "        else:\n",
    "            x += 1\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "net(mindspore.tensor(-2), mindspore.tensor(-1))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cf448c1",
   "metadata": {},
   "source": [
    "输出结果:\n",
    "\n",
    "ValueError: x should be greater than y."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e166a75d",
   "metadata": {},
   "source": [
    "### assert语句\n",
    "\n",
    "支持使用assert来做异常检查，`assert`语法格式：`assert[Expression [, args]]`。其中`Expression`是判断条件，如果条件为真，就不做任何事情；条件为假时，则将抛出`AssertError`类型的异常信息。`args`是用户提供的异常参数，通常可以是字符串或者其他对象。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "622c2534",
   "metadata": {},
   "source": [
    "```python\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        assert x in [2, 3, 4]\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "net(mindspore.tensor(-1))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df8c7712",
   "metadata": {},
   "source": [
    "输出结果中正常出现:\n",
    "\n",
    "AssertionError."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9de947c0",
   "metadata": {},
   "source": [
    "### pass语句\n",
    "\n",
    "`pass`语句不做任何事情，通常用于占位，保持结构的完整性。例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "66f4171b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret: 50.625\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        i = 0\n",
    "        while i < 5:\n",
    "            if i > 3:\n",
    "                pass\n",
    "            else:\n",
    "                x = x * 1.5\n",
    "            i += 1\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "ret = net(10)\n",
    "print(\"ret:\", ret)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbd8719e",
   "metadata": {},
   "source": [
    "### return语句\n",
    "\n",
    "`return`语句通常是将结果返回调用的地方，`return`语句之后的语句不被执行。如果返回语句没有任何表达式，或者函数没有`return`语句，则默认返回一个`None`对象。一个函数体内可以根据不同的情况有多个`return`语句。例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ad2f05a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret: 10\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        if x > 0:\n",
    "            return x\n",
    "        return 0\n",
    "\n",
    "net = Net()\n",
    "ret = net(10)\n",
    "print(\"ret:\", ret)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d83cb8b0",
   "metadata": {},
   "source": [
    "如上，在控制流场景语句中，可以有多个`return`语句。如果一个函数中没有`return`语句，则默认返回None对象，如下用例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a2ad7018",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "mindspore.set_device(\"CPU\")\n",
    "@mindspore.jit\n",
    "def foo():\n",
    "    x = 3\n",
    "    print(\"x:\", x)\n",
    "\n",
    "res = foo()\n",
    "assert res is None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71fcf4e1",
   "metadata": {},
   "source": [
    "### break语句\n",
    "\n",
    "`break`语句用来终止循环语句，即循环条件没有`False`条件或者序列还没完全递归完时，也会停止执行循环语句，通常用在`while`和`for`循环中。在嵌套循环中，`break`语句将停止执行最内层的循环。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7826f8f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret: 1920\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        for i in range(8):\n",
    "            if i > 5:\n",
    "                x *= 3\n",
    "                break\n",
    "            x = x * 2\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "ret = net(10)\n",
    "print(\"ret:\", ret)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5eb5f172",
   "metadata": {},
   "source": [
    "### continue语句\n",
    "\n",
    "`continue`语句用来跳出当前的循环语句，进入下一轮的循环。与`break`语句有所不同，`break`语句用来终止整个循环语句。`continue`也用在`while`和`for`循环中。例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8b417f48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret: 9\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self, x):\n",
    "        for i in range(4):\n",
    "            if i > 2:\n",
    "                x *= 3\n",
    "            continue\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "ret = net(3)\n",
    "print(\"ret:\", ret)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7a65be9",
   "metadata": {},
   "source": [
    "## 复合语句\n",
    "\n",
    "### 条件控制语句\n",
    "\n",
    "#### if语句\n",
    "\n",
    "使用方式：\n",
    "\n",
    "- `if (cond): statements...`\n",
    "\n",
    "- `x = y if (cond) else z`\n",
    "\n",
    "参数：`cond` - 支持`bool`类型的变量，也支持类型为`Number`、`List`、`Tuple`、`Dict`、`String`类型的常量以及`None`对象。\n",
    "\n",
    "限制：\n",
    "\n",
    "- 如果`cond`不为常量，在不同分支中同一符号被赋予的变量或者常量的数据类型应一致。如果是被赋予变量或者常量数据类型是`Tensor`，则要求`Tensor`的type和shape也应一致。\n",
    "\n",
    "- 图模式中，要求变量必须在使用前定义。在控制流内部定义，外部使用将会报错。例如示例4。\n",
    "\n",
    "示例1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed905ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:1\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "x = mindspore.tensor([1, 4], mindspore.int32)\n",
    "y = mindspore.tensor([0, 3], mindspore.int32)\n",
    "m = 1\n",
    "n = 2\n",
    "\n",
    "@mindspore.jit\n",
    "def test_if_cond(x, y):\n",
    "    if (x > y).any():\n",
    "        return m\n",
    "    return n\n",
    "\n",
    "ret = test_if_cond(x, y)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0fe1b35",
   "metadata": {},
   "source": [
    "`if`分支返回的`m`和`else`分支返回的`n`，二者数据类型必须一致。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aec3af1",
   "metadata": {},
   "source": [
    "示例2："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c44982",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:1\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "x = mindspore.tensor([1, 4], mindspore.int32)\n",
    "y = mindspore.tensor([0, 3], mindspore.int32)\n",
    "m = 1\n",
    "n = 2\n",
    "\n",
    "@mindspore.jit\n",
    "def test_if_cond(x, y):\n",
    "    out = 3\n",
    "    if (x > y).any():\n",
    "        out = m\n",
    "    else:\n",
    "        out = n\n",
    "    return out\n",
    "\n",
    "ret = test_if_cond(x, y)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2148097",
   "metadata": {},
   "source": [
    "`if`分支中`out`被赋值的变量或者常量`m`与`else`分支中`out`被赋值的变量或者常量`n`的数据类型必须一致。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "757452c0",
   "metadata": {},
   "source": [
    "示例3："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c6cf0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:1\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "x = mindspore.tensor([1, 4], mindspore.int32)\n",
    "y = mindspore.tensor([0, 3], mindspore.int32)\n",
    "m = 1\n",
    "\n",
    "@mindspore.jit\n",
    "def test_if_cond(x, y):\n",
    "    out = 2\n",
    "    if (x > y).any():\n",
    "        out = m\n",
    "    return out\n",
    "\n",
    "ret = test_if_cond(x, y)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36ebd6a4",
   "metadata": {},
   "source": [
    "`if`分支中`out`被赋值的变量或者常量`m`与`out`初始赋值的数据类型必须一致。\n",
    "\n",
    "示例4：在控制流外部使用变量z，需要在外部定义。否则将报错：UnboundLocalError: The local variable 'z' is not defined in false branch, but defined in true branch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c2df5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "\n",
    "x = mindspore.tensor([1, 4], mindspore.int32)\n",
    "y = mindspore.tensor([0, 3], mindspore.int32)\n",
    "\n",
    "@mindspore.jit\n",
    "def test_if_cond(x, y):\n",
    "    if (x > y).any():\n",
    "        z = x + 1\n",
    "    return z\n",
    "\n",
    "ret = test_if_cond(x, y)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1676a957",
   "metadata": {},
   "source": [
    "### 循环语句\n",
    "\n",
    "#### for语句\n",
    "\n",
    "使用方式：\n",
    "\n",
    "- `for i in sequence  statements...`\n",
    "\n",
    "- `for i in sequence  statements... if (cond) break`\n",
    "\n",
    "- `for i in sequence  statements... if (cond) continue`\n",
    "\n",
    "参数：`sequence` - 遍历序列(`Tuple`、`List`、`range`等)\n",
    "\n",
    "限制：\n",
    "\n",
    "- 图的算子数量和`for`循环的迭代次数成倍数关系。`for`循环迭代次数过大可能会导致图占用内存超过使用限制。\n",
    "\n",
    "- 不支持`for...else...`语句。\n",
    "\n",
    "- 图模式中，要求变量必须在使用前定义。for循环后要使用的变量，在控制流内部定义，外部使用将会报错。例如示例2。\n",
    "\n",
    "示例1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e9582e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:[[7. 7. 7.]\n",
      " [7. 7. 7.]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "\n",
    "z = mindspore.tensor(np.ones((2, 3)))\n",
    "\n",
    "@mindspore.jit\n",
    "def test_cond():\n",
    "    x = (1, 2, 3)\n",
    "    for i in x:\n",
    "        z += i\n",
    "    return z\n",
    "\n",
    "ret = test_cond()\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fe31ded",
   "metadata": {},
   "source": [
    "示例2：将报错：NameError: The name 'z' is not defined, or not supported in graph mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0dc01c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test_cond():\n",
    "    x = (1, 2, 3)\n",
    "    for i in x:\n",
    "        z += i\n",
    "    return z\n",
    "\n",
    "ret = test_cond()\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b7a6630",
   "metadata": {},
   "source": [
    "#### while语句\n",
    "\n",
    "使用方式：\n",
    "\n",
    "- `while (cond)  statements...`\n",
    "\n",
    "- `while (cond)  statements... if (cond1) break`\n",
    "\n",
    "- `while (cond)  statements... if (cond1) continue`\n",
    "\n",
    "参数：`cond` - 支持`bool`类型的变量，也支持类型为`Number`、`List`、`Tuple`、`Dict`、`String`类型的常量以及`None`对象。\n",
    "\n",
    "限制：\n",
    "\n",
    "- 如果`cond`不为常量，在循环体内外同一符号被赋值的变量或者常量的数据类型应一致。如果是被赋予数据类型`Tensor`，则要求`Tensor`的type和shape也应一致。\n",
    "\n",
    "- 不支持`while...else...`语句。\n",
    "\n",
    "- 图模式中，要求变量必须在使用前定义。while循环后要使用的变量，在控制流内部定义，外部使用将会报错。例如示例3。\n",
    "\n",
    "示例1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a689cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:1\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "m = 1\n",
    "n = 2\n",
    "\n",
    "@mindspore.jit\n",
    "def test_cond(x, y):\n",
    "    while x < y:\n",
    "        x += 1\n",
    "        return m\n",
    "    return n\n",
    "\n",
    "ret = test_cond(1, 5)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "859fc7b9",
   "metadata": {},
   "source": [
    "`while`循环内返回的`m`和`while`外返回的`n`数据类型必须一致。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e692d457",
   "metadata": {},
   "source": [
    "示例2："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70a6364",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:15\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "m = 1\n",
    "n = 2\n",
    "\n",
    "def ops1(a, b):\n",
    "    return a + b\n",
    "\n",
    "@mindspore.jit\n",
    "def test_cond(x, y):\n",
    "    out = m\n",
    "    while x < y:\n",
    "        x += 1\n",
    "        out = ops1(out, x)\n",
    "    return out\n",
    "\n",
    "ret = test_cond(1, 5)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a4b0e80",
   "metadata": {},
   "source": [
    "`while`内，`out`在循环体内被赋值的变量`ops1`的输出类型和初始类型`m`必须一致。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b94b195a",
   "metadata": {},
   "source": [
    "示例3：",
    "\n",
    "将报错：UnboundLocalError: The local variable 'z' defined in the 'while' loop body cannot be used outside of the loop body. Please define variable 'z' before 'while'.\n",
    "\n",
    "```python\n",
    "import mindspore as ms\n",
    "\n",
    "@ms.jit()\n",
    "def test_cond(x, y):\n",
    "    while x < y:\n",
    "        x += 1\n",
    "        z = x + y\n",
    "    return z\n",
    "\n",
    "ret = test_cond(1, 5)\n",
    "print('ret:{}'.format(ret))\n",
    "```\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277d411d",
   "metadata": {},
   "source": [
    "### 函数定义语句\n",
    "\n",
    "#### def关键字\n",
    "\n",
    "`def`用于定义函数，后接函数标识符名称和圆括号`()`，括号中可以包含函数的参数。\n",
    "使用方式：`def function_name(args): statements...`。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1839f820",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:6\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "def number_add(x, y):\n",
    "    return x + y\n",
    "\n",
    "@mindspore.jit\n",
    "def test(x, y):\n",
    "    return number_add(x, y)\n",
    "\n",
    "ret = test(1, 5)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da3ca19c",
   "metadata": {},
   "source": [
    "说明：\n",
    "\n",
    "- 函数可以支持不写返回值，不写返回值默认函数的返回值为None。\n",
    "- 支持最外层网络模型的`construct`函数和内层网络函数输入kwargs，即支持 `def construct(**kwargs):`。\n",
    "- 支持变参和非变参的混合使用，即支持 `def function(x, y, *args):`和 `def function(x = 1, y = 1, **kwargs):`。\n",
    "\n",
    "#### lambda表达式\n",
    "\n",
    "`lambda`表达式用于生成匿名函数。与普通函数不同，它只计算并返回一个表达式。使用方式：`lambda x, y: x + y`。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bdffe97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:6\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test(x, y):\n",
    "    number_add = lambda x, y: x + y\n",
    "    return number_add(x, y)\n",
    "\n",
    "ret = test(1, 5)\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0313862b",
   "metadata": {},
   "source": [
    "#### 偏函数partial\n",
    "\n",
    "功能：偏函数，固定函数入参。使用方式：`partial(func, arg, ...)`。\n",
    "\n",
    "入参：\n",
    "\n",
    "- `func` - 函数。\n",
    "\n",
    "- `arg` - 一个或多个要固定的参数，支持位置参数和键值对传参。\n",
    "\n",
    "返回值：返回某些入参固定了值的函数。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84ebe8bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m:5\n",
      "n:7\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import ops\n",
    "\n",
    "def add(x, y):\n",
    "    return x + y\n",
    "\n",
    "@mindspore.jit\n",
    "def test():\n",
    "    add_ = ops.partial(add, x=2)\n",
    "    m = add_(y=3)\n",
    "    n = add_(y=5)\n",
    "    return m, n\n",
    "\n",
    "m, n = test()\n",
    "print('m:{}'.format(m))\n",
    "print('n:{}'.format(n))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b76592f9",
   "metadata": {},
   "source": [
    "#### 函数参数\n",
    "\n",
    "- 参数默认值：目前不支持默认值设为`Tensor`类型数据，支持`int`、`float`、`bool`、`None`、`str`、`tuple`、`list`、`dict`类型数据。\n",
    "- 可变参数：支持带可变参数网络的推理和训练。\n",
    "- 键值对参数：目前不支持带键值对参数的函数求反向。\n",
    "- 可变键值对参数：目前不支持带可变键值对的函数求反向。\n",
    "\n",
    "### 列表生成式和生成器表达式\n",
    "\n",
    "支持列表生成式（List Comprehension）、字典生成式（Dict Comprehension）和生成器表达式（Generator Expression）。支持构建一个新的序列。\n",
    "\n",
    "#### 列表生成式\n",
    "\n",
    "列表生成式用于生成列表。使用方式：`[arg for loop if statements]`。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5e4a2d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:[4, 16, 36, 64, 100]\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test():\n",
    "    l = [x * x for x in range(1, 11) if x % 2 == 0]\n",
    "    return l\n",
    "\n",
    "ret = test()\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e004891",
   "metadata": {},
   "source": [
    "限制：\n",
    "\n",
    "图模式下不支持多层嵌套迭代器的使用方式。\n",
    "\n",
    "限制用法示例如下（使用了两层迭代器）：\n",
    "\n",
    "```python\n",
    "l = [y for x in ((1, 2), (3, 4), (5, 6)) for y in x]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4630086d",
   "metadata": {},
   "source": "会提示错误：\n\nTypeError: The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got 2 comprehensions."
  },
  {
   "cell_type": "markdown",
   "id": "7a449a28",
   "metadata": {},
   "source": [
    "#### 字典生成式\n",
    "\n",
    "字典生成式用于生成字典。使用方式：`{key, value for loop if statements}`。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c604b7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:{'b': 2, 'c': 3}\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test():\n",
    "    x = [('a', 1), ('b', 2), ('c', 3)]\n",
    "    res = {k: v for (k, v) in x if v > 1}\n",
    "    return res\n",
    "\n",
    "ret = test()\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "825328b2",
   "metadata": {},
   "source": [
    "限制：\n",
    "\n",
    "图模式下不支持多层嵌套迭代器的使用方式。\n",
    "\n",
    "限制用法示例如下（使用了两层迭代器）："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2858c9cf",
   "metadata": {},
   "source": [
    "```python\n",
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test():\n",
    "    x = ({'a': 1, 'b': 2}, {'d': 1, 'e': 2}, {'g': 1, 'h': 2})\n",
    "    res = {k: v for y in x for (k, v) in y.items()}\n",
    "    return res\n",
    "\n",
    "ret = test()\n",
    "print('ret:{}'.format(ret))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1393a338",
   "metadata": {},
   "source": "会提示错误：\n\nTypeError: The 'generators' supports 1 'comprehension' in DictComp/GeneratorExp, but got 2 comprehensions."
  },
  {
   "cell_type": "markdown",
   "id": "498fef2f",
   "metadata": {},
   "source": [
    "#### 生成器表达式\n",
    "\n",
    "生成器表达式用于生成列表。使用方式：`(arg for loop if statements)`。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2e02d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ret:[4, 16, 36, 64, 100]\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "\n",
    "@mindspore.jit\n",
    "def test():\n",
    "    l = (x * x for x in range(1, 11) if x % 2 == 0)\n",
    "    return l\n",
    "\n",
    "ret = test()\n",
    "print('ret:{}'.format(ret))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a709cb5",
   "metadata": {},
   "source": [
    "使用限制同列表生成式。即：图模式下不支持多层嵌套迭代器的使用方式。\n",
    "\n",
    "### with语句\n",
    "\n",
    "在图模式下，有限制地支持`with`语句。`with`语句要求对象必须有两个魔术方法：`__enter__()`和`__exit__()`。\n",
    "\n",
    "值得注意的是，with语句中使用的类需要有装饰器@ms.jit_class修饰或者继承于nn.Cell，更多介绍可见[使用jit_class](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph_expert_programming.html#使用jit-class)。\n",
    "\n",
    "示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7cb1da85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "out1: [5]\n",
      "out2: [2]\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "\n",
    "@mindspore.jit_class\n",
    "class Sample:\n",
    "    def __init__(self):\n",
    "        super(Sample, self).__init__()\n",
    "        self.num = mindspore.tensor([2])\n",
    "\n",
    "    def __enter__(self):\n",
    "        return self.num * 2\n",
    "\n",
    "    def __exit__(self, exc_type, exc_value, traceback):\n",
    "        return self.num * 4\n",
    "\n",
    "class TestNet(nn.Cell):\n",
    "    @mindspore.jit\n",
    "    def construct(self):\n",
    "        res = 1\n",
    "        obj = Sample()\n",
    "        with obj as sample:\n",
    "            res += sample\n",
    "        return res, obj.num\n",
    "\n",
    "test_net = TestNet()\n",
    "out1, out2 = test_net()\n",
    "print(\"out1:\", out1)\n",
    "print(\"out2:\", out2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}