#!/usr/bin/env python
# coding: utf-8

# [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/beginner/mindspore_accelerate_with_static_graph.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.9.0/tutorials/zh_cn/beginner/mindspore_accelerate_with_static_graph.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.9.0/resource/_static/logo_source.svg)](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/beginner/accelerate_with_static_graph.ipynb)
# 
# [基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/tensor.html) || [数据加载与处理](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/dataset.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/model.html) || [函数式自动微分](https://atomgit.com/mindspore/docs/blob/r2.9.0/tutorials/source_zh_cn/beginner/autograd.ipynb) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/train.html) || [保存与加载](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/save_load.html) || **Graph Mode加速** ||
# 
# # Graph Mode加速
# 
# ## 背景介绍
# 
# AI编译框架有两种运行模式：动态图模式和静态图模式。MindSpore默认情况下是以动态图模式运行，但也支持手动切换为静态图模式。两种运行模式的详细介绍如下：
# 
# ### 动态图模式
# 
# 动态图的特点是计算图的构建和计算同时发生（Define by run），符合Python的解释执行方式。在计算图中定义一个Tensor时，其值就已经被计算且确定，因此在调试模型时较为方便，能够实时得到中间结果的值。但由于所有节点都需要被保存，导致难以对整个计算图进行优化。
# 
# 在MindSpore中，动态图模式又被称为PyNative模式。由于动态图的解释执行特性，在脚本开发和网络流程调试过程中，推荐使用动态图模式。
# 
# 如需要手动控制框架采用PyNative模式，可以通过以下代码进行网络构建：

# In[1]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.PYNATIVE_MODE)  # 使用set_context进行动态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)


# ### 静态图模式
# 
# 相较于动态图而言，静态图的特点是将计算图的构建和实际计算分开（Define and run）。有关静态图模式的运行原理，可以参考[静态图语法支持](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph.html#概述)。
# 
# 在MindSpore中，静态图模式又被称为Graph模式。在Graph模式下，基于图优化、计算图整图下沉等技术，编译器可以针对图进行全局的优化，获得较好的性能，因此比较适合网络固定且需要高性能的场景。
# 
# 如需手动控制框架采用静态图模式，可以通过以下代码进行网络构建：

# In[2]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)


# ## 静态图模式的使用场景
# 
# MindSpore编译器重点面向Tensor数据的计算以及其微分处理。因此，使用MindSpore API以及基于Tensor对象的操作，更适合使用静态图编译优化。其他操作虽然可以部分入图编译，但实际优化作用有限。另外，静态图模式采用先编译后执行，存在编译耗时。如果函数无需反复执行，那么Graph Mode加速也可能没有价值。
# 
# 有关使用静态图来进行网络编译的示例，请参考[网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/beginner/model.html)。
# 
# ## 静态图模式开启方式
# 
# 通常情况下，由于动态图的灵活性，我们会选择使用PyNative模式来进行自由的神经网络构建，以实现模型的创新和优化。但是当需要进行性能加速时，可以对神经网络部分或整体进行加速。MindSpore提供了两种切换为静态图模式的方式：基于装饰器的开启方式以及基于全局context的开启方式。
# 
# ### 基于装饰器的开启方式
# 
# MindSpore提供了jit装饰器，可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图，并通过图优化等技术提高运行速度。此时，可以对想要进行性能优化的模块进行图编译加速，而模型其他部分，仍旧使用解释执行方式，不丢失动态图的灵活性。无论全局context是设置成静态图模式还是动态图模式，被jit修饰的部分始终会以静态图模式进行运行。
# 
# 在需要对Tensor的某些运算进行编译加速时，可以在其定义的函数上使用jit修饰器，在调用该函数时，该模块自动被编译为静态图。需要注意的是，jit装饰器只能用来修饰函数，无法对类进行修饰。jit的使用示例如下：

# In[3]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))

@ms.jit  # 使用ms.jit装饰器，使被装饰的函数以静态图模式运行
def run(x):
    model = Network()
    return model(x)

output = run(input)
print(output)


# 除使用修饰器外，也可使用函数变换方式调用jit方法，示例如下：

# In[4]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))

def run(x):
    model = Network()
    return model(x)

run_with_jit = ms.jit(run)  # 通过调用jit将函数转换为以静态图方式执行
output = run_with_jit(input)
print(output)


# 当我们需要对神经网络的某部分进行加速时，可以直接在construct方法上使用jit修饰器，在调用实例化对象时，该模块自动被编译为静态图。示例如下：

# In[5]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    @ms.jit  # 使用ms.jit装饰器，使被装饰的函数以静态图模式运行
    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
model = Network()
output = model(input)
print(output)


# ### 基于context的开启方式
# 
# context模式是一种全局的设置模式。代码示例如下：

# In[6]:


import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)


# ## 静态图的语法约束
# 
# 在Graph模式下，Python代码并不会由Python解释器去执行，而是先编译成静态计算图，再执行该静态计算图。因此，编译器无法支持全量的Python语法。MindSpore的静态图编译器支持Python常用语法子集，以支持神经网络的构建及训练。详情可参考[静态图语法支持](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph.html)。
# 
# ## 静态图高级编程技巧
# 
# 使用静态图高级编程技巧，可以有效地提高编译和执行效率，使程序运行更加稳定。详情可参考[静态图高级编程技巧](https://www.mindspore.cn/tutorials/zh-CN/r2.9.0/compile/static_graph_expert_programming.html)。
