Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit 7ddcf85

Browse files
committed
tensorflow2.0 learn 2020-4-24
1 parent 6968bf2 commit 7ddcf85

File tree

1 file changed

+72
-4
lines changed

1 file changed

+72
-4
lines changed

tenforlow_learn/demo-1.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,32 @@
44
# @Software: PyCharm
55
# pip install --upgrade --ignore-installed tensorflow
66
import matplotlib as mpl
7+
import os
78
import matplotlib.pyplot as plt
89
import numpy as np
910
import tensorflow as tf
11+
import pandas as pd
1012
from tensorflow import keras
13+
from sklearn.preprocessing import StandardScaler
1114

1215
fashion_mnist = keras.datasets.fashion_mnist
1316
# 共六万图片
1417
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
1518
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
1619
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
1720

21+
print(np.max(x_train), np.min(x_train))
22+
23+
# 做归一化
24+
scaler = StandardScaler()
25+
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
26+
27+
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
28+
29+
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
30+
31+
print(np.max(x_train_scaled), np.min(x_train_scaled))
32+
1833
print(x_valid.shape, y_valid.shape)
1934
print(x_train.shape, y_train.shape)
2035
print(x_test.shape, y_test.shape)
@@ -35,15 +50,68 @@ def show_images(n_rows, n_cols, x_data, y_data, class_names):
3550
plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
3651
for row in range(n_rows):
3752
for col in range(n_cols):
38-
index = n_cols * n_rows + col
39-
plt.subplot(n_rows, n_cols, index + 1)
53+
# index 表示在n_rows行 n_cols列中的 哪个地方画图 取值范围为1,n_rows*n_cols
54+
index = n_rows * n_cols - col - row * n_cols
55+
plt.subplot(n_rows, n_cols, index)
4056
plt.imshow(x_data[index], cmap="binary", interpolation="nearest")
4157
plt.axis("off")
4258
plt.title(class_names[y_data[index]])
4359

4460
plt.show()
4561

4662

63+
def train():
64+
models = keras.models.Sequential()
65+
# 输入28*28的图像数据 28*28的矩阵变为 1*726的一维向量
66+
models.add(keras.layers.Flatten(input_shape=[28, 28]))
67+
# 全连接层 relu激活函数
68+
models.add(keras.layers.Dense(300, activation="relu"))
69+
70+
models.add(keras.layers.Dense(100, activation="relu"))
71+
72+
# softmax将向量变成概率分布
73+
# x= [x1,x2,x3]
74+
# y = [e^x1/sum,e^x2/sum,e^x3/sum],sum = e^x1+e^x2+e^x3
75+
models.add(keras.layers.Dense(10, activation="softmax"))
76+
77+
models.compile(loss="sparse_categorical_crossentropy",
78+
optimizer="sgd",
79+
metrics=["accuracy"])
80+
81+
# 模型的层
82+
print(models.layers)
83+
84+
# 可以看到模型的参数
85+
print(models.summary())
86+
87+
# TensorBoard earlytopping ModelCheckpoint
88+
89+
log_dir = "callbacks"
90+
if not os.path.exists(log_dir):
91+
os.mkdir(log_dir)
92+
ouput_model_file = os.path.join(log_dir, "fashion_mnist.h5")
93+
94+
callbacks = [
95+
keras.callbacks.TensorBoard(log_dir),
96+
keras.callbacks.ModelCheckpoint(ouput_model_file, save_best_only=True),
97+
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
98+
]
99+
history = models.fit(x_train_scaled, y_train, epochs=20, validation_data=(x_valid_scaled, y_valid),
100+
callbacks=callbacks)
101+
102+
# 测试集上评估
103+
print(models.evaluate(x_test_scaled, y_test))
104+
105+
# 存储训练过程中的一些值 如损失 accury等
106+
print(history.history)
107+
108+
pd.DataFrame(history.history).plot(figsize=(8, 5))
109+
plt.grid(True)
110+
plt.gca().set_ylim(0, 1)
111+
plt.show()
112+
113+
47114
if __name__ == '__main__':
48-
show_single_image(x_train[0])
49-
# show_images(1,1,x_train,y_train,class_names)
115+
train()
116+
# show_single_image(x_train[0])
117+
# show_images(10,5,x_train,y_train,class_names)

0 commit comments

Comments
 (0)