轨道故障识别(Xception)

轨道故障识别(Xception)
##导入包
import warnings
warnings.filterwarnings('ignore')
from keras.preprocessing.image import ImageDataGenerator
# 接下来是其它的导入语句
from skimage.io import  imread, imshow
from skimage.transform import  resize, rescale
from skimage.color import rgb2gray
from os import listdir, path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os,os.path
import visualkeras
from PIL import ImageFont
from tensorflow.keras import Model
from keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization
from tensorflow.keras.preprocessing import image_dataset_from_directory
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model

datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
train_gen = datagen.flow_from_directory(directory = 'C:/Users/..',
                            class_mode="categorical",
                            target_size = (299, 299),
                            batch_size = 32,
                            color_mode='rgb',
                            seed = 1234,
                            shuffle = True,
                            classes=['defective', 'no_defective']  # 指定要加载的类别
val_gen = datagen.flow_from_directory(directory = 'C:/Users/..',
                    class_mode="categorical",
                    target_size = (299, 299),
                    batch_size = 32,
                    color_mode='rgb',
                    seed = 1234,
                    shuffle = True,
                    classes=['defective', 'no_defective'])
test_gen = datagen.flow_from_directory(directory = 'C:/Users/..', class_mode="categorical", target_size = (299, 299), batch_size = 32, color_mode='rgb', shuffle = False , classes=['defective', 'no_defective'] )
```python
from tensorflow.keras.applications import Xception
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
# 加载预训练模型 Xception,不包括顶部的全连接层
base_model = Xception(weights='imagenet', include_top=False, input_shape=(299, 299, 3))
# 冻结预训练模型的所有层,这样在初期训练时不会改变它们的权重
for layer in base_model.layers:
    layer.trainable = False
# 构建模型
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.2)(x)  # 添加 Dropout 层来减少过拟合
x = Dense(1024, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)  # 假设您有两个类别

model = Model(inputs=base_model.input, outputs=predictions)
# 编译模型
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
# 回调函数
checkpoint = ModelCheckpoint('best_model_improved.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='min', restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, mode='min', min_lr=0.00001)

callbacks_list = [checkpoint, early_stopping, reduce_lr]
# 训练模型
history = model.fit(
    train_gen,
    epochs=15,
    validation_data=val_gen,
    callbacks=callbacks_list
)
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

# 加载字体,以便支持中文(根据您的系统路径调整字体路径)
font = FontProperties(fname='C:/Users/yusialone/Downloads/Compressed/SimHei.ttf')

# 假设 'history' 是模型训练返回的历史记录对象
history_df = pd.DataFrame(history.history)

# 绘制训练和验证损失
plt.figure(figsize=(8, 6))
plt.plot(history_df['loss'], label='训练损失')
plt.plot(history_df['val_loss'], label='验证损失')
plt.title('训练和验证损失随周期变化', fontproperties=font)
plt.xlabel('周期', fontproperties=font)
plt.ylabel('损失', fontproperties=font)
plt.legend(prop=font)
plt.show()

1 Comment

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注