Recurrent Neural Netword(RNN) Using Tensorflow
Pre-processing
Datasets
MNIST database of handwritten digits. Click here
Input data: Image shape(28*28)
Output label: 0~9
Parameters
- Data_Size
Input_dimension
: Dimension of each imageOutput_dimension
: Dimension of predicted labelClasses
: The number of different outputs
- Model_Parameter
Training_iter
: The number of iterations for trainingBatch_size
: The length of inputeach epoch
Requirement
Python 2.7
Tensorflow 0.12.1
Model(RNN + LSTM)
We use a Recurrent Neural Network with LSTM Cell to implement this model.
- LSTM (Long Short Term Memory):
LSTM Composed of three gates which called INPUT_GATE, FORGET_GATE and OUTPUT_GATE.
More information about how to implement LSTM Model is here.
- Initialize Step
First we should initialize the placeholder and weights of our neural network.placeholder
: just like the x of the function:
weights
: the weight for converting input data to output label.
|
|
- Training Step
First we define a RNN_Model function.
Using linear relationship to combine the output parameters.
|
|
Second we have to define the loss function and optmizer of our model.loss fuction
: softmax_cross_entropyoptimizer
:
|
|
Third in order to evaluate the efficiency of this model, we define the function to calculate accuracy.
|
|
Finally we can start training after all the initialization.
- We can use session to run our tensorflow function.
|
|
Tips: “./“ represent the parameters defined by user own.
- Testing Step
After training we get a weights in the tensorflow session which can be used to predict our test data.
First generate the testing dataset from mnist generator.
|
|
Finally because tensorflow mnist test dataset have its own ground-truth. So we can estimate if our “predict_label” is correct.
|
|
Usage
Install tensorflow.
- If we will run our model on GPU we have to install cuda and cuDNN.1pip install tensorflow(-gpu)==0.12.1
- If we will run our model on GPU we have to install cuda and cuDNN.
Import tensorflow package.
- Import tensorflow mnist dataset and read the dataset as a generator.
- Run our model1python "./model_name".py
Reference
Challenges
In this experiment we use a simple RNN(LSTM) model to predict the handwritten digits which also catch a good consequence in CNN.
RNN model is good for using in NLP processing. But how to explore the most useful determines whether our model can get an excellent result or not.