Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
0% found this document useful (0 votes)
1 views

Training CodeTensorflowLite

Uploaded by

20701025
Copyright
© © All Rights Reserved
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
1 views

Training CodeTensorflowLite

Uploaded by

20701025
Copyright
© © All Rights Reserved
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 2

import numpy as np

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Expanded dataset
data = np.array([
[30, 70], [25, 50], [20, 90], [35, 40], [28, 65],
[22, 75], [32, 60], [26, 80], [29, 55], [31, 85],
[24, 45], [27, 70], [33, 50], [23, 85], [34, 65]
])

labels = np.array([80, 50, 90, 20, 70, 75, 60, 85, 55, 95, 45, 70, 50, 90, 65]) /
100.0

# Build the ANN model


model = Sequential([
Dense(16, activation='relu', input_shape=(2,)), # Increased number of neurons
Dense(16, activation='relu'),
Dense(1, activation='sigmoid') # Output rain likelihood (0-1)
])

# Compile the model


model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# Train the model


model.fit(data, labels, epochs=500, verbose=1)

# Save the model


model.save("rain_prediction_model.h5")

from tensorflow.keras.losses import MeanSquaredError

# Explicitly map 'mse' to the MeanSquaredError function


custom_objects = {'mse': MeanSquaredError()}

# Load the model with the custom_objects mapping


model = tf.keras.models.load_model("rain_prediction_model.h5",
custom_objects=custom_objects)

# Convert the model to TensorFlow Lite


converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TensorFlow Lite model


with open("rain_prediction_model.tflite", "wb") as f:
f.write(tflite_model)

# Read the TensorFlow Lite model file


with open("rain_prediction_model.tflite", "rb") as f:
tflite_model = f.read()

# Write the model to a C header file


with open("model.h", "w") as f:
f.write("#ifndef MODEL_H_\n")
f.write("#define MODEL_H_\n\n")
f.write(f"const unsigned char model[] = {{\n")

# Convert the binary model data to hexadecimal format


for i, byte in enumerate(tflite_model):
if i % 12 == 0: # 12 bytes per line for readability
f.write("\n ")
f.write(f"0x{byte:02x}, ")
f.write("\n};\n\n")
f.write(f"const unsigned int model_len = {len(tflite_model)};\n")
f.write("\n#endif // MODEL_H_\n")

You might also like