TensorFlow 是一个用于机器学习和深度学习的开源框架,由 Google Brain 团队开发并在 2015 年发布。它是目前最流行的深度学习框架之一,广泛用于构建、训练和部署机器学习模型,特别是在处理复杂的神经网络任务时。TensorFlow 提供了灵活的工具和库,支持从研究到生产环境中的机器学习应用。
注意:配置Tensorflow gpu版本,建议使用Anoconda创建环境,避免出现污染环境变量等问题,Anoconda的配置与使用,不做过多赘述。
使用如:python=3.8 CUDA=11.3 cuDNN=8.2.1 tensorflow-gpu=2.7.0,运行较为稳定。
在Anoconda命令行环境下,创建Tensorflow的专属环境 。
1 conda create -n tf_gpu_1 python==3.8
1 conda install cudatoolkit=11.3
1 conda install cudnn=8.2 .1
1 conda install tensorflow-gpu=2.7 .0
1 print ("是否有 GPU 设备:" , len (tf.config.list_physical_devices('GPU' )) > 0 )
汉字名为类别标签的手写数据集 (747M)
├── data
│ ├── chinese-calligraphy-dataset
│ │ ├── ㄚ
│ │ ├── 一
│ │ ├── 丁
│ │ ├── 七
│ │ ├── 万
│ │ └── …
│ └── label_character.csv
注意:代码运行需要在在命令行的本例:tf_gpu_1 的环境中运行
1 2 src_data_folder = "S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/CursiveChineseCalligraphyDataset-master/Cursive_Chinese_Calligraphy_Dataset/Training" target_data_folder = "s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu"
在tf_gpu_1环境下,输入命令 python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 import osimport randomfrom shutil import copy2, movefrom concurrent.futures import ThreadPoolExecutorfrom tqdm import tqdmfrom PIL import Image, ImageEnhancedef copy_file (src_dest ): src_img_path, target_folder = src_dest copy2(src_img_path, target_folder) def move_file (src_dest ): src_img_path, target_folder = src_dest move(src_img_path, target_folder) def augment_image (src_img_path, target_folder, num_augments=5 ): img = (src_img_path) for i in range (num_augments): img_aug = img.copy() angle = random.uniform(-15 , 15 ) img_aug = img_aug.rotate(angle, fillcolor='white' ) max_translate = 5 x_translate = random.randint(-max_translate, -max_translate) y_translate = random.randint(-max_translate, -max_translate) img_aug = img_aug.transform(img_aug.size, Image.AFFINE, (1 , 0 , x_translate, 0 , 1 , y_translate), fillcolor='white' ) scale_factor = random.uniform(0.9 , 1.1 ) w, h = img_aug.size img_aug = img_aug.resize((int (w * scale_factor), int (h * scale_factor)), Image.Resampling.LANCZOS) img_aug = img_aug.resize((w, h), Image.Resampling.LANCZOS) enhancer = ImageEnhance.Contrast(img_aug) img_aug = enhancer.enhance(random.uniform(0.8 , 1.2 )) enhancer = ImageEnhance.Brightness(img_aug) img_aug = enhancer.enhance(random.uniform(0.8 , 1.2 )) aug_img_name = os.path.basename(src_img_path).replace("." , f"_aug_{i} ." ), aug_img_name)) def data_set_split_and_augment (src_data_folder, target_data_folder, train_scale=0.8 , val_scale=0.2 , test_scale=0.0 , num_augments=5 ): ''' 读取源数据文件夹,生成划分好的文件夹,并对每张图片生成5张增强图片。 :param src_data_folder: 源文件夹 :param target_data_folder: 目标文件夹 :param train_scale: 训练集比例 :param val_scale: 验证集比例 :param test_scale: 测试集比例 :param num_augments: 每张图片生成的增强图片数量 :return: ''' class_names = os.listdir(src_data_folder) split_names = ['train' , 'val' , 'test' ] data_split_completed = True for split_name in split_names: split_path = os.path.join(target_data_folder, split_name) if not os.path.exists(split_path) or len (os.listdir(split_path)) == 0 : data_split_completed = False break if not data_split_completed: print ("开始数据集划分" ) for split_name in split_names: split_path = os.path.join(target_data_folder, split_name) os.makedirs(split_path, exist_ok=True ) for class_name in class_names: class_split_path = os.path.join(split_path, class_name) os.makedirs(class_split_path, exist_ok=True ) tasks = [] total_files = 0 for class_name in class_names: current_class_data_path = os.path.join(src_data_folder, class_name) current_all_data = os.listdir(current_class_data_path) total_files += len (current_all_data) random.shuffle(current_all_data) train_folder = os.path.join(os.path.join(target_data_folder, 'train' ), class_name) val_folder = os.path.join(os.path.join(target_data_folder, 'val' ), class_name) test_folder = os.path.join(os.path.join(target_data_folder, 'test' ), class_name) train_stop_flag = len (current_all_data) * train_scale val_stop_flag = len (current_all_data) * (train_scale + val_scale) for idx, img_name in enumerate (current_all_data): src_img_path = os.path.join(current_class_data_path, img_name) if idx <= train_stop_flag: tasks.append((src_img_path, train_folder)) elif idx <= val_stop_flag: tasks.append((src_img_path, val_folder)) else : tasks.append((src_img_path, test_folder)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm( (copy_file, tasks), total=total_files, desc="文件复制进度" )) print ("数据集划分完成!" ) else : print ("数据集划分已经完成,跳过该步骤。" ) data_augmentation_completed = True for class_name in class_names: train_class_folder = os.path.join(target_data_folder, 'train' , class_name) if not any ("_aug_" in fname for fname in os.listdir(train_class_folder)): data_augmentation_completed = False break if not data_augmentation_completed: print ("开始数据增强" ) aug_tasks = [] for split_name in split_names: split_folder = os.path.join(target_data_folder, split_name) for class_name in class_names: class_split_folder = os.path.join(split_folder, class_name) for img_name in os.listdir(class_split_folder): if "_aug_" in img_name: src_img_path = os.path.join(class_split_folder, img_name) aug_tasks.append((src_img_path, class_split_folder, num_augments)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm( (lambda x: augment_image(*x), aug_tasks), total=len (aug_tasks), desc="数据增强进度" )) print ("数据增强完成!" ) else : print ("数据增强已经完成,跳过该步骤。" ) print ("开始检查并补充val文件夹" ) supplement_tasks = [] for class_name in class_names: val_class_folder = os.path.join(target_data_folder, 'val' , class_name) train_class_folder = os.path.join(target_data_folder, 'train' , class_name) val_files = os.listdir(val_class_folder) train_files = os.listdir(train_class_folder) if len (val_files) == 0 or len (val_files) < len (train_files) * 0.1 : num_to_move = max (1 , int (len (train_files) * 0.1 )) random.shuffle(train_files) files_to_move = train_files[:num_to_move] for file_name in files_to_move: src_img_path = os.path.join(train_class_folder, file_name) supplement_tasks.append((src_img_path, val_class_folder)) with ThreadPoolExecutor(max_workers=8 ) as executor: list (tqdm( (move_file, supplement_tasks), total=len (supplement_tasks), desc="补充val文件夹进度" )) print ("val文件夹补充完成!" ) if __name__ == '__main__' : src_data_folder = "S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/CursiveChineseCalligraphyDataset-master/Cursive_Chinese_Calligraphy_Dataset/Training" target_data_folder = "s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu" data_set_split_and_augment(src_data_folder, target_data_folder, num_augments=5 )
在tf_gpu_1环境下使用命令 python labels_get.py运行此代码,将在代码相同目录下生成标签文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 import tensorflow as tffrom pathlib import Pathdef load_chinese_dataset (data_dir ): data_dir = Path(data_dir).resolve() dataset = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='int' , seed=123 , batch_size=32 , image_size=(256 , 256 ) ) class_names = dataset.class_names print ("类别标签(中文):" , class_names) labels_file_path = Path(data_dir).parent / "" with open (labels_file_path, 'w' , encoding='utf-8' ) as file: file.write("labels_caoshu = [\n" ) for label in class_names: file.write(f" '{label} ',\n" ) file.write("]\n" ) print (f"标签已保存到 {labels_file_path} " ) return dataset, class_names if __name__ == '__main__' : data_dir = r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" dataset, chinese_labels = load_chinese_dataset(data_dir) for images, labels in dataset.take(1 ): print ("图像批次:" , images.numpy()) print ("标签批次:" , labels.numpy())
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 import tensorflow as tfimport matplotlib.pyplot as pltfrom time import timefrom pathlib import Pathimport osos.environ['PYTHONIOENCODING' ] = 'utf-8' os.environ['LANG' ] = 'zh_CN.UTF-8' def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): data_dir = Path(data_dir).resolve() test_data_dir = Path(test_data_dir).resolve() train_ds = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size) val_ds = tf.keras.preprocessing.image_dataset_from_directory( str (test_data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size) class_names = train_ds.class_names return train_ds, val_ds, class_names def model_load (IMG_SHAPE=(160 , 160 , 3 ), class_num=12 ): model = tf.keras.models.Sequential([ tf.keras.layers.Lambda(lambda x: x / 255.0 , input_shape=IMG_SHAPE), tf.keras.layers.Conv2D(32 , (3 , 3 ), activation='relu' ), tf.keras.layers.MaxPooling2D(2 , 2 ), tf.keras.layers.Conv2D(64 , (3 , 3 ), activation='relu' ), tf.keras.layers.MaxPooling2D(2 , 2 ), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128 , activation='relu' ), tf.keras.layers.Dense(class_num, activation='softmax' ) ]) model.summary() model.compile (optimizer='sgd' , loss='categorical_crossentropy' , metrics=['accuracy' ]) return model def show_loss_acc (history ): acc = history.history['accuracy' ] val_acc = history.history['val_accuracy' ] loss = history.history['loss' ] val_loss = history.history['val_loss' ] plt.figure(figsize=(8 , 8 )) plt.subplot(2 , 1 , 1 ) plt.plot(acc, label='Training Accuracy' ) plt.plot(val_acc, label='Validation Accuracy' ) plt.legend(loc='lower right' ) plt.ylabel('Accuracy' ) plt.ylim([min (plt.ylim()), 1 ]) plt.title('Training and Validation Accuracy' ) plt.subplot(2 , 1 , 2 ) plt.plot(loss, label='Training Loss' ) plt.plot(val_loss, label='Validation Loss' ) plt.legend(loc='upper right' ) plt.ylabel('Cross Entropy' ) plt.title('Training and Validation Loss' ) plt.xlabel('epoch' ) plt.savefig('results/results_cnn.png' , dpi=100 ) def train (epochs ): begin_time = time() train_ds, val_ds, class_names = data_load( r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all-more/train" , r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all-more/val" , 160 , 160 , 32 ) print ("类别标签(中文):" , class_names) model = model_load(class_num=len (class_names)) history =, validation_data=val_ds, epochs=epochs)"models/cnn_hanzi_2.h5" ) end_time = time() run_time = end_time - begin_time print ('该循环程序运行时间:' , run_time, "s" ) show_loss_acc(history) if __name__ == '__main__' : train(epochs=40 )
1 2 3 4 5 train_ds, val_ds, class_names = data_load( r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" , r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/val" , 128 , 128 , 64 )
1"models/final_resnet50_chinese_kai" , save_format='tf' )
在tf_gpu_1环境下,使用命令 python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 import tensorflow as tfimport matplotlib.pyplot as pltfrom time import timefrom pathlib import Pathimport osos.environ['PYTHONIOENCODING' ] = 'utf-8' os.environ['LANG' ] = 'zh_CN.UTF-8' from tensorflow.keras.mixed_precision import experimental as mixed_precisionpolicy = mixed_precision.Policy('mixed_float16' ) mixed_precision.set_policy(policy) def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): data_dir = Path(data_dir).resolve() test_data_dir = Path(test_data_dir).resolve() train_ds = tf.keras.preprocessing.image_dataset_from_directory( str (data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size ) class_names = train_ds.class_names data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal' ), tf.keras.layers.experimental.preprocessing.RandomRotation(0.1 ), ]) train_ds = (lambda x, y: (data_augmentation(x), y), val_ds = tf.keras.preprocessing.image_dataset_from_directory( str (test_data_dir), label_mode='categorical' , seed=123 , image_size=(img_height, img_width), batch_size=batch_size ) train_ds = train_ds.prefetch( val_ds = val_ds.prefetch( return train_ds, val_ds, class_names def model_load (IMG_SHAPE=(128 , 128 , 3 ), class_num=7200 ): resnet = tf.keras.applications.ResNet50(weights='imagenet' , include_top=False , input_shape=IMG_SHAPE) model = tf.keras.models.Sequential([ resnet, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(1024 , activation='relu' ), tf.keras.layers.Dropout(0.5 ), tf.keras.layers.Dense(class_num, activation='softmax' ) ]) model.summary() initial_learning_rate = 0.01 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=10000 , decay_rate=0.9 , staircase=True ) optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9 ) model.compile (optimizer=optimizer, loss='categorical_crossentropy' , metrics=['accuracy' ]) return model def show_loss_acc (history ): acc = history.history['accuracy' ] val_acc = history.history['val_accuracy' ] loss = history.history['loss' ] val_loss = history.history['val_loss' ] plt.figure(figsize=(8 , 8 )) plt.subplot(2 , 1 , 1 ) plt.plot(acc, label='Training Accuracy' ) plt.plot(val_acc, label='Validation Accuracy' ) plt.legend(loc='lower right' ) plt.ylabel('Accuracy' ) plt.ylim([min (plt.ylim()), 1 ]) plt.title('Training and Validation Accuracy' ) plt.subplot(2 , 1 , 2 ) plt.plot(loss, label='Training Loss' ) plt.plot(val_loss, label='Validation Loss' ) plt.legend(loc='upper right' ) plt.ylabel('Cross Entropy' ) plt.title('Training and Validation Loss' ) plt.xlabel('epoch' ) plt.savefig('results/results_resnet.png' , dpi=100 ) def train (epochs ): begin_time = time() train_ds, val_ds, class_names = data_load( r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/train" , r"s:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-caoshu/val" , 128 , 128 , 64 ) print ("类别标签(中文):" , class_names) model = model_load(class_num=len (class_names)) early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss' , patience=10 , restore_best_weights=True ) checkpoint = tf.keras.callbacks.ModelCheckpoint( 'models/final_resnet50_chinese_kai' , monitor='val_loss' , save_best_only=True , save_format='tf' ) history = train_ds, validation_data=val_ds, epochs=epochs, callbacks=[early_stopping, checkpoint] )"models/final_resnet50_chinese_kai" , save_format='tf' ) end_time = time() run_time = end_time - begin_time print ('该循环程序运行时间:' , run_time, "s" ) show_loss_acc(history) if __name__ == '__main__' : train(epochs=15 )
2.4 模型的测试
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport osimport timeimport randomplt.rcParams['' ] = ['sans-serif' ] plt.rcParams['font.sans-serif' ] = ['SimHei' ] def data_load (data_dir, test_data_dir, img_height, img_width, batch_size ): class_names = sorted ( [dir_name for dir_name in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, dir_name))]) print (f"Detected {len (class_names)} classes." ) def process_path (file_path ): label = tf.strings.split(file_path, os.path.sep)[-2 ] label = tf.where(tf.equal(tf.constant(class_names), label))[0 ][0 ] img = img = tf.image.decode_jpeg(img, channels=3 ) img = tf.image.resize(img, [img_height, img_width]) return img, label def prepare_dataset (directory ): list_ds =, '*/*' ), shuffle=True ) labeled_ds = (process_path, labeled_ds = labeled_ds.batch(batch_size).prefetch( return labeled_ds train_ds = prepare_dataset(data_dir) val_ds = prepare_dataset(test_data_dir) print (f"Loaded datasets: {len (train_ds)} training batches, {len (val_ds)} validation batches." ) return train_ds, val_ds, class_names def aggregate_labels_and_select (real_labels, pred_labels, num_classes, num_groups, selected_groups ): group_size = num_classes // num_groups real_labels_agg = [label // group_size for label in real_labels] pred_labels_agg = [label // group_size for label in pred_labels] real_labels_selected = [real_labels_agg[i] for i in range (len (real_labels_agg)) if real_labels_agg[i] in selected_groups] pred_labels_selected = [pred_labels_agg[i] for i in range (len (pred_labels_agg)) if real_labels_agg[i] in selected_groups] return real_labels_selected, pred_labels_selected def test_cnn (): start_time = time.time() train_ds, test_ds, class_names = data_load( r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all/train" , r"S:/2_tensflow_Project/chinese-calligraphy-dataset-master/data/data-chinese-all/val" , 160 , 160 , 16 ) print (f"Data loading completed in {time.time() - start_time:.2 f} seconds." ) start_time = time.time() model = tf.keras.models.load_model("models/final_resnet50_chinese" ) print (f"Model loaded in {time.time() - start_time:.2 f} seconds." ) model.compile (optimizer='adam' , loss='sparse_categorical_crossentropy' , metrics=['accuracy' ]) test_real_labels = [] test_pre_labels = [] print ("Starting model inference..." ) inference_start_time = time.time() for batch_idx, (test_batch_images, test_batch_labels) in enumerate (test_ds): if batch_idx >= 10 : break if batch_idx % 2 == 0 : print (f"Processing batch {batch_idx + 1 } ..." ) test_batch_labels = test_batch_labels.numpy() test_batch_pres = model.predict(test_batch_images) test_real_labels.extend(test_batch_labels) test_pre_labels.extend(np.argmax(test_batch_pres, axis=1 )) print (f"Inference completed in {time.time() - inference_start_time:.2 f} seconds." ) num_groups = 10 selected_groups = random.sample(range (num_groups), 10 ) test_real_labels_agg, test_pre_labels_agg = aggregate_labels_and_select( test_real_labels, test_pre_labels, len (class_names), num_groups, selected_groups ) print ("Generating aggregated heatmaps..." ) heat_maps = np.zeros((num_groups, num_groups)) for real_label, pred_label in zip (test_real_labels_agg, test_pre_labels_agg): heat_maps[real_label][pred_label] += 1 heat_maps_sum = np.sum (heat_maps, axis=1 ).reshape(-1 , 1 ) heat_maps_sum[heat_maps_sum == 0 ] = 1 heat_maps_float = heat_maps / heat_maps_sum output_dir = "results/aggregated_heatmap" os.makedirs(output_dir, exist_ok=True ) save_path = os.path.join(output_dir, "aggregated_heatmap.png" ) show_heatmaps("Aggregated Heatmap" , selected_groups, selected_groups, heat_maps_float[selected_groups][:, selected_groups], save_path) print (f"Saved aggregated heatmap to {save_path} " ) print ("All heatmaps generated and saved." ) def show_heatmaps (title, x_labels, y_labels, harvest, save_name ): fig, ax = plt.subplots() im = ax.imshow(harvest, cmap="OrRd" ) ax.set_xticks(np.arange(len (y_labels))) ax.set_yticks(np.arange(len (x_labels))) ax.set_xticklabels(y_labels) ax.set_yticklabels(x_labels) plt.setp(ax.get_xticklabels(), rotation=45 , ha="right" , rotation_mode="anchor" ) for i in range (len (x_labels)): for j in range (len (y_labels)): ax.text(j, i, round (harvest[i, j], 2 ), ha="center" , va="center" , color="black" ) ax.set_xlabel("Predict label" ) ax.set_ylabel("Actual label" ) ax.set_title(title) fig.tight_layout() plt.colorbar(im) plt.savefig(save_name, dpi=100 ) plt.close(fig) if __name__ == '__main__' : start_time = time.time() test_cnn() print (f"Total execution time: {time.time() - start_time:.2 f} seconds." )
2.5 tensorboard工具——模型流图
model = tf.saved_model.load(“models/final_resnet50_chinese”) 语句中修改为目标模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import tensorflow as tfmodel = tf.saved_model.load("models/final_resnet50_chinese" ) infer = model.signatures['serving_default' ] log_dir = "logs/graph" writer = tf.summary.create_file_writer(log_dir) @tf.function def model_inference (input_tensor ): return infer(input_tensor) example_input = tf.random.normal([1 , 160 , 160 , 3 ]) with writer.as_default(): tf.summary.trace_on(graph=True , profiler=True ) model_inference(example_input) tf.summary.trace_export(name="model_trace" , step=0 , profiler_outdir=log_dir) writer.flush() print ("Graph has been written to TensorBoard logs. You can view it using TensorBoard." )
2.6 模型的使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 import tensorflow as tffrom PyQt5.QtGui import *from PyQt5.QtCore import *from PyQt5.QtWidgets import *import sysimport cv2from PIL import Imageimport numpy as npimport shutilfrom lables_caoshu import labels_caoshuclass MainWindow (QTabWidget ): def __init__ (self ): super ().__init__() self.setWindowIcon(QIcon('images/logo.png' )) self.setWindowTitle('CNN汉字识别系统' ) self.model = tf.keras.models.load_model("models/final_resnet50_chinese" ) self.to_predict_name = "images/Start_1.png" self.class_names = labels_caoshu self.resize(900 , 700 ) self.initUI() def initUI (self ): main_widget = QWidget() main_layout = QHBoxLayout() font = QFont('楷体' , 15 ) dark_style = """ QWidget { background-color: #2E2E2E; color: #FFFFFF; } QLabel { color: #FFFFFF; } QPushButton { background-color: #4F4F4F; border: 2px solid #6E6E6E; color: #FFFFFF; padding: 5px; border-radius: 5px; } QPushButton:hover { background-color: #6E6E6E; } QPushButton:pressed { background-color: #3D3D3D; } QTabBar::tab { background: #3D3D3D; color: #FFFFFF; padding: 10px; border-radius: 5px; } QTabBar::tab:selected { background: #2E2E2E; border-bottom: 2px solid #4F4F4F; } QTabBar::tab:!selected { background: #3D3D3D; } """ self.setStyleSheet(dark_style) left_widget = QWidget() left_layout = QVBoxLayout() img_title = QLabel("输入作品" ) img_title.setFont(font) img_title.setAlignment(Qt.AlignCenter) self.img_label = QLabel() self.process_image(self.to_predict_name) self.img_label.setPixmap(QPixmap("images/binary_show.png" )) left_layout.addWidget(img_title) left_layout.addWidget(self.img_label, 1 , Qt.AlignCenter) left_widget.setLayout(left_layout) right_widget = QWidget() right_layout = QVBoxLayout() btn_change = QPushButton(" 上传作品 " ) btn_change.setIcon(QIcon('images/upload.png' )) btn_change.clicked.connect(self.change_img) btn_change.setFont(font) btn_predict = QPushButton(" 开始识别 " ) btn_predict.setIcon(QIcon('images/recognize.png' )) btn_predict.setFont(font) btn_predict.clicked.connect(self.predict_img) label_result = QLabel(' 识别结果 ' ) self.result = QLabel("等待识别" ) label_result.setFont(QFont('楷体' , 16 )) self.result.setFont(QFont('楷体' , 24 )) right_layout.addStretch() right_layout.addWidget(label_result, 0 , Qt.AlignCenter) right_layout.addStretch() right_layout.addWidget(self.result, 0 , Qt.AlignCenter) right_layout.addStretch() right_layout.addStretch() right_layout.addWidget(btn_change) right_layout.addWidget(btn_predict) right_layout.addStretch() right_widget.setLayout(right_layout) main_layout.addWidget(left_widget) main_layout.addWidget(right_widget) main_widget.setLayout(main_layout) about_widget = QWidget() about_layout = QVBoxLayout() about_title = QLabel('欢迎使用手写汉字识别系统' ) about_title.setFont(QFont('楷体' , 18 )) about_title.setAlignment(Qt.AlignCenter) about_img = QLabel() about_img.setPixmap(QPixmap('images/CNN.png' )) about_img.setAlignment(Qt.AlignCenter) label_super = QLabel("sz_jmu" ) label_super.setFont(QFont('楷体' , 12 )) label_super.setAlignment(Qt.AlignRight) about_layout.addWidget(about_title) about_layout.addStretch() about_layout.addWidget(about_img) about_layout.addStretch() about_layout.addWidget(label_super) about_widget.setLayout(about_layout) self.addTab(main_widget, '主页' ) self.addTab(about_widget, '关于' ) self.setTabIcon(0 , QIcon('images/主页面.png' )) self.setTabIcon(1 , QIcon('images/关于.png' )) def change_img (self ): openfile_name = QFileDialog.getOpenFileName(self, 'chose files' , '' , 'Image files(*.jpg *.png *jpeg)' ) img_name = openfile_name[0 ] if img_name == '' : pass else : target_image_name = "images/tmp_up." + img_name.split("." )[-1 ] shutil.copy(img_name, target_image_name) self.to_predict_name = target_image_name self.process_image(self.to_predict_name) self.img_label.setPixmap(QPixmap("images/binary_show.png" )) self.result.setText("等待识别" ) self.show_binary_image() def process_image (self, image_path ): img_init = cv2.imread(image_path) h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0 , 0 ), fx=scale, fy=scale) gray_img = cv2.cvtColor(img_init, cv2.COLOR_BGR2GRAY) _, binary_img = cv2.threshold(gray_img, 127 , 255 , cv2.THRESH_BINARY) binary_img_inverted = cv2.bitwise_not(binary_img) binary_img_colored = cv2.cvtColor(binary_img_inverted, cv2.COLOR_GRAY2BGR) binary_img_show = cv2.resize(binary_img_colored, (0 , 0 ), fx=scale, fy=scale) cv2.imwrite('images/binary_show.png' , binary_img_show) cv2.imwrite('images/target.png' , binary_img_colored) cv2.imwrite('images/binary_target.png' , binary_img) def predict_img (self ): img = ('images/target.png' ) img = img.resize((128 , 128 )) img = np.asarray(img) if img.shape[-1 ] != 3 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) img = img.reshape(1 , 128 , 128 , 3 ) outputs = self.model.predict(img) result_index = int (np.argmax(outputs)) result = self.class_names[result_index] self.result.setText(result) def show_binary_image (self ): binary_image_window = QMainWindow() binary_image_window.setWindowTitle('二值化处理结果' ) binary_image_widget = QLabel() binary_image = QPixmap("images/binary_target.png" ) binary_image_widget.setPixmap(binary_image) binary_image_widget.setAlignment(Qt.AlignCenter) binary_image_window.setCentralWidget(binary_image_widget) binary_image_window.resize(binary_image.size()) def closeEvent (self, event ): reply = QMessageBox.question(self, '退出' , "是否要退出程序?" , QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: self.close() event.accept() else : event.ignore() if __name__ == "__main__" : app = QApplication(sys.argv) x = MainWindow() sys.exit(app.exec_())
对 新增参数接口化 配置,使用命令行 即可便捷配置模型训练参数
1 python --train_data_dir "path/to/train_data" --test_data_dir "path/to/test_data" --img_height 128 --img_width 128 --batch_size 64 --epochs 15
