author: sz_jmu
前言
TensorFlow 是一个用于机器学习和深度学习的开源框架,由 Google Brain 团队开发并在 2015 年发布。它是目前最流行的深度学习框架之一,广泛用于构建、训练和部署机器学习模型,特别是在处理复杂的神经网络任务时。TensorFlow 提供了灵活的工具和库,支持从研究到生产环境中的机器学习应用。
TensorFlow可以支持CPU,也可以支持CPU+GPU,前者配置较为简单,兼容性较好,后者需要一些额外的操作支持。
在训练规模庞大的模型时,使用CPU往往存在较大的性能限制,模型训练速度较慢。神经网络算法通常涉及大量的参数、激活值、梯度值的缓冲区,其中每个值在每一次训练迭代中国都要被完全更新,有可能会超出传统计算机的高速缓存(Cache),所以内存带宽通常会成为主要瓶颈。而与CPU相比,GPU的一个显著优势就是具有极高的内存带宽。神经网络的训练算法通常不涉及大量的分支运算和复杂控制指令,更适合在GPU硬件上完成,具有并行特性的GPU更适合神经网络的计算,因此,安装TensorFlow的GPU环境是合适的选择。
参考文章:
解决tensorflow-gpu版本训练loss一直为nan,或者loss,accuracy数值明显不对问题_采用gpu训练 loss为负数 采用cpu训练正常-CSDN博客
注意:配置Tensorflow gpu版本,建议使用Anoconda创建环境,避免出现污染环境变量等问题,Anoconda的配置与使用,不做过多赘述。
一、Tensorflow环境的基本配置
Tensorflow的GPU环境,在不同的系统下配置存在一些兼容性问题,如果直接安装最新的版本,很难完成通过GPU进行深度学习相关的环境搭建。
CUDA,cuDNN,python,tensorflow的版本需要一一对应。
使用如:python=3.8 CUDA=11.3 cuDNN=8.2.1 tensorflow-gpu=2.7.0,运行较为稳定。
在Anoconda命令行环境下,创建Tensorflow的专属环境 。
1 conda create -n tf_gpu_1 python==3.8
诸如CSDN等网站,许多参考文章说明要从Nivida官网安装CUDA,比较耗时麻烦,实际上tensorflow需要的是CUDA中的cudatoolkit,所以在Anoconda环境下进行如下安装操作即可:
1 conda install cudatoolkit=11.3
安装cudnn,实际上作用是CUDA的补丁包
1 conda install cudnn=8.2 .1
安装tensorflow-gpu版本
1 conda install tensorflow-gpu=2.7 .0
安装完成后,在tf_gpu_1环境中,检查是否能够查找到显卡驱动
检查tensorflow是否识别到GPU设备
1.进入python命令行环境
2.导入tensorflow库
3.是否查找到GPU设备
1 print ("是否有 GPU 设备:" , len (tf.config.list_physical_devices('GPU' )) > 0 )
如果tensorflow环境配置正常,能够正常识别到主机GPU,tensorflow深度学习框架基本搭建完成,由于此例使用的tensorflow和python版本不是当前最高版本,后续自行编写相关代码可能需要根据具体的情况安装更多的依赖包或对代码进行调整。
二、手写汉字识别神经网络模型训练
1.数据集的准备
要实现手写汉字识别,需要准备规模庞大的数据集。仅仅是常见的汉字数据集,就有高达七千多种类别,且汉字书法风格迥异,若要训练出泛化性强,准确度高的神经网络模型,就需要足够充分复杂的数据集,数据集即要体现出汉字的普遍特征,也要具备不同的形式风格,通过数据集也可以预料到需要足够强大的模型才能实现对任意手写汉字的识别。
例如,常见手写汉字识别数据集的下载:
汉字名为类别标签的手写数据集 (747M)
文件结构如下:
├── data
│ ├── chinese-calligraphy-dataset
│ │ ├── ㄚ
│ │ ├── 一
│ │ ├── 丁
│ │ ├── 七
│ │ ├── 万
│ │ └── …
│ └── label_character.csv
2.编写Tensorflow工具链与使用说明
准备好数据集后,我们首先要做的工作是对数据集进行处理,使其符合神经网络模型训练的规范格式,其次,也有一些通过程序的方法能够对原始数据集进行处理,增加数据集的多样性,我们使用的方法为“数据增强”,即在原始数据集的基础上,对每张图片进行小幅度旋转,对比度调节,平移,压缩,放大等操作,这样有利于提高最后训练的模型的泛化性。
注意:代码运行需要在在命令行的本例:tf_gpu_1 的环境中运行
2.1数据集划分脚本
本脚本分为三个阶段:
1.将原始数据集复制到指定目录,并且分为test,val,train,即测试集,验证集,训练集。
2.对分类后的数据集进行数据增强,每张图片生成5张增强后的图像,那么,数据集的复杂程度得到了一定的增加
3.检查测试集中是否存在空的子文件夹,由于有些类别的汉字图片可能较少,按照代码中:20%的比例从训练集划分给测试集,若测试集图片少于2张,可能导致测试集中该类别没有对应的图片,所以,这个阶段用于对测试集文件夹进行二次排查,确保测试集中不存在空的汉字类别图片。
使用方法:
修改
1 2 src_data_folder = "S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/CursiveChineseCalligraphyDataset-master/Cursive_Chinese_Calligraphy_Dataset/Training" target_data_folder = "s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu"
其中src_data_folder为原始数据集路径,target_data_folder为划分好的目标路径(将包含train,val,test三个子文件夹)
在tf_gpu_1环境下,输入命令 python data_split.py
完整代码(data_split.py)如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 import osimport randomfrom shutil import copy2, movefrom concurrent.futures import ThreadPoolExecutorfrom tqdm import tqdmfrom PIL import Image, ImageEnhancedef copy_file (src_dest ): src_img_path, target_folder = src_dest copy2(src_img_path, target_folder) def move_file (src_dest ): src_img_path, target_folder = src_dest move(src_img_path, target_folder) def augment_image (src_img_path, target_folder, num_augments=5 ): img = Image.open (src_img_path) for i in range (num_augments): img_aug = img.copy() angle = random.uniform(-15 , 15 ) img_aug = img_aug.rotate(angle, fillcolor='white' ) max_translate = 5 x_translate = random.randint(-max_translate, -max_translate) y_translate = random.randint(-max_translate, -max_translate) img_aug = img_aug.transform(img_aug.size, Image.AFFINE, (1 , 0 , x_translate, 0 , 1 , y_translate), fillcolor='white' ) scale_factor = random.uniform(0.9 , 1.1 ) w, h = img_aug.size img_aug = img_aug.resize((int (w * scale_factor), int (h * scale_factor)), Image.Resampling.LANCZOS) img_aug = img_aug.resize((w, h), Image.Resampling.LANCZOS) enhancer = ImageEnhance.Contrast(img_aug) img_aug = enhancer.enhance(random.uniform(0.8 , 1.2 )) enhancer = ImageEnhance.Brightness(img_aug) img_aug = enhancer.enhance(random.uniform(0.8 , 1.2 )) aug_img_name = os.path.basename(src_img_path).replace("." , f"_aug_{i} ." ) img_aug.save(os.path.join(target_folder, aug_img_name)) def data_set_split_and_augment (src_data_folder, target_data_folder, train_scale=0.8 , val_scale=0.2 , test_scale=0.0 , num_augments=5 ): ''' 读取源数据文件夹,生成划分好的文件夹,并对每张图片生成5张增强图片。 :param src_data_folder: 源文件夹 :param target_data_folder: 目标文件夹 :param train_scale: 训练集比例 :param val_scale: 验证集比例 :param test_scale: 测试集比例 :param num_augments: 每张图片生成的增强图片数量 :return: ''' class_names = os.listdir(src_data_folder) split_names = ['train' , 'val' , 'test' ] data_split_completed = True for split_name in split_names: split_path = os.path.join(target_data_folder, split_name) if not os.path.exists(split_path) or len (os.listdir(split_path)) == 0 : data_split_completed = False break if not data_split_completed: print ("开始数据集划分" ) for split_name in split_names: split_path = os.path.join(target_data_folder, split_name) os.makedirs(split_path, exist_ok=True ) for class_name in class_names: class_split_path = os.path.join(split_path, class_name) os.makedirs(class_split_path, exist_ok=True ) tasks = [] total_files = 0 for class_name in class_names: current_class_data_path = os.path.join(src_data_folder, class_name) current_all_data = os.listdir(current_class_data_path) total_files += len (current_all_data) random.shuffle(current_all_data) train_folder = os.path.join(os.path.join(target_data_folder, 'train' ), class_name) val_folder = os.path.join(os.path.join(target_data_folder, 'val' ), class_name) test_folder = os.path.join(os.path.join(target_data_folder, 'test' ), class_name) train_stop_flag = len (current_all_data) * train_scale val_stop_flag = len (current_all_data) * (train_scale + val_scale) for idx, img_name in enumerate (current_all_data): src_img_path = os.path.join(current_class_data_path, img_name) if idx <= train_stop_flag: tasks.append((src_img_path, train_folder)) elif idx <= val_stop_flag: tasks.append((src_img_path, val_folder)) else : tasks.append((src_img_path, test_folder)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm(executor.map (copy_file, tasks), total=total_files, desc="文件复制进度" )) print ("数据集划分完成!" ) else : print ("数据集划分已经完成,跳过该步骤。" ) data_augmentation_completed = True for class_name in class_names: train_class_folder = os.path.join(target_data_folder, 'train' , class_name) if not any ("_aug_" in fname for fname in os.listdir(train_class_folder)): data_augmentation_completed = False break if not data_augmentation_completed: print ("开始数据增强" ) aug_tasks = [] for split_name in split_names: split_folder = os.path.join(target_data_folder, split_name) for class_name in class_names: class_split_folder = os.path.join(split_folder, class_name) for img_name in os.listdir(class_split_folder): if "_aug_" in img_name: src_img_path = os.path.join(class_split_folder, img_name) aug_tasks.append((src_img_path, class_split_folder, num_augments)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm(executor.map (lambda x: augment_image(*x), aug_tasks), total=len (aug_tasks), desc="数据增强进度" )) print ("数据增强完成!" ) else : print ("数据增强已经完成,跳过该步骤。" ) print ("开始检查并补充val文件夹" ) supplement_tasks = [] for class_name in class_names: val_class_folder = os.path.join(target_data_folder, 'val' , class_name) train_class_folder = os.path.join(target_data_folder, 'train' , class_name) val_files = os.listdir(val_class_folder) train_files = os.listdir(train_class_folder) if len (val_files) == 0 or len (val_files) < len (train_files) * 0.1 : num_to_move = max (1 , int (len (train_files) * 0.1 )) random.shuffle(train_files) files_to_move = train_files[:num_to_move] for file_name in files_to_move: src_img_path = os.path.join(train_class_folder, file_name) supplement_tasks.append((src_img_path, val_class_folder)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm(executor.map (move_file, supplement_tasks), total=len (supplement_tasks), desc="补充val文件夹进度" )) print ("val文件夹补充完成!" ) if __name__ == '__main__' : src_data_folder = "S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/CursiveChineseCalligraphyDataset-master/Cursive_Chinese_Calligraphy_Dataset/Training" target_data_folder = "s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu" data_set_split_and_augment(src_data_folder, target_data_folder, num_augments=5 )
2.2数据集标签的提取
此例汉字识别的类别高达7318种,并且为中文的格式,手动定义数据集的标签显然不太现实,因此,需要编写一个能够提取数据集标签的脚本。
使用方法:
在tf_gpu_1环境下使用命令 python labels_get.py运行此代码,将在代码相同目录下生成标签文件
完整代码(labels_get)如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 import tensorflow as tffrom pathlib import Pathdef load_chinese_dataset (data_dir ): data_dir = Path(data_dir).resolve() dataset = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='int' , seed=123 , batch_size=32 , image_size=(256 , 256 ) ) class_names = dataset.class_names print ("类别标签(中文):" , class_names) labels_file_path = Path(data_dir).parent / "lables_caoshu.py" with open (labels_file_path, 'w' , encoding='utf-8' ) as file: file.write("labels_caoshu = [\n" ) for label in class_names: file.write(f" '{label} ',\n" ) file.write("]\n" ) print (f"标签已保存到 {labels_file_path} " ) return dataset, class_names if __name__ == '__main__' : data_dir = r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" dataset, chinese_labels = load_chinese_dataset(data_dir) for images, labels in dataset.take(1 ): print ("图像批次:" , images.numpy()) print ("标签批次:" , labels.numpy())
2.3模型的训练
2.3.1基于卷积神经网络(CNN)的模型训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 import tensorflow as tfimport matplotlib.pyplot as pltfrom time import timefrom pathlib import Pathimport osos.environ['PYTHONIOENCODING' ] = 'utf-8' os.environ['LANG' ] = 'zh_CN.UTF-8' def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): data_dir = Path(data_dir).resolve() test_data_dir = Path(test_data_dir).resolve() train_ds = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size) val_ds = tf.keras.preprocessing.image_dataset_from_directory( str (test_data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size) class_names = train_ds.class_names return train_ds, val_ds, class_names def model_load (IMG_SHAPE=(160 , 160 , 3 ), class_num=12 ): model = tf.keras.models.Sequential([ tf.keras.layers.Lambda(lambda x: x / 255.0 , input_shape=IMG_SHAPE), tf.keras.layers.Conv2D(32 , (3 , 3 ), activation='relu' ), tf.keras.layers.MaxPooling2D(2 , 2 ), tf.keras.layers.Conv2D(64 , (3 , 3 ), activation='relu' ), tf.keras.layers.MaxPooling2D(2 , 2 ), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128 , activation='relu' ), tf.keras.layers.Dense(class_num, activation='softmax' ) ]) model.summary() model.compile (optimizer='sgd' , loss='categorical_crossentropy' , metrics=['accuracy' ]) return model def show_loss_acc (history ): acc = history.history['accuracy' ] val_acc = history.history['val_accuracy' ] loss = history.history['loss' ] val_loss = history.history['val_loss' ] plt.figure(figsize=(8 , 8 )) plt.subplot(2 , 1 , 1 ) plt.plot(acc, label='Training Accuracy' ) plt.plot(val_acc, label='Validation Accuracy' ) plt.legend(loc='lower right' ) plt.ylabel('Accuracy' ) plt.ylim([min (plt.ylim()), 1 ]) plt.title('Training and Validation Accuracy' ) plt.subplot(2 , 1 , 2 ) plt.plot(loss, label='Training Loss' ) plt.plot(val_loss, label='Validation Loss' ) plt.legend(loc='upper right' ) plt.ylabel('Cross Entropy' ) plt.title('Training and Validation Loss' ) plt.xlabel('epoch' ) plt.savefig('results/results_cnn.png' , dpi=100 ) def train (epochs ): begin_time = time() train_ds, val_ds, class_names = data_load( r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all-more/train" , r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all-more/val" , 160 , 160 , 32 ) print ("类别标签(中文):" , class_names) model = model_load(class_num=len (class_names)) history = model.fit(train_ds, validation_data=val_ds, epochs=epochs) model.save("models/cnn_hanzi_2.h5" ) end_time = time() run_time = end_time - begin_time print ('该循环程序运行时间:' , run_time, "s" ) show_loss_acc(history) if __name__ == '__main__' : train(epochs=40 )
2.3.2基于残差神经网络(Resnet)的模型训练
汉字识别模型类别非常多,使用一般的神经网络训练出来的模型可能泛化性并不理想,针对多类别的模型训练可以采用更复杂的网络,比如本例的Resnet残差神经网络,最终训练的模型准确率和泛化性极高。
使用方法:
代码片段中,修改路径为已经划分好的训练集,测试集路径
1 2 3 4 5 train_ds, val_ds, class_names = data_load( r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" , r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/val" , 128 , 128 , 64 )
模型保存名称修改
1 model.save("models/final_resnet50_chinese_kai" , save_format='tf' )
在tf_gpu_1环境下,使用命令 python model_train_resnet50.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 import tensorflow as tfimport matplotlib.pyplot as pltfrom time import timefrom pathlib import Pathimport osos.environ['PYTHONIOENCODING' ] = 'utf-8' os.environ['LANG' ] = 'zh_CN.UTF-8' from tensorflow.keras.mixed_precision import experimental as mixed_precisionpolicy = mixed_precision.Policy('mixed_float16' ) mixed_precision.set_policy(policy) def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): data_dir = Path(data_dir).resolve() test_data_dir = Path(test_data_dir).resolve() train_ds = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size ) class_names = train_ds.class_names data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal' ), tf.keras.layers.experimental.preprocessing.RandomRotation(0.1 ), ]) train_ds = train_ds.map (lambda x, y: (data_augmentation(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE) val_ds = tf.keras.preprocessing.image_dataset_from_directory( str (test_data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size ) train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) val_ds = val_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return train_ds, val_ds, class_names def model_load (IMG_SHAPE=(128 , 128 , 3 ), class_num=7200 ): resnet = tf.keras.applications.ResNet50(weights='imagenet' , include_top=False , input_shape=IMG_SHAPE) model = tf.keras.models.Sequential([ resnet, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(1024 , activation='relu' ), tf.keras.layers.Dropout(0.5 ), tf.keras.layers.Dense(class_num, activation='softmax' ) ]) model.summary() initial_learning_rate = 0.01 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=10000 , decay_rate=0.9 , staircase=True ) optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9 ) model.compile (optimizer=optimizer, loss='categorical_crossentropy' , metrics=['accuracy' ]) return model def show_loss_acc (history ): acc = history.history['accuracy' ] val_acc = history.history['val_accuracy' ] loss = history.history['loss' ] val_loss = history.history['val_loss' ] plt.figure(figsize=(8 , 8 )) plt.subplot(2 , 1 , 1 ) plt.plot(acc, label='Training Accuracy' ) plt.plot(val_acc, label='Validation Accuracy' ) plt.legend(loc='lower right' ) plt.ylabel('Accuracy' ) plt.ylim([min (plt.ylim()), 1 ]) plt.title('Training and Validation Accuracy' ) plt.subplot(2 , 1 , 2 ) plt.plot(loss, label='Training Loss' ) plt.plot(val_loss, label='Validation Loss' ) plt.legend(loc='upper right' ) plt.ylabel('Cross Entropy' ) plt.title('Training and Validation Loss' ) plt.xlabel('epoch' ) plt.savefig('results/results_resnet.png' , dpi=100 ) def train (epochs ): begin_time = time() train_ds, val_ds, class_names = data_load( r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" , r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/val" , 128 , 128 , 64 ) print ("类别标签(中文):" , class_names) model = model_load(class_num=len (class_names)) early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss' , patience=10 , restore_best_weights=True ) checkpoint = tf.keras.callbacks.ModelCheckpoint( 'models/final_resnet50_chinese_kai' , monitor='val_loss' , save_best_only=True , save_format='tf' ) history = model.fit( train_ds, validation_data=val_ds, epochs=epochs, callbacks=[early_stopping, checkpoint] ) model.save("models/final_resnet50_chinese_kai" , save_format='tf' ) end_time = time() run_time = end_time - begin_time print ('该循环程序运行时间:' , run_time, "s" ) show_loss_acc(history) if __name__ == '__main__' : train(epochs=15 )
2.4 模型的测试
代码如下,更改测试集以及模型路径为相应的目标路径即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport osimport timeimport randomplt.rcParams['font.family' ] = ['sans-serif' ] plt.rcParams['font.sans-serif' ] = ['SimHei' ] def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): class_names = sorted ( [dir_name for dir_name in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, dir_name))]) print (f"Detected {len (class_names)} classes." ) def process_path (file_path ): label = tf.strings.split(file_path, os.path.sep)[-2 ] label = tf.where(tf.equal(tf.constant(class_names), label))[0 ][0 ] img = tf.io.read_file(file_path) img = tf.image.decode_jpeg(img, channels=3 ) img = tf.image.resize(img, [img_height, img_width]) return img, label def prepare_dataset (directory ): list_ds = tf.data.Dataset.list_files(os.path.join(directory, '*/*' ), shuffle=True ) labeled_ds = list_ds.map (process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE) labeled_ds = labeled_ds.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return labeled_ds train_ds = prepare_dataset(data_dir) val_ds = prepare_dataset(test_data_dir) print (f"Loaded datasets: {len (train_ds)} training batches, {len (val_ds)} validation batches." ) return train_ds, val_ds, class_names def aggregate_labels_and_select (real_labels, pred_labels, num_classes, num_groups, selected_groups ): group_size = num_classes // num_groups real_labels_agg = [label // group_size for label in real_labels] pred_labels_agg = [label // group_size for label in pred_labels] real_labels_selected = [real_labels_agg[i] for i in range (len (real_labels_agg)) if real_labels_agg[i] in selected_groups] pred_labels_selected = [pred_labels_agg[i] for i in range (len (pred_labels_agg)) if real_labels_agg[i] in selected_groups] return real_labels_selected, pred_labels_selected def test_cnn (): start_time = time.time() train_ds, test_ds, class_names = data_load( r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all/train" , r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all/val" , 160 , 160 , 16 ) print (f"Data loading completed in {time.time() - start_time:.2 f} seconds." ) start_time = time.time() model = tf.keras.models.load_model("models/final_resnet50_chinese" ) print (f"Model loaded in {time.time() - start_time:.2 f} seconds." ) model.compile (optimizer='adam' , loss='sparse_categorical_crossentropy' , metrics=['accuracy' ]) test_real_labels = [] test_pre_labels = [] print ("Starting model inference..." ) inference_start_time = time.time() for batch_idx, (test_batch_images, test_batch_labels) in enumerate (test_ds): if batch_idx >= 10 : break if batch_idx % 2 == 0 : print (f"Processing batch {batch_idx + 1 } ..." ) test_batch_labels = test_batch_labels.numpy() test_batch_pres = model.predict(test_batch_images) test_real_labels.extend(test_batch_labels) test_pre_labels.extend(np.argmax(test_batch_pres, axis=1 )) print (f"Inference completed in {time.time() - inference_start_time:.2 f} seconds." ) num_groups = 10 selected_groups = random.sample(range (num_groups), 10 ) test_real_labels_agg, test_pre_labels_agg = aggregate_labels_and_select( test_real_labels, test_pre_labels, len (class_names), num_groups, selected_groups ) print ("Generating aggregated heatmaps..." ) heat_maps = np.zeros((num_groups, num_groups)) for real_label, pred_label in zip (test_real_labels_agg, test_pre_labels_agg): heat_maps[real_label][pred_label] += 1 heat_maps_sum = np.sum (heat_maps, axis=1 ).reshape(-1 , 1 ) heat_maps_sum[heat_maps_sum == 0 ] = 1 heat_maps_float = heat_maps / heat_maps_sum output_dir = "results/aggregated_heatmap" os.makedirs(output_dir, exist_ok=True ) save_path = os.path.join(output_dir, "aggregated_heatmap.png" ) show_heatmaps("Aggregated Heatmap" , selected_groups, selected_groups, heat_maps_float[selected_groups][:, selected_groups], save_path) print (f"Saved aggregated heatmap to {save_path} " ) print ("All heatmaps generated and saved." ) def show_heatmaps (title, x_labels, y_labels, harvest, save_name ): fig, ax = plt.subplots() im = ax.imshow(harvest, cmap="OrRd" ) ax.set_xticks(np.arange(len (y_labels))) ax.set_yticks(np.arange(len (x_labels))) ax.set_xticklabels(y_labels) ax.set_yticklabels(x_labels) plt.setp(ax.get_xticklabels(), rotation=45 , ha="right" , rotation_mode="anchor" ) for i in range (len (x_labels)): for j in range (len (y_labels)): ax.text(j, i, round (harvest[i, j], 2 ), ha="center" , va="center" , color="black" ) ax.set_xlabel("Predict label" ) ax.set_ylabel("Actual label" ) ax.set_title(title) fig.tight_layout() plt.colorbar(im) plt.savefig(save_name, dpi=100 ) plt.close(fig) if __name__ == '__main__' : start_time = time.time() test_cnn() print (f"Total execution time: {time.time() - start_time:.2 f} seconds." )
2.5 tensorboard工具——模型流图
model = tf.saved_model.load(“models/final_resnet50_chinese”) 语句中修改为目标模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import tensorflow as tfmodel = tf.saved_model.load("models/final_resnet50_chinese" ) infer = model.signatures['serving_default' ] log_dir = "logs/graph" writer = tf.summary.create_file_writer(log_dir) @tf.function def model_inference (input_tensor ): return infer(input_tensor) example_input = tf.random.normal([1 , 160 , 160 , 3 ]) with writer.as_default(): tf.summary.trace_on(graph=True , profiler=True ) model_inference(example_input) tf.summary.trace_export(name="model_trace" , step=0 , profiler_outdir=log_dir) writer.flush() print ("Graph has been written to TensorBoard logs. You can view it using TensorBoard." )
2.6 模型的使用
本例为基于Qt界面的汉字识别程序,通过加载训练好的Tensorflow模型,选择本地的图片作为输入,进行手写汉字识别
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 import tensorflow as tffrom PyQt5.QtGui import *from PyQt5.QtCore import *from PyQt5.QtWidgets import *import sysimport cv2from PIL import Imageimport numpy as npimport shutilfrom lables_caoshu import labels_caoshuclass MainWindow (QTabWidget ): def __init__ (self ): super ().__init__() self.setWindowIcon(QIcon('images/logo.png' )) self.setWindowTitle('CNN汉字识别系统' ) self.model = tf.keras.models.load_model("models/final_resnet50_chinese" ) self.to_predict_name = "images/Start_1.png" self.class_names = labels_caoshu self.resize(900 , 700 ) self.initUI() def initUI (self ): main_widget = QWidget() main_layout = QHBoxLayout() font = QFont('楷体' , 15 ) dark_style = """ QWidget { background-color: #2E2E2E; color: #FFFFFF; } QLabel { color: #FFFFFF; } QPushButton { background-color: #4F4F4F; border: 2px solid #6E6E6E; color: #FFFFFF; padding: 5px; border-radius: 5px; } QPushButton:hover { background-color: #6E6E6E; } QPushButton:pressed { background-color: #3D3D3D; } QTabBar::tab { background: #3D3D3D; color: #FFFFFF; padding: 10px; border-radius: 5px; } QTabBar::tab:selected { background: #2E2E2E; border-bottom: 2px solid #4F4F4F; } QTabBar::tab:!selected { background: #3D3D3D; } """ self.setStyleSheet(dark_style) left_widget = QWidget() left_layout = QVBoxLayout() img_title = QLabel("输入作品" ) img_title.setFont(font) img_title.setAlignment(Qt.AlignCenter) self.img_label = QLabel() self.process_image(self.to_predict_name) self.img_label.setPixmap(QPixmap("images/binary_show.png" )) left_layout.addWidget(img_title) left_layout.addWidget(self.img_label, 1 , Qt.AlignCenter) left_widget.setLayout(left_layout) right_widget = QWidget() right_layout = QVBoxLayout() btn_change = QPushButton(" 上传作品 " ) btn_change.setIcon(QIcon('images/upload.png' )) btn_change.clicked.connect(self.change_img) btn_change.setFont(font) btn_predict = QPushButton(" 开始识别 " ) btn_predict.setIcon(QIcon('images/recognize.png' )) btn_predict.setFont(font) btn_predict.clicked.connect(self.predict_img) label_result = QLabel(' 识别结果 ' ) self.result = QLabel("等待识别" ) label_result.setFont(QFont('楷体' , 16 )) self.result.setFont(QFont('楷体' , 24 )) right_layout.addStretch() right_layout.addWidget(label_result, 0 , Qt.AlignCenter) right_layout.addStretch() right_layout.addWidget(self.result, 0 , Qt.AlignCenter) right_layout.addStretch() right_layout.addStretch() right_layout.addWidget(btn_change) right_layout.addWidget(btn_predict) right_layout.addStretch() right_widget.setLayout(right_layout) main_layout.addWidget(left_widget) main_layout.addWidget(right_widget) main_widget.setLayout(main_layout) about_widget = QWidget() about_layout = QVBoxLayout() about_title = QLabel('欢迎使用手写汉字识别系统' ) about_title.setFont(QFont('楷体' , 18 )) about_title.setAlignment(Qt.AlignCenter) about_img = QLabel() about_img.setPixmap(QPixmap('images/CNN.png' )) about_img.setAlignment(Qt.AlignCenter) label_super = QLabel("sz_jmu" ) label_super.setFont(QFont('楷体' , 12 )) label_super.setAlignment(Qt.AlignRight) about_layout.addWidget(about_title) about_layout.addStretch() about_layout.addWidget(about_img) about_layout.addStretch() about_layout.addWidget(label_super) about_widget.setLayout(about_layout) self.addTab(main_widget, '主页' ) self.addTab(about_widget, '关于' ) self.setTabIcon(0 , QIcon('images/主页面.png' )) self.setTabIcon(1 , QIcon('images/关于.png' )) def change_img (self ): openfile_name = QFileDialog.getOpenFileName(self, 'chose files' , '' , 'Image files(*.jpg *.png *jpeg)' ) img_name = openfile_name[0 ] if img_name == '' : pass else : target_image_name = "images/tmp_up." + img_name.split("." )[-1 ] shutil.copy(img_name, target_image_name) self.to_predict_name = target_image_name self.process_image(self.to_predict_name) self.img_label.setPixmap(QPixmap("images/binary_show.png" )) self.result.setText("等待识别" ) self.show_binary_image() def process_image (self, image_path ): img_init = cv2.imread(image_path) h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0 , 0 ), fx=scale, fy=scale) gray_img = cv2.cvtColor(img_init, cv2.COLOR_BGR2GRAY) _, binary_img = cv2.threshold(gray_img, 127 , 255 , cv2.THRESH_BINARY) binary_img_inverted = cv2.bitwise_not(binary_img) binary_img_colored = cv2.cvtColor(binary_img_inverted, cv2.COLOR_GRAY2BGR) binary_img_show = cv2.resize(binary_img_colored, (0 , 0 ), fx=scale, fy=scale) cv2.imwrite('images/binary_show.png' , binary_img_show) cv2.imwrite('images/target.png' , binary_img_colored) cv2.imwrite('images/binary_target.png' , binary_img) def predict_img (self ): img = Image.open ('images/target.png' ) img = img.resize((128 , 128 )) img = np.asarray(img) if img.shape[-1 ] != 3 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) img = img.reshape(1 , 128 , 128 , 3 ) outputs = self.model.predict(img) result_index = int (np.argmax(outputs)) result = self.class_names[result_index] self.result.setText(result) def show_binary_image (self ): binary_image_window = QMainWindow() binary_image_window.setWindowTitle('二值化处理结果' ) binary_image_widget = QLabel() binary_image = QPixmap("images/binary_target.png" ) binary_image_widget.setPixmap(binary_image) binary_image_widget.setAlignment(Qt.AlignCenter) binary_image_window.setCentralWidget(binary_image_widget) binary_image_window.resize(binary_image.size()) binary_image_window.show() def closeEvent (self, event ): reply = QMessageBox.question(self, '退出' , "是否要退出程序?" , QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: self.close() event.accept() else : event.ignore() if __name__ == "__main__" : app = QApplication(sys.argv) x = MainWindow() x.show() sys.exit(app.exec_())
运行结果如下:
tensorflow_gpu_tools工具链总结:
更新记录
此栏目用于记录代码项目新增功能,bug修复等日志
date:2024.8.29
对model_train_densenet169.py 新增参数接口化 配置,使用命令行 即可便捷配置模型训练参数
如:
1 python script_name.py --train_data_dir "path/to/train_data" --test_data_dir "path/to/test_data" --img_height 128 --img_width 128 --batch_size 64 --epochs 15
β \beta β α \alpha α