forked from docs/doc-exports
Reviewed-by: Hasko, Vladimir <vladimir.hasko@t-systems.com> Co-authored-by: Lai, Weijian <laiweijian4@huawei.com> Co-committed-by: Lai, Weijian <laiweijian4@huawei.com>
746 lines
71 KiB
HTML
746 lines
71 KiB
HTML
<a name="EN-US_TOPIC_0000001943974109"></a><a name="EN-US_TOPIC_0000001943974109"></a>
|
|
|
|
<h1 class="topictitle1">TensorFlow</h1>
|
|
<div id="body8662426"><p id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_p8060118">There are two types of TensorFlow APIs, Keras and tf. They use different code for training and saving models, but the same code for inference.</p>
|
|
<div class="section" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_section1337018527168"><h4 class="sectiontitle">Training a Model (Keras API)</h4><div class="codecoloring" codetype="Python" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_screen825024713179"><div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
|
<span class="normal"> 2</span>
|
|
<span class="normal"> 3</span>
|
|
<span class="normal"> 4</span>
|
|
<span class="normal"> 5</span>
|
|
<span class="normal"> 6</span>
|
|
<span class="normal"> 7</span>
|
|
<span class="normal"> 8</span>
|
|
<span class="normal"> 9</span>
|
|
<span class="normal">10</span>
|
|
<span class="normal">11</span>
|
|
<span class="normal">12</span>
|
|
<span class="normal">13</span>
|
|
<span class="normal">14</span>
|
|
<span class="normal">15</span>
|
|
<span class="normal">16</span>
|
|
<span class="normal">17</span>
|
|
<span class="normal">18</span>
|
|
<span class="normal">19</span>
|
|
<span class="normal">20</span>
|
|
<span class="normal">21</span>
|
|
<span class="normal">22</span>
|
|
<span class="normal">23</span>
|
|
<span class="normal">24</span>
|
|
<span class="normal">25</span>
|
|
<span class="normal">26</span>
|
|
<span class="normal">27</span>
|
|
<span class="normal">28</span>
|
|
<span class="normal">29</span>
|
|
<span class="normal">30</span>
|
|
<span class="normal">31</span>
|
|
<span class="normal">32</span>
|
|
<span class="normal">33</span>
|
|
<span class="normal">34</span>
|
|
<span class="normal">35</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">Sequential</span>
|
|
<span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">()</span>
|
|
<span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">Dense</span>
|
|
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
|
|
|
|
<span class="c1"># Import a training dataset.</span>
|
|
<span class="n">mnist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span>
|
|
<span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>
|
|
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">/</span> <span class="mf">255.0</span>
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">Dense</span>
|
|
<span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">Sequential</span>
|
|
<span class="kn">import</span> <span class="nn">keras</span>
|
|
<span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">Dense</span><span class="p">,</span> <span class="n">Activation</span><span class="p">,</span> <span class="n">Flatten</span><span class="p">,</span> <span class="n">Dropout</span>
|
|
|
|
<span class="c1"># Define a model network.</span>
|
|
<span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">()</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Flatten</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span><span class="mi">28</span><span class="p">)))</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">5120</span><span class="p">,</span><span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">))</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span>
|
|
|
|
<span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">))</span>
|
|
|
|
<span class="c1"># Define an optimizer and loss functions.</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span>
|
|
<span class="n">loss</span><span class="o">=</span><span class="s1">'sparse_categorical_crossentropy'</span><span class="p">,</span>
|
|
<span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s1">'accuracy'</span><span class="p">])</span>
|
|
|
|
<span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span>
|
|
<span class="c1"># Train the model.</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
|
<span class="c1"># Evaluate the model.</span>
|
|
<span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
|
|
</pre></div></td></tr></table></div>
|
|
|
|
</div>
|
|
</div>
|
|
<div class="section" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_section10626141011172"><h4 class="sectiontitle">Saving a Model (Keras API)</h4><div class="codecoloring" codetype="Python" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_screen925975910178"><div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
|
<span class="normal"> 2</span>
|
|
<span class="normal"> 3</span>
|
|
<span class="normal"> 4</span>
|
|
<span class="normal"> 5</span>
|
|
<span class="normal"> 6</span>
|
|
<span class="normal"> 7</span>
|
|
<span class="normal"> 8</span>
|
|
<span class="normal"> 9</span>
|
|
<span class="normal">10</span>
|
|
<span class="normal">11</span>
|
|
<span class="normal">12</span>
|
|
<span class="normal">13</span>
|
|
<span class="normal">14</span>
|
|
<span class="normal">15</span>
|
|
<span class="normal">16</span>
|
|
<span class="normal">17</span>
|
|
<span class="normal">18</span>
|
|
<span class="normal">19</span>
|
|
<span class="normal">20</span>
|
|
<span class="normal">21</span>
|
|
<span class="normal">22</span>
|
|
<span class="normal">23</span>
|
|
<span class="normal">24</span>
|
|
<span class="normal">25</span>
|
|
<span class="normal">26</span>
|
|
<span class="normal">27</span>
|
|
<span class="normal">28</span>
|
|
<span class="normal">29</span>
|
|
<span class="normal">30</span>
|
|
<span class="normal">31</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">backend</span> <span class="k">as</span> <span class="n">K</span>
|
|
|
|
<span class="c1"># K.get_session().run(tf.global_variables_initializer())</span>
|
|
|
|
<span class="c1"># Define the inputs and outputs of the prediction API.</span>
|
|
<span class="c1"># The key values of the inputs and outputs dictionaries are used as the index keys for the input and output tensors of the model.</span>
|
|
<span class="c1"># The input and output definitions of the model must match the custom inference script.</span>
|
|
<span class="n">predict_signature</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">signature_def_utils</span><span class="o">.</span><span class="n">predict_signature_def</span><span class="p">(</span>
|
|
<span class="n">inputs</span><span class="o">=</span><span class="p">{</span><span class="s2">"images"</span> <span class="p">:</span> <span class="n">model</span><span class="o">.</span><span class="n">input</span><span class="p">},</span>
|
|
<span class="n">outputs</span><span class="o">=</span><span class="p">{</span><span class="s2">"scores"</span> <span class="p">:</span> <span class="n">model</span><span class="o">.</span><span class="n">output</span><span class="p">}</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="c1"># Define a save path.</span>
|
|
<span class="n">builder</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">SavedModelBuilder</span><span class="p">(</span><span class="s1">'./mnist_keras/'</span><span class="p">)</span>
|
|
|
|
<span class="n">builder</span><span class="o">.</span><span class="n">add_meta_graph_and_variables</span><span class="p">(</span>
|
|
|
|
<span class="n">sess</span> <span class="o">=</span> <span class="n">K</span><span class="o">.</span><span class="n">get_session</span><span class="p">(),</span>
|
|
<span class="c1"># The tf.saved_model.tag_constants.SERVING tag needs to be defined for inference and deployment.</span>
|
|
<span class="n">tags</span><span class="o">=</span><span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">tag_constants</span><span class="o">.</span><span class="n">SERVING</span><span class="p">],</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> signature_def_map: Only single items can exist, or the corresponding key needs to be defined as follows:</span>
|
|
<span class="sd"> tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">signature_def_map</span><span class="o">=</span><span class="p">{</span>
|
|
<span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">signature_constants</span><span class="o">.</span><span class="n">DEFAULT_SERVING_SIGNATURE_DEF_KEY</span><span class="p">:</span>
|
|
<span class="n">predict_signature</span>
|
|
<span class="p">}</span>
|
|
|
|
<span class="p">)</span>
|
|
<span class="n">builder</span><span class="o">.</span><span class="n">save</span><span class="p">()</span>
|
|
</pre></div></td></tr></table></div>
|
|
|
|
</div>
|
|
</div>
|
|
<div class="section" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_section5263321161712"><h4 class="sectiontitle">Training a Model (tf API)</h4><div class="codecoloring" codetype="Python" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_screen1567417138187"><div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
|
<span class="normal"> 2</span>
|
|
<span class="normal"> 3</span>
|
|
<span class="normal"> 4</span>
|
|
<span class="normal"> 5</span>
|
|
<span class="normal"> 6</span>
|
|
<span class="normal"> 7</span>
|
|
<span class="normal"> 8</span>
|
|
<span class="normal"> 9</span>
|
|
<span class="normal"> 10</span>
|
|
<span class="normal"> 11</span>
|
|
<span class="normal"> 12</span>
|
|
<span class="normal"> 13</span>
|
|
<span class="normal"> 14</span>
|
|
<span class="normal"> 15</span>
|
|
<span class="normal"> 16</span>
|
|
<span class="normal"> 17</span>
|
|
<span class="normal"> 18</span>
|
|
<span class="normal"> 19</span>
|
|
<span class="normal"> 20</span>
|
|
<span class="normal"> 21</span>
|
|
<span class="normal"> 22</span>
|
|
<span class="normal"> 23</span>
|
|
<span class="normal"> 24</span>
|
|
<span class="normal"> 25</span>
|
|
<span class="normal"> 26</span>
|
|
<span class="normal"> 27</span>
|
|
<span class="normal"> 28</span>
|
|
<span class="normal"> 29</span>
|
|
<span class="normal"> 30</span>
|
|
<span class="normal"> 31</span>
|
|
<span class="normal"> 32</span>
|
|
<span class="normal"> 33</span>
|
|
<span class="normal"> 34</span>
|
|
<span class="normal"> 35</span>
|
|
<span class="normal"> 36</span>
|
|
<span class="normal"> 37</span>
|
|
<span class="normal"> 38</span>
|
|
<span class="normal"> 39</span>
|
|
<span class="normal"> 40</span>
|
|
<span class="normal"> 41</span>
|
|
<span class="normal"> 42</span>
|
|
<span class="normal"> 43</span>
|
|
<span class="normal"> 44</span>
|
|
<span class="normal"> 45</span>
|
|
<span class="normal"> 46</span>
|
|
<span class="normal"> 47</span>
|
|
<span class="normal"> 48</span>
|
|
<span class="normal"> 49</span>
|
|
<span class="normal"> 50</span>
|
|
<span class="normal"> 51</span>
|
|
<span class="normal"> 52</span>
|
|
<span class="normal"> 53</span>
|
|
<span class="normal"> 54</span>
|
|
<span class="normal"> 55</span>
|
|
<span class="normal"> 56</span>
|
|
<span class="normal"> 57</span>
|
|
<span class="normal"> 58</span>
|
|
<span class="normal"> 59</span>
|
|
<span class="normal"> 60</span>
|
|
<span class="normal"> 61</span>
|
|
<span class="normal"> 62</span>
|
|
<span class="normal"> 63</span>
|
|
<span class="normal"> 64</span>
|
|
<span class="normal"> 65</span>
|
|
<span class="normal"> 66</span>
|
|
<span class="normal"> 67</span>
|
|
<span class="normal"> 68</span>
|
|
<span class="normal"> 69</span>
|
|
<span class="normal"> 70</span>
|
|
<span class="normal"> 71</span>
|
|
<span class="normal"> 72</span>
|
|
<span class="normal"> 73</span>
|
|
<span class="normal"> 74</span>
|
|
<span class="normal"> 75</span>
|
|
<span class="normal"> 76</span>
|
|
<span class="normal"> 77</span>
|
|
<span class="normal"> 78</span>
|
|
<span class="normal"> 79</span>
|
|
<span class="normal"> 80</span>
|
|
<span class="normal"> 81</span>
|
|
<span class="normal"> 82</span>
|
|
<span class="normal"> 83</span>
|
|
<span class="normal"> 84</span>
|
|
<span class="normal"> 85</span>
|
|
<span class="normal"> 86</span>
|
|
<span class="normal"> 87</span>
|
|
<span class="normal"> 88</span>
|
|
<span class="normal"> 89</span>
|
|
<span class="normal"> 90</span>
|
|
<span class="normal"> 91</span>
|
|
<span class="normal"> 92</span>
|
|
<span class="normal"> 93</span>
|
|
<span class="normal"> 94</span>
|
|
<span class="normal"> 95</span>
|
|
<span class="normal"> 96</span>
|
|
<span class="normal"> 97</span>
|
|
<span class="normal"> 98</span>
|
|
<span class="normal"> 99</span>
|
|
<span class="normal">100</span>
|
|
<span class="normal">101</span>
|
|
<span class="normal">102</span>
|
|
<span class="normal">103</span>
|
|
<span class="normal">104</span>
|
|
<span class="normal">105</span>
|
|
<span class="normal">106</span>
|
|
<span class="normal">107</span>
|
|
<span class="normal">108</span>
|
|
<span class="normal">109</span>
|
|
<span class="normal">110</span>
|
|
<span class="normal">111</span>
|
|
<span class="normal">112</span>
|
|
<span class="normal">113</span>
|
|
<span class="normal">114</span>
|
|
<span class="normal">115</span>
|
|
<span class="normal">116</span>
|
|
<span class="normal">117</span>
|
|
<span class="normal">118</span>
|
|
<span class="normal">119</span>
|
|
<span class="normal">120</span>
|
|
<span class="normal">121</span>
|
|
<span class="normal">122</span>
|
|
<span class="normal">123</span>
|
|
<span class="normal">124</span>
|
|
<span class="normal">125</span>
|
|
<span class="normal">126</span>
|
|
<span class="normal">127</span>
|
|
<span class="normal">128</span>
|
|
<span class="normal">129</span>
|
|
<span class="normal">130</span>
|
|
<span class="normal">131</span>
|
|
<span class="normal">132</span>
|
|
<span class="normal">133</span>
|
|
<span class="normal">134</span>
|
|
<span class="normal">135</span>
|
|
<span class="normal">136</span>
|
|
<span class="normal">137</span>
|
|
<span class="normal">138</span>
|
|
<span class="normal">139</span>
|
|
<span class="normal">140</span>
|
|
<span class="normal">141</span>
|
|
<span class="normal">142</span>
|
|
<span class="normal">143</span>
|
|
<span class="normal">144</span>
|
|
<span class="normal">145</span>
|
|
<span class="normal">146</span>
|
|
<span class="normal">147</span>
|
|
<span class="normal">148</span>
|
|
<span class="normal">149</span>
|
|
<span class="normal">150</span>
|
|
<span class="normal">151</span>
|
|
<span class="normal">152</span>
|
|
<span class="normal">153</span>
|
|
<span class="normal">154</span>
|
|
<span class="normal">155</span>
|
|
<span class="normal">156</span>
|
|
<span class="normal">157</span>
|
|
<span class="normal">158</span>
|
|
<span class="normal">159</span>
|
|
<span class="normal">160</span>
|
|
<span class="normal">161</span>
|
|
<span class="normal">162</span>
|
|
<span class="normal">163</span>
|
|
<span class="normal">164</span>
|
|
<span class="normal">165</span>
|
|
<span class="normal">166</span>
|
|
<span class="normal">167</span>
|
|
<span class="normal">168</span>
|
|
<span class="normal">169</span>
|
|
<span class="normal">170</span>
|
|
<span class="normal">171</span>
|
|
<span class="normal">172</span>
|
|
<span class="normal">173</span>
|
|
<span class="normal">174</span>
|
|
<span class="normal">175</span>
|
|
<span class="normal">176</span>
|
|
<span class="normal">177</span>
|
|
<span class="normal">178</span>
|
|
<span class="normal">179</span>
|
|
<span class="normal">180</span>
|
|
<span class="normal">181</span>
|
|
<span class="normal">182</span>
|
|
<span class="normal">183</span>
|
|
<span class="normal">184</span>
|
|
<span class="normal">185</span>
|
|
<span class="normal">186</span>
|
|
<span class="normal">187</span>
|
|
<span class="normal">188</span>
|
|
<span class="normal">189</span>
|
|
<span class="normal">190</span>
|
|
<span class="normal">191</span>
|
|
<span class="normal">192</span>
|
|
<span class="normal">193</span>
|
|
<span class="normal">194</span>
|
|
<span class="normal">195</span>
|
|
<span class="normal">196</span>
|
|
<span class="normal">197</span>
|
|
<span class="normal">198</span>
|
|
<span class="normal">199</span>
|
|
<span class="normal">200</span>
|
|
<span class="normal">201</span>
|
|
<span class="normal">202</span>
|
|
<span class="normal">203</span>
|
|
<span class="normal">204</span>
|
|
<span class="normal">205</span>
|
|
<span class="normal">206</span>
|
|
<span class="normal">207</span>
|
|
<span class="normal">208</span>
|
|
<span class="normal">209</span>
|
|
<span class="normal">210</span>
|
|
<span class="normal">211</span>
|
|
<span class="normal">212</span>
|
|
<span class="normal">213</span>
|
|
<span class="normal">214</span>
|
|
<span class="normal">215</span>
|
|
<span class="normal">216</span>
|
|
<span class="normal">217</span>
|
|
<span class="normal">218</span>
|
|
<span class="normal">219</span>
|
|
<span class="normal">220</span>
|
|
<span class="normal">221</span>
|
|
<span class="normal">222</span>
|
|
<span class="normal">223</span>
|
|
<span class="normal">224</span>
|
|
<span class="normal">225</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">gzip</span>
|
|
<span class="kn">import</span> <span class="nn">os</span>
|
|
<span class="kn">import</span> <span class="nn">urllib</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">numpy</span>
|
|
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
|
|
<span class="kn">from</span> <span class="nn">six.moves</span> <span class="kn">import</span> <span class="n">urllib</span>
|
|
|
|
<span class="c1"># Training data is obtained from the Yann LeCun official website http://yann.lecun.com/exdb/mnist/.</span>
|
|
<span class="n">SOURCE_URL</span> <span class="o">=</span> <span class="s1">'http://yann.lecun.com/exdb/mnist/'</span>
|
|
<span class="n">TRAIN_IMAGES</span> <span class="o">=</span> <span class="s1">'train-images-idx3-ubyte.gz'</span>
|
|
<span class="n">TRAIN_LABELS</span> <span class="o">=</span> <span class="s1">'train-labels-idx1-ubyte.gz'</span>
|
|
<span class="n">TEST_IMAGES</span> <span class="o">=</span> <span class="s1">'t10k-images-idx3-ubyte.gz'</span>
|
|
<span class="n">TEST_LABELS</span> <span class="o">=</span> <span class="s1">'t10k-labels-idx1-ubyte.gz'</span>
|
|
<span class="n">VALIDATION_SIZE</span> <span class="o">=</span> <span class="mi">5000</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">maybe_download</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">work_directory</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Download the data from Yann's website, unless it's already here."""</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">work_directory</span><span class="p">):</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">work_directory</span><span class="p">)</span>
|
|
<span class="n">filepath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">work_directory</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filepath</span><span class="p">):</span>
|
|
<span class="n">filepath</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">SOURCE_URL</span> <span class="o">+</span> <span class="n">filename</span><span class="p">,</span> <span class="n">filepath</span><span class="p">)</span>
|
|
<span class="n">statinfo</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">stat</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Successfully downloaded </span><span class="si">%s</span><span class="s1"> </span><span class="si">%d</span><span class="s1"> bytes.'</span> <span class="o">%</span> <span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">statinfo</span><span class="o">.</span><span class="n">st_size</span><span class="p">))</span>
|
|
<span class="k">return</span> <span class="n">filepath</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">):</span>
|
|
<span class="n">dt</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">uint32</span><span class="p">)</span><span class="o">.</span><span class="n">newbyteorder</span><span class="p">(</span><span class="s1">'>'</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">numpy</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">bytestream</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dt</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">extract_images</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Extracting </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">filename</span><span class="p">)</span>
|
|
<span class="k">with</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span> <span class="k">as</span> <span class="n">bytestream</span><span class="p">:</span>
|
|
<span class="n">magic</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">magic</span> <span class="o">!=</span> <span class="mi">2051</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
|
<span class="s1">'Invalid magic number </span><span class="si">%d</span><span class="s1"> in MNIST image file: </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span>
|
|
<span class="p">(</span><span class="n">magic</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
|
|
<span class="n">num_images</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="n">rows</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="n">cols</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="n">buf</span> <span class="o">=</span> <span class="n">bytestream</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">rows</span> <span class="o">*</span> <span class="n">cols</span> <span class="o">*</span> <span class="n">num_images</span><span class="p">)</span>
|
|
<span class="n">data</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">buf</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">numpy</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
|
|
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_images</span><span class="p">,</span> <span class="n">rows</span><span class="p">,</span> <span class="n">cols</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">data</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">dense_to_one_hot</span><span class="p">(</span><span class="n">labels_dense</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Convert class labels from scalars to one-hot vectors."""</span>
|
|
<span class="n">num_labels</span> <span class="o">=</span> <span class="n">labels_dense</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">index_offset</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_labels</span><span class="p">)</span> <span class="o">*</span> <span class="n">num_classes</span>
|
|
<span class="n">labels_one_hot</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_labels</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">))</span>
|
|
<span class="n">labels_one_hot</span><span class="o">.</span><span class="n">flat</span><span class="p">[</span><span class="n">index_offset</span> <span class="o">+</span> <span class="n">labels_dense</span><span class="o">.</span><span class="n">ravel</span><span class="p">()]</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="k">return</span> <span class="n">labels_one_hot</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">extract_labels</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Extract the labels into a 1D uint8 numpy array [index]."""</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Extracting </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">filename</span><span class="p">)</span>
|
|
<span class="k">with</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span> <span class="k">as</span> <span class="n">bytestream</span><span class="p">:</span>
|
|
<span class="n">magic</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">magic</span> <span class="o">!=</span> <span class="mi">2049</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
|
<span class="s1">'Invalid magic number </span><span class="si">%d</span><span class="s1"> in MNIST label file: </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span>
|
|
<span class="p">(</span><span class="n">magic</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
|
|
<span class="n">num_items</span> <span class="o">=</span> <span class="n">_read32</span><span class="p">(</span><span class="n">bytestream</span><span class="p">)</span>
|
|
<span class="n">buf</span> <span class="o">=</span> <span class="n">bytestream</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">num_items</span><span class="p">)</span>
|
|
<span class="n">labels</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">buf</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">numpy</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">one_hot</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">dense_to_one_hot</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">labels</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">DataSet</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Class encompassing test, validation and training MNIST data set."""</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Construct a DataSet. one_hot arg is used only if fake_data is true."""</span>
|
|
|
|
<span class="k">if</span> <span class="n">fake_data</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span> <span class="o">=</span> <span class="mi">10000</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">one_hot</span> <span class="o">=</span> <span class="n">one_hot</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">labels</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span>
|
|
<span class="s1">'images.shape: </span><span class="si">%s</span><span class="s1"> labels.shape: </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
|
<span class="n">labels</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
|
|
<span class="c1"># Convert shape from [num examples, rows, columns, depth]</span>
|
|
<span class="c1"># to [num examples, rows*columns] (assuming depth == 1)</span>
|
|
<span class="k">assert</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span>
|
|
<span class="n">images</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
|
<span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
|
|
<span class="c1"># Convert from [0, 255] -> [0.0, 1.0].</span>
|
|
<span class="n">images</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="n">images</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_images</span> <span class="o">=</span> <span class="n">images</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_labels</span> <span class="o">=</span> <span class="n">labels</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_epochs_completed</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span> <span class="o">=</span> <span class="mi">0</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">images</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_images</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">labels</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_labels</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">num_examples</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">epochs_completed</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_epochs_completed</span>
|
|
|
|
<span class="k">def</span> <span class="nf">next_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Return the next `batch_size` examples from this data set."""</span>
|
|
<span class="k">if</span> <span class="n">fake_data</span><span class="p">:</span>
|
|
<span class="n">fake_image</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mi">784</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">one_hot</span><span class="p">:</span>
|
|
<span class="n">fake_label</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mi">9</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">fake_label</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">return</span> <span class="p">[</span><span class="n">fake_image</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)],</span> <span class="p">[</span>
|
|
<span class="n">fake_label</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="n">start</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span> <span class="o">+=</span> <span class="n">batch_size</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span><span class="p">:</span>
|
|
<span class="c1"># Finished epoch</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_epochs_completed</span> <span class="o">+=</span> <span class="mi">1</span>
|
|
<span class="c1"># Shuffle the data</span>
|
|
<span class="n">perm</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span><span class="p">)</span>
|
|
<span class="n">numpy</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">perm</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_images</span><span class="p">[</span><span class="n">perm</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_labels</span><span class="p">[</span><span class="n">perm</span><span class="p">]</span>
|
|
<span class="c1"># Start next epoch</span>
|
|
<span class="n">start</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span> <span class="o">=</span> <span class="n">batch_size</span>
|
|
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_examples</span>
|
|
<span class="n">end</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_index_in_epoch</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_images</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">_labels</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">read_data_sets</span><span class="p">(</span><span class="n">train_dir</span><span class="p">,</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Return training, validation and testing data sets."""</span>
|
|
|
|
<span class="k">class</span> <span class="nc">DataSets</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="k">pass</span>
|
|
|
|
<span class="n">data_sets</span> <span class="o">=</span> <span class="n">DataSets</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="n">fake_data</span><span class="p">:</span>
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">train</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">([],</span> <span class="p">[],</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="n">one_hot</span><span class="p">)</span>
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">validation</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">([],</span> <span class="p">[],</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="n">one_hot</span><span class="p">)</span>
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">test</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">([],</span> <span class="p">[],</span> <span class="n">fake_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="n">one_hot</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">data_sets</span>
|
|
|
|
<span class="n">local_file</span> <span class="o">=</span> <span class="n">maybe_download</span><span class="p">(</span><span class="n">TRAIN_IMAGES</span><span class="p">,</span> <span class="n">train_dir</span><span class="p">)</span>
|
|
<span class="n">train_images</span> <span class="o">=</span> <span class="n">extract_images</span><span class="p">(</span><span class="n">local_file</span><span class="p">)</span>
|
|
|
|
<span class="n">local_file</span> <span class="o">=</span> <span class="n">maybe_download</span><span class="p">(</span><span class="n">TRAIN_LABELS</span><span class="p">,</span> <span class="n">train_dir</span><span class="p">)</span>
|
|
<span class="n">train_labels</span> <span class="o">=</span> <span class="n">extract_labels</span><span class="p">(</span><span class="n">local_file</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="n">one_hot</span><span class="p">)</span>
|
|
|
|
<span class="n">local_file</span> <span class="o">=</span> <span class="n">maybe_download</span><span class="p">(</span><span class="n">TEST_IMAGES</span><span class="p">,</span> <span class="n">train_dir</span><span class="p">)</span>
|
|
<span class="n">test_images</span> <span class="o">=</span> <span class="n">extract_images</span><span class="p">(</span><span class="n">local_file</span><span class="p">)</span>
|
|
|
|
<span class="n">local_file</span> <span class="o">=</span> <span class="n">maybe_download</span><span class="p">(</span><span class="n">TEST_LABELS</span><span class="p">,</span> <span class="n">train_dir</span><span class="p">)</span>
|
|
<span class="n">test_labels</span> <span class="o">=</span> <span class="n">extract_labels</span><span class="p">(</span><span class="n">local_file</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="n">one_hot</span><span class="p">)</span>
|
|
|
|
<span class="n">validation_images</span> <span class="o">=</span> <span class="n">train_images</span><span class="p">[:</span><span class="n">VALIDATION_SIZE</span><span class="p">]</span>
|
|
<span class="n">validation_labels</span> <span class="o">=</span> <span class="n">train_labels</span><span class="p">[:</span><span class="n">VALIDATION_SIZE</span><span class="p">]</span>
|
|
<span class="n">train_images</span> <span class="o">=</span> <span class="n">train_images</span><span class="p">[</span><span class="n">VALIDATION_SIZE</span><span class="p">:]</span>
|
|
<span class="n">train_labels</span> <span class="o">=</span> <span class="n">train_labels</span><span class="p">[</span><span class="n">VALIDATION_SIZE</span><span class="p">:]</span>
|
|
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">train</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">(</span><span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">)</span>
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">validation</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">(</span><span class="n">validation_images</span><span class="p">,</span> <span class="n">validation_labels</span><span class="p">)</span>
|
|
<span class="n">data_sets</span><span class="o">.</span><span class="n">test</span> <span class="o">=</span> <span class="n">DataSet</span><span class="p">(</span><span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">data_sets</span>
|
|
|
|
<span class="n">training_iteration</span> <span class="o">=</span> <span class="mi">1000</span>
|
|
|
|
<span class="n">modelarts_example_path</span> <span class="o">=</span> <span class="s1">'./modelarts-mnist-train-save-deploy-example'</span>
|
|
|
|
<span class="n">export_path</span> <span class="o">=</span> <span class="n">modelarts_example_path</span> <span class="o">+</span> <span class="s1">'/model/'</span>
|
|
<span class="n">data_path</span> <span class="o">=</span> <span class="s1">'./'</span>
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Training model...'</span><span class="p">)</span>
|
|
<span class="n">mnist</span> <span class="o">=</span> <span class="n">read_data_sets</span><span class="p">(</span><span class="n">data_path</span><span class="p">,</span> <span class="n">one_hot</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">InteractiveSession</span><span class="p">()</span>
|
|
<span class="n">serialized_tf_example</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">string</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'tf_example'</span><span class="p">)</span>
|
|
<span class="n">feature_configs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'x'</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">FixedLenFeature</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">784</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="p">}</span>
|
|
<span class="n">tf_example</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">parse_example</span><span class="p">(</span><span class="n">serialized_tf_example</span><span class="p">,</span> <span class="n">feature_configs</span><span class="p">)</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">identity</span><span class="p">(</span><span class="n">tf_example</span><span class="p">[</span><span class="s1">'x'</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s1">'x'</span><span class="p">)</span> <span class="c1"># use tf.identity() to assign name</span>
|
|
<span class="n">y_</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="s1">'float'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span>
|
|
<span class="n">w</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">]))</span>
|
|
<span class="n">b</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">10</span><span class="p">]))</span>
|
|
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">global_variables_initializer</span><span class="p">())</span>
|
|
<span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'y'</span><span class="p">)</span>
|
|
<span class="n">cross_entropy</span> <span class="o">=</span> <span class="o">-</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>
|
|
<span class="n">train_step</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">GradientDescentOptimizer</span><span class="p">(</span><span class="mf">0.01</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">cross_entropy</span><span class="p">)</span>
|
|
<span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">top_k</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
|
<span class="n">table</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">lookup</span><span class="o">.</span><span class="n">index_to_string_table_from_tensor</span><span class="p">(</span>
|
|
<span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">)]))</span>
|
|
<span class="n">prediction_classes</span> <span class="o">=</span> <span class="n">table</span><span class="o">.</span><span class="n">lookup</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">to_int64</span><span class="p">(</span><span class="n">indices</span><span class="p">))</span>
|
|
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">training_iteration</span><span class="p">):</span>
|
|
<span class="n">batch</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">next_batch</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span>
|
|
<span class="n">train_step</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span><span class="n">x</span><span class="p">:</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y_</span><span class="p">:</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]})</span>
|
|
<span class="n">correct_prediction</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">correct_prediction</span><span class="p">,</span> <span class="s1">'float'</span><span class="p">))</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'training accuracy </span><span class="si">%g</span><span class="s1">'</span> <span class="o">%</span> <span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span>
|
|
<span class="n">accuracy</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span>
|
|
<span class="n">x</span><span class="p">:</span> <span class="n">mnist</span><span class="o">.</span><span class="n">test</span><span class="o">.</span><span class="n">images</span><span class="p">,</span>
|
|
<span class="n">y_</span><span class="p">:</span> <span class="n">mnist</span><span class="o">.</span><span class="n">test</span><span class="o">.</span><span class="n">labels</span>
|
|
<span class="p">}))</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Done training!'</span><span class="p">)</span>
|
|
</pre></div></td></tr></table></div>
|
|
|
|
</div>
|
|
</div>
|
|
<div class="section" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_section66523281172"><h4 class="sectiontitle">Saving a Model (tf API)</h4><div class="codecoloring" codetype="Python" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_screen4380141941818"><div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
|
<span class="normal"> 2</span>
|
|
<span class="normal"> 3</span>
|
|
<span class="normal"> 4</span>
|
|
<span class="normal"> 5</span>
|
|
<span class="normal"> 6</span>
|
|
<span class="normal"> 7</span>
|
|
<span class="normal"> 8</span>
|
|
<span class="normal"> 9</span>
|
|
<span class="normal">10</span>
|
|
<span class="normal">11</span>
|
|
<span class="normal">12</span>
|
|
<span class="normal">13</span>
|
|
<span class="normal">14</span>
|
|
<span class="normal">15</span>
|
|
<span class="normal">16</span>
|
|
<span class="normal">17</span>
|
|
<span class="normal">18</span>
|
|
<span class="normal">19</span>
|
|
<span class="normal">20</span>
|
|
<span class="normal">21</span>
|
|
<span class="normal">22</span>
|
|
<span class="normal">23</span>
|
|
<span class="normal">24</span>
|
|
<span class="normal">25</span>
|
|
<span class="normal">26</span>
|
|
<span class="normal">27</span>
|
|
<span class="normal">28</span>
|
|
<span class="normal">29</span>
|
|
<span class="normal">30</span></pre></div></td><td class="code"><div><pre><span></span><span class="c1"># Export the model.</span>
|
|
<span class="c1"># The model needs to be saved using the saved_model API.</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Exporting trained model to'</span><span class="p">,</span> <span class="n">export_path</span><span class="p">)</span>
|
|
<span class="n">builder</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">SavedModelBuilder</span><span class="p">(</span><span class="n">export_path</span><span class="p">)</span>
|
|
|
|
<span class="n">tensor_info_x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">build_tensor_info</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
|
<span class="n">tensor_info_y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">build_tensor_info</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Define the inputs and outputs of the prediction API.</span>
|
|
<span class="c1"># The key values of the inputs and outputs dictionaries are used as the index keys for the input and output tensors of the model.</span>
|
|
<span class="c1"># The input and output definitions of the model must match the custom inference script.</span>
|
|
<span class="n">prediction_signature</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">signature_def_utils</span><span class="o">.</span><span class="n">build_signature_def</span><span class="p">(</span>
|
|
<span class="n">inputs</span><span class="o">=</span><span class="p">{</span><span class="s1">'images'</span><span class="p">:</span> <span class="n">tensor_info_x</span><span class="p">},</span>
|
|
<span class="n">outputs</span><span class="o">=</span><span class="p">{</span><span class="s1">'scores'</span><span class="p">:</span> <span class="n">tensor_info_y</span><span class="p">},</span>
|
|
<span class="n">method_name</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">signature_constants</span><span class="o">.</span><span class="n">PREDICT_METHOD_NAME</span><span class="p">))</span>
|
|
|
|
<span class="n">legacy_init_op</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">tables_initializer</span><span class="p">(),</span> <span class="n">name</span><span class="o">=</span><span class="s1">'legacy_init_op'</span><span class="p">)</span>
|
|
<span class="n">builder</span><span class="o">.</span><span class="n">add_meta_graph_and_variables</span><span class="p">(</span>
|
|
<span class="c1"># Set tag to serve/tf.saved_model.tag_constants.SERVING.</span>
|
|
<span class="n">sess</span><span class="p">,</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">tag_constants</span><span class="o">.</span><span class="n">SERVING</span><span class="p">],</span>
|
|
<span class="n">signature_def_map</span><span class="o">=</span><span class="p">{</span>
|
|
<span class="s1">'predict_images'</span><span class="p">:</span>
|
|
<span class="n">prediction_signature</span><span class="p">,</span>
|
|
<span class="p">},</span>
|
|
<span class="n">legacy_init_op</span><span class="o">=</span><span class="n">legacy_init_op</span><span class="p">)</span>
|
|
|
|
<span class="n">builder</span><span class="o">.</span><span class="n">save</span><span class="p">()</span>
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="s1">'Done exporting!'</span><span class="p">)</span>
|
|
</pre></div></td></tr></table></div>
|
|
|
|
</div>
|
|
</div>
|
|
<div class="section" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_section5321236111714"><h4 class="sectiontitle">Inference Code (Keras and tf APIs)</h4><p id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_p1311114917175">In the model inference code file <strong id="EN-US_TOPIC_0000001943974109__b123671717172616">customize_service.py</strong>, add a child model class which inherits properties from its parent model class. For details about the import statements of different types of parent model classes, see <a href="inference-modelarts-0057.html#EN-US_TOPIC_0000001910014882__en-us_topic_0172466150_table55021545175412">Table 1</a>.</p>
|
|
<div class="codecoloring" codetype="Python" id="EN-US_TOPIC_0000001943974109__en-us_topic_0196618241_screen12930832151820"><div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
|
<span class="normal"> 2</span>
|
|
<span class="normal"> 3</span>
|
|
<span class="normal"> 4</span>
|
|
<span class="normal"> 5</span>
|
|
<span class="normal"> 6</span>
|
|
<span class="normal"> 7</span>
|
|
<span class="normal"> 8</span>
|
|
<span class="normal"> 9</span>
|
|
<span class="normal">10</span>
|
|
<span class="normal">11</span>
|
|
<span class="normal">12</span>
|
|
<span class="normal">13</span>
|
|
<span class="normal">14</span>
|
|
<span class="normal">15</span>
|
|
<span class="normal">16</span>
|
|
<span class="normal">17</span>
|
|
<span class="normal">18</span>
|
|
<span class="normal">19</span>
|
|
<span class="normal">20</span>
|
|
<span class="normal">21</span>
|
|
<span class="normal">22</span>
|
|
<span class="normal">23</span>
|
|
<span class="normal">24</span>
|
|
<span class="normal">25</span>
|
|
<span class="normal">26</span>
|
|
<span class="normal">27</span>
|
|
<span class="normal">28</span>
|
|
<span class="normal">29</span>
|
|
<span class="normal">30</span>
|
|
<span class="normal">31</span>
|
|
<span class="normal">32</span>
|
|
<span class="normal">33</span>
|
|
<span class="normal">34</span>
|
|
<span class="normal">35</span>
|
|
<span class="normal">36</span>
|
|
<span class="normal">37</span>
|
|
<span class="normal">38</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
|
|
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
|
<span class="kn">from</span> <span class="nn">model_service.tfserving_model_service</span> <span class="kn">import</span> <span class="n">TfServingBaseService</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">MnistService</span><span class="p">(</span><span class="n">TfServingBaseService</span><span class="p">):</span>
|
|
|
|
<span class="c1"># Match the model input with the user's HTTPS API input during preprocessing.</span>
|
|
<span class="c1"># The model input corresponding to the preceding training part is {"images":<array>}.</span>
|
|
<span class="k">def</span> <span class="nf">_preprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
|
|
|
|
<span class="n">preprocessed_data</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="n">images</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="c1"># Iterate the input data.</span>
|
|
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="k">for</span> <span class="n">file_name</span><span class="p">,</span> <span class="n">file_content</span> <span class="ow">in</span> <span class="n">v</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="n">image1</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">file_content</span><span class="p">)</span>
|
|
<span class="n">image1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="n">image1</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span><span class="mi">784</span><span class="p">))</span>
|
|
<span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image1</span><span class="p">)</span>
|
|
<span class="c1"># Return the numpy array.</span>
|
|
<span class="n">images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">images</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="c1"># Perform batch processing on multiple input samples and ensure that the shape is the same as that inputted during training.</span>
|
|
<span class="n">images</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="mi">784</span><span class="p">))</span>
|
|
<span class="n">preprocessed_data</span><span class="p">[</span><span class="s1">'images'</span><span class="p">]</span> <span class="o">=</span> <span class="n">images</span>
|
|
<span class="k">return</span> <span class="n">preprocessed_data</span>
|
|
|
|
<span class="c1"># Processing logic of the inference for invoking the parent class.</span>
|
|
|
|
<span class="c1"># The output corresponding to model saving in the preceding training part is {"scores":<array>}.</span>
|
|
<span class="c1"># Postprocess the HTTPS output.</span>
|
|
<span class="k">def</span> <span class="nf">_postprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
|
|
<span class="n">infer_output</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"mnist_result"</span><span class="p">:</span> <span class="p">[]}</span>
|
|
<span class="c1"># Iterate the model output.</span>
|
|
<span class="k">for</span> <span class="n">output_name</span><span class="p">,</span> <span class="n">results</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="k">for</span> <span class="n">result</span> <span class="ow">in</span> <span class="n">results</span><span class="p">:</span>
|
|
<span class="n">infer_output</span><span class="p">[</span><span class="s2">"mnist_result"</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">result</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">result</span><span class="p">)))</span>
|
|
<span class="k">return</span> <span class="n">infer_output</span>
|
|
</pre></div></td></tr></table></div>
|
|
|
|
</div>
|
|
</div>
|
|
</div>
|
|
<div>
|
|
<div class="familylinks">
|
|
<div class="parentlink"><strong>Parent topic:</strong> <a href="inference-modelarts-0078.html">Examples of Custom Scripts</a></div>
|
|
</div>
|
|
</div>
|
|
|