#!/usr/bin/env python
# coding: utf-8

# # 图模式语法-python语句
# 
# [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_statements.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/compile/mindspore_statements.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_source.svg)](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/compile/statements.ipynb)
# 
# ## 简单语句
# 
# ### raise语句
# 
# 支持使用`raise`触发异常。`raise`语法格式：`raise[Exception [, args]]`。语句中的`Exception`是异常的类型，`args`是用户提供的异常参数，通常可以是字符串或者其他对象。
# 
# 目前支持的异常类型有：NoExceptionType、UnknownError、ArgumentError、NotSupportError、NotExistsError、DeviceProcessError、AbortedError、IndexError、ValueError、TypeError、KeyError、AttributeError、NameError、AssertionError、BaseException、KeyboardInterrupt、Exception、StopIteration、OverflowError、ZeroDivisionError、EnvironmentError、IOError、OSError、ImportError、MemoryError、UnboundLocalError、RuntimeError、NotImplementedError、IndentationError、RuntimeWarning。
# 
# 图模式下的raise语法不支持`Dict`类型的变量。
# 
# 例如：

# ```python
# import mindspore
# from mindspore import nn
# 
# class Net(nn.Cell):
#     def __init__(self):
#         super(Net, self).__init__()
# 
#     @mindspore.jit
#     def construct(self, x, y):
#         if x <= y:
#             raise ValueError("x should be greater than y.")
#         else:
#             x += 1
#         return x
# 
# net = Net()
# net(mindspore.tensor(-2), mindspore.tensor(-1))
# ```

# 输出结果:
# 
# ValueError: x should be greater than y.

# ### assert语句
# 
# 支持使用assert来做异常检查，`assert`语法格式：`assert[Expression [, args]]`。其中`Expression`是判断条件，如果条件为真，就不做任何事情；条件为假时，则将抛出`AssertError`类型的异常信息。`args`是用户提供的异常参数，通常可以是字符串或者其他对象。

# ```python
# import mindspore
# from mindspore import nn
# 
# class Net(nn.Cell):
#     def __init__(self):
#         super(Net, self).__init__()
# 
#     @mindspore.jit
#     def construct(self, x):
#         assert x in [2, 3, 4]
#         return x
# 
# net = Net()
# net(mindspore.tensor(-1))
# ```

# 输出结果中正常出现:
# 
# AssertionError.

# ### pass语句
# 
# `pass`语句不做任何事情，通常用于占位，保持结构的完整性。例如：

# In[5]:


import mindspore
from mindspore import nn

class Net(nn.Cell):
    @mindspore.jit
    def construct(self, x):
        i = 0
        while i < 5:
            if i > 3:
                pass
            else:
                x = x * 1.5
            i += 1
        return x

net = Net()
ret = net(10)
print("ret:", ret)


# ### return语句
# 
# `return`语句通常是将结果返回调用的地方，`return`语句之后的语句不被执行。如果返回语句没有任何表达式，或者函数没有`return`语句，则默认返回一个`None`对象。一个函数体内可以根据不同的情况有多个`return`语句。例如：

# In[6]:


import mindspore
from mindspore import nn

class Net(nn.Cell):
    @mindspore.jit
    def construct(self, x):
        if x > 0:
            return x
        return 0

net = Net()
ret = net(10)
print("ret:", ret)


# 如上，在控制流场景语句中，可以有多个`return`语句。如果一个函数中没有`return`语句，则默认返回None对象，如下用例：

# In[7]:


import mindspore

mindspore.set_device("CPU")
@mindspore.jit
def foo():
    x = 3
    print("x:", x)

res = foo()
assert res is None


# ### break语句
# 
# `break`语句用来终止循环语句，即循环条件没有`False`条件或者序列还没完全递归完时，也会停止执行循环语句，通常用在`while`和`for`循环中。在嵌套循环中，`break`语句将停止执行最内层的循环。

# In[8]:


import mindspore
from mindspore import nn

class Net(nn.Cell):
    @mindspore.jit
    def construct(self, x):
        for i in range(8):
            if i > 5:
                x *= 3
                break
            x = x * 2
        return x

net = Net()
ret = net(10)
print("ret:", ret)


# ### continue语句
# 
# `continue`语句用来跳出当前的循环语句，进入下一轮的循环。与`break`语句有所不同，`break`语句用来终止整个循环语句。`continue`也用在`while`和`for`循环中。例如：

# In[9]:


import mindspore
from mindspore import nn

class Net(nn.Cell):
    @mindspore.jit
    def construct(self, x):
        for i in range(4):
            if i > 2:
                x *= 3
            continue
        return x

net = Net()
ret = net(3)
print("ret:", ret)


# ## 复合语句
# 
# ### 条件控制语句
# 
# #### if语句
# 
# 使用方式：
# 
# - `if (cond): statements...`
# 
# - `x = y if (cond) else z`
# 
# 参数：`cond` - 支持`bool`类型的变量，也支持类型为`Number`、`List`、`Tuple`、`Dict`、`String`类型的常量以及`None`对象。
# 
# 限制：
# 
# - 如果`cond`不为常量，在不同分支中同一符号被赋予的变量或者常量的数据类型应一致。如果是被赋予变量或者常量数据类型是`Tensor`，则要求`Tensor`的type和shape也应一致。
# 
# - 图模式中，要求变量必须在使用前定义。在控制流内部定义，外部使用将会报错。例如示例4。
# 
# 示例1：

# In[ ]:


import mindspore

x = mindspore.tensor([1, 4], mindspore.int32)
y = mindspore.tensor([0, 3], mindspore.int32)
m = 1
n = 2

@mindspore.jit
def test_if_cond(x, y):
    if (x > y).any():
        return m
    return n

ret = test_if_cond(x, y)
print('ret:{}'.format(ret))


# `if`分支返回的`m`和`else`分支返回的`n`，二者数据类型必须一致。
# 
# 

# 示例2：

# In[ ]:


import mindspore

x = mindspore.tensor([1, 4], mindspore.int32)
y = mindspore.tensor([0, 3], mindspore.int32)
m = 1
n = 2

@mindspore.jit
def test_if_cond(x, y):
    out = 3
    if (x > y).any():
        out = m
    else:
        out = n
    return out

ret = test_if_cond(x, y)
print('ret:{}'.format(ret))


# `if`分支中`out`被赋值的变量或者常量`m`与`else`分支中`out`被赋值的变量或者常量`n`的数据类型必须一致。

# 示例3：

# In[ ]:


import mindspore

x = mindspore.tensor([1, 4], mindspore.int32)
y = mindspore.tensor([0, 3], mindspore.int32)
m = 1

@mindspore.jit
def test_if_cond(x, y):
    out = 2
    if (x > y).any():
        out = m
    return out

ret = test_if_cond(x, y)
print('ret:{}'.format(ret))


# `if`分支中`out`被赋值的变量或者常量`m`与`out`初始赋值的数据类型必须一致。
# 
# 示例4：在控制流外部使用变量z，需要在外部定义。否则将报错：UnboundLocalError: The local variable 'z' is not defined in false branch, but defined in true branch.

# In[ ]:


import mindspore

x = mindspore.tensor([1, 4], mindspore.int32)
y = mindspore.tensor([0, 3], mindspore.int32)

@mindspore.jit
def test_if_cond(x, y):
    if (x > y).any():
        z = x + 1
    return z

ret = test_if_cond(x, y)
print('ret:{}'.format(ret))


# ### 循环语句
# 
# #### for语句
# 
# 使用方式：
# 
# - `for i in sequence  statements...`
# 
# - `for i in sequence  statements... if (cond) break`
# 
# - `for i in sequence  statements... if (cond) continue`
# 
# 参数：`sequence` - 遍历序列(`Tuple`、`List`、`range`等)
# 
# 限制：
# 
# - 图的算子数量和`for`循环的迭代次数成倍数关系。`for`循环迭代次数过大可能会导致图占用内存超过使用限制。
# 
# - 不支持`for...else...`语句。
# 
# - 图模式中，要求变量必须在使用前定义。for循环后要使用的变量，在控制流内部定义，外部使用将会报错。例如示例2。
# 
# 示例1：

# In[ ]:


import numpy as np
import mindspore

z = mindspore.tensor(np.ones((2, 3)))

@mindspore.jit
def test_cond():
    x = (1, 2, 3)
    for i in x:
        z += i
    return z

ret = test_cond()
print('ret:{}'.format(ret))


# 示例2：将报错：NameError: The name 'z' is not defined, or not supported in graph mode.

# In[ ]:


import mindspore

@mindspore.jit
def test_cond():
    x = (1, 2, 3)
    for i in x:
        z += i
    return z

ret = test_cond()
print('ret:{}'.format(ret))


# #### while语句
# 
# 使用方式：
# 
# - `while (cond)  statements...`
# 
# - `while (cond)  statements... if (cond1) break`
# 
# - `while (cond)  statements... if (cond1) continue`
# 
# 参数：`cond` - 支持`bool`类型的变量，也支持类型为`Number`、`List`、`Tuple`、`Dict`、`String`类型的常量以及`None`对象。
# 
# 限制：
# 
# - 如果`cond`不为常量，在循环体内外同一符号被赋值的变量或者常量的数据类型应一致。如果是被赋予数据类型`Tensor`，则要求`Tensor`的type和shape也应一致。
# 
# - 不支持`while...else...`语句。
# 
# - 图模式中，要求变量必须在使用前定义。while循环后要使用的变量，在控制流内部定义，外部使用将会报错。例如示例3。
# 
# 示例1：

# In[ ]:


import mindspore

m = 1
n = 2

@mindspore.jit
def test_cond(x, y):
    while x < y:
        x += 1
        return m
    return n

ret = test_cond(1, 5)
print('ret:{}'.format(ret))


# `while`循环内返回的`m`和`while`外返回的`n`数据类型必须一致。
# 
# 

# 示例2：

# In[ ]:


import mindspore

m = 1
n = 2

def ops1(a, b):
    return a + b

@mindspore.jit
def test_cond(x, y):
    out = m
    while x < y:
        x += 1
        out = ops1(out, x)
    return out

ret = test_cond(1, 5)
print('ret:{}'.format(ret))


# `while`内，`out`在循环体内被赋值的变量`ops1`的输出类型和初始类型`m`必须一致。
# 
# 

# 示例3：
# 将报错：UnboundLocalError: The local variable 'z' defined in the 'while' loop body cannot be used outside of the loop body. Please define variable 'z' before 'while'.
# 
# ```python
# import mindspore as ms
# 
# @ms.jit()
# def test_cond(x, y):
#     while x < y:
#         x += 1
#         z = x + y
#     return z
# 
# ret = test_cond(1, 5)
# print('ret:{}'.format(ret))
# ```
# 
# 

# ### 函数定义语句
# 
# #### def关键字
# 
# `def`用于定义函数，后接函数标识符名称和圆括号`()`，括号中可以包含函数的参数。
# 使用方式：`def function_name(args): statements...`。
# 
# 示例如下：

# In[ ]:


import mindspore

def number_add(x, y):
    return x + y

@mindspore.jit
def test(x, y):
    return number_add(x, y)

ret = test(1, 5)
print('ret:{}'.format(ret))


# 说明：
# 
# - 函数可以支持不写返回值，不写返回值默认函数的返回值为None。
# - 支持最外层网络模型的`construct`函数和内层网络函数输入kwargs，即支持 `def construct(**kwargs):`。
# - 支持变参和非变参的混合使用，即支持 `def function(x, y, *args):`和 `def function(x = 1, y = 1, **kwargs):`。
# 
# #### lambda表达式
# 
# `lambda`表达式用于生成匿名函数。与普通函数不同，它只计算并返回一个表达式。使用方式：`lambda x, y: x + y`。
# 
# 示例如下：

# In[ ]:


import mindspore

@mindspore.jit
def test(x, y):
    number_add = lambda x, y: x + y
    return number_add(x, y)

ret = test(1, 5)
print('ret:{}'.format(ret))


# #### 偏函数partial
# 
# 功能：偏函数，固定函数入参。使用方式：`partial(func, arg, ...)`。
# 
# 入参：
# 
# - `func` - 函数。
# 
# - `arg` - 一个或多个要固定的参数，支持位置参数和键值对传参。
# 
# 返回值：返回某些入参固定了值的函数。
# 
# 示例如下：

# In[ ]:


import mindspore
from mindspore import ops

def add(x, y):
    return x + y

@mindspore.jit
def test():
    add_ = ops.partial(add, x=2)
    m = add_(y=3)
    n = add_(y=5)
    return m, n

m, n = test()
print('m:{}'.format(m))
print('n:{}'.format(n))


# #### 函数参数
# 
# - 参数默认值：目前不支持默认值设为`Tensor`类型数据，支持`int`、`float`、`bool`、`None`、`str`、`tuple`、`list`、`dict`类型数据。
# - 可变参数：支持带可变参数网络的推理和训练。
# - 键值对参数：目前不支持带键值对参数的函数求反向。
# - 可变键值对参数：目前不支持带可变键值对的函数求反向。
# 
# ### 列表生成式和生成器表达式
# 
# 支持列表生成式（List Comprehension）、字典生成式（Dict Comprehension）和生成器表达式（Generator Expression）。支持构建一个新的序列。
# 
# #### 列表生成式
# 
# 列表生成式用于生成列表。使用方式：`[arg for loop if statements]`。
# 
# 示例如下：

# In[ ]:


import mindspore

@mindspore.jit
def test():
    l = [x * x for x in range(1, 11) if x % 2 == 0]
    return l

ret = test()
print('ret:{}'.format(ret))


# 限制：
# 
# 图模式下不支持多层嵌套迭代器的使用方式。
# 
# 限制用法示例如下（使用了两层迭代器）：
# 
# ```python
# l = [y for x in ((1, 2), (3, 4), (5, 6)) for y in x]
# ```

# 会提示错误：
# 
# TypeError: The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got 2 comprehensions.

# #### 字典生成式
# 
# 字典生成式用于生成字典。使用方式：`{key, value for loop if statements}`。
# 
# 示例如下：

# In[ ]:


import mindspore

@mindspore.jit
def test():
    x = [('a', 1), ('b', 2), ('c', 3)]
    res = {k: v for (k, v) in x if v > 1}
    return res

ret = test()
print('ret:{}'.format(ret))


# 限制：
# 
# 图模式下不支持多层嵌套迭代器的使用方式。
# 
# 限制用法示例如下（使用了两层迭代器）：

# ```python
# import mindspore
# 
# @mindspore.jit
# def test():
#     x = ({'a': 1, 'b': 2}, {'d': 1, 'e': 2}, {'g': 1, 'h': 2})
#     res = {k: v for y in x for (k, v) in y.items()}
#     return res
# 
# ret = test()
# print('ret:{}'.format(ret))
# ```

# 会提示错误：
# 
# TypeError: The 'generators' supports 1 'comprehension' in DictComp/GeneratorExp, but got 2 comprehensions.

# #### 生成器表达式
# 
# 生成器表达式用于生成列表。使用方式：`(arg for loop if statements)`。
# 
# 示例如下：

# In[ ]:


import mindspore

@mindspore.jit
def test():
    l = (x * x for x in range(1, 11) if x % 2 == 0)
    return l

ret = test()
print('ret:{}'.format(ret))


# 使用限制同列表生成式。即：图模式下不支持多层嵌套迭代器的使用方式。
# 
# ### with语句
# 
# 在图模式下，有限制地支持`with`语句。`with`语句要求对象必须有两个魔术方法：`__enter__()`和`__exit__()`。
# 
# 值得注意的是，with语句中使用的类需要有装饰器@ms.jit_class修饰或者继承于nn.Cell，更多介绍可见[使用jit_class](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph_expert_programming.html#使用jit-class)。
# 
# 示例如下：

# In[24]:


import mindspore
from mindspore import nn

@mindspore.jit_class
class Sample:
    def __init__(self):
        super(Sample, self).__init__()
        self.num = mindspore.tensor([2])

    def __enter__(self):
        return self.num * 2

    def __exit__(self, exc_type, exc_value, traceback):
        return self.num * 4

class TestNet(nn.Cell):
    @mindspore.jit
    def construct(self):
        res = 1
        obj = Sample()
        with obj as sample:
            res += sample
        return res, obj.num

test_net = TestNet()
out1, out2 = test_net()
print("out1:", out1)
print("out2:", out2)

