Kera高层API
2021/4/15 18:56:36
本文主要是介绍Kera高层API,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
目录
- Keras != tf.keras
- Outline1
- Metrics
- Step1.Build a meter
- Step2.Update data
- Step3.Get Average data
- Clear buffer
- Outline2
- Compile + Fit
- Individual loss and optimize1
- Now1
- Individual epoch and step2
- Now2
- Standard Progressbar
- Individual evaluation3
- Now3
- Evaluation
- Test
- Predict
Keras != tf.keras
Keras是一个框架
datasets
layers
losses
metrics
optimizers
Outline1
Metrics
update_state
result().numpy()
reset_states
Metrics
Step1.Build a meter
acc_meter = metrics.Accuarcy() loss_meter = metrics.Mean
Step2.Update data
loss_meter.update_state(loss) acc_meter.update_state(y,pred)
Step3.Get Average data
print(step, 'loss:', loss_meter.result().numpy()) # ... print(step,'Evaluate Acc:', total_correct/total, acc_meter.result().numpy()
Clear buffer
if step % 100 == 0: print(step, 'loss:', loss_meter.result().numpy()) loss_meter.reset_states() # ... if step % 500 == 0: total, total_correct = 0., 0 acc_meter.reset_states()
Outline2
Compile
Fit
Evaluate
Predict
Compile + Fit
Individual loss and optimize1
with tf.GradientTape() as tape: x = tf.reshape(x, (-1, 28*28)) out = network(x) y_onehot = tf.one_hot(y, depth=10) loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True)) grads = tape.gradient(loss, network.trainable_variables) optimizer.apply_gradients(zip(grads, network.trainable_variables))
Now1
network.compile(optimizer=optimizers.Adam(lr=0.01), loss=tf.losses.CategoricalCrossentropy(fromlogits=True), metircs=['accuracy'])
Individual epoch and step2
for epoch in range(epochs): for step, (x, y) in enumerate(db): # ...
Now2
network.compile(optimizer=optimizers.Adam(lr=0.01), loss=tf.losses.CategoricalCrossentropy(fromlogits=True), metircs=['accuracy']) network.fit(db, epochs=10)
Standard Progressbar
Individual evaluation3
if step % 500 == 0: total, total_correct = 0., 0 for step, (x, y) in enumerate(ds_val): x = tf.reshape(x, (-1, 28*28)) out = network(x) pred = tf.argmax(out, axis=1) pred = tf.cast(pred, dtype=tf.int32) correct = tf.equal(pred, y) total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy() total += x.shape[0] print(step, 'Evaluate Acc:', total_correct/total)
Now3
network.compile(optimizer=optimizers.Adam(lr=0.01), loss=tf.losses.CategoricalCrossentropy(fromlogits=True), metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证 network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2)
Evaluation
Test
network.compile(optimizer=optimizers.Adam(lr=0.01), loss=tf.losses.CategoricalCrossentropy(fromlogits=True), metircs=['accuracy']) # validation_freq=2表示2个epochs做一次验证 network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2) network.evaluate(ds_val)
Predict
sample = next(iter(ds_val)) x = sample[0] y = sample[1] pred = network.predict(x) y = tf.argmax(y, axis=1) pred = tf.argmax(pre, axis=1) print(pred) print(y)
这篇关于Kera高层API的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-15鸿蒙生态设备数量超8亿台
- 2024-05-13TiDB + ES:转转业财系统亿级数据存储优化实践
- 2024-05-09“2024鸿蒙零基础快速实战-仿抖音App开发(ArkTS版)”实战课程已上线
- 2024-05-09聊聊如何通过arthas-tunnel-server来远程管理所有需要arthas监控的应用
- 2024-05-09log4j2这么配就对了
- 2024-05-09nginx修改Content-Type
- 2024-05-09Redis多数据源,看这篇就够了
- 2024-05-09Google Chrome驱动程序 124.0.6367.62(正式版本)去哪下载?
- 2024-05-09有没有大佬知道这种数据应该怎么抓取呀?
- 2024-05-09这种运行结果里的10.100000001,怎么能最快改成10.1?