912 lines
47 KiB
Plaintext
912 lines
47 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c56fea40",
|
|
"metadata": {},
|
|
"source": [
|
|
"## MODEL TRAINING"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "49c0547d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package wordnet to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package wordnet is already up-to-date!\n",
|
|
"[nltk_data] Downloading package omw-1.4 to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package omw-1.4 is already up-to-date!\n",
|
|
"[nltk_data] Downloading package stopwords to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"[nltk_data] Downloading package averaged_perceptron_tagger to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package averaged_perceptron_tagger is already up-to-\n",
|
|
"[nltk_data] date!\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1b378787d3434655aa2491f5bebd7faf",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "fb1743fd07404972b8a0b22bedcfa527",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "043d2ab3c31c487da32781800fabdf44",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "168f2251156144d39c7d0bf609653bfb",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# imdb_reviews_cls.ipynb\n",
|
|
"# train a classification model to predict the reaction of reviews\n",
|
|
"#\n",
|
|
"# author : xiao deng\n",
|
|
"# date : 20210612\n",
|
|
"# platform: Macbook pro 14\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import pickle\n",
|
|
"\n",
|
|
"import nltk\n",
|
|
"import numpy as np\n",
|
|
"from rich.progress import track\n",
|
|
"\n",
|
|
"\n",
|
|
"# 1) NLTK init\n",
|
|
"nltk.download(\"punkt\")\n",
|
|
"nltk.download('wordnet')\n",
|
|
"nltk.download('omw-1.4')\n",
|
|
"nltk.download('stopwords')\n",
|
|
"nltk.download('averaged_perceptron_tagger')\n",
|
|
"eng_stopwords = nltk.corpus.stopwords.words(\"english\")\n",
|
|
"lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()\n",
|
|
"stemmer = nltk.stem.porter.PorterStemmer()\n",
|
|
"\n",
|
|
"# 2) Read text files of train data\n",
|
|
"train_pos_dir = './aclImdb/train/pos'\n",
|
|
"train_neg_dir = './aclImdb/train/neg'\n",
|
|
"pos_paths = [f'{train_pos_dir}/{file}' for file in os.listdir(train_pos_dir)]\n",
|
|
"neg_paths = [f'{train_neg_dir}/{file}' for file in os.listdir(train_neg_dir)]\n",
|
|
"\n",
|
|
"pos_reviews = []\n",
|
|
"neg_reviews = []\n",
|
|
"\n",
|
|
"for path in pos_paths:\n",
|
|
" with open(path) as f:\n",
|
|
" pos_reviews.append(f.read())\n",
|
|
"\n",
|
|
"for path in neg_paths:\n",
|
|
" with open(path) as f:\n",
|
|
" neg_reviews.append(f.read())\n",
|
|
"\n",
|
|
"# 3) Text Preprocessing\n",
|
|
"# Step1: sentence seg\n",
|
|
"pos_sentences = [nltk.sent_tokenize(review) for review in track(pos_reviews, 'Sentence tokenize pos reivews ...')]\n",
|
|
"neg_sentences = [nltk.sent_tokenize(review) for review in track(neg_reviews, 'Sentence tokenize neg reivews ...')]\n",
|
|
"\n",
|
|
"# Step2: word seg (apply lowercase, mark removal, digit removal, stopword removal, lemma, stemming)\n",
|
|
"word_code_map = {}\n",
|
|
"pos_map = {'J': nltk.corpus.wordnet.ADJ,\n",
|
|
" 'V': nltk.corpus.wordnet.VERB,\n",
|
|
" 'R': nltk.corpus.wordnet.ADV}\n",
|
|
"# positive reviews\n",
|
|
"pos_words = []\n",
|
|
"for sentences in track(pos_sentences, 'Word tokenize pos reivews ...'):\n",
|
|
" review_words = []\n",
|
|
" \n",
|
|
" for sentence in sentences:\n",
|
|
" words = nltk.tokenize.word_tokenize(sentence)\n",
|
|
" words = [word.lower() for word in words if word.isalnum() and not word.isdigit() and word not in eng_stopwords]\n",
|
|
" pos_tags = [ele[1][0] for ele in nltk.pos_tag(words)]\n",
|
|
" pos_tags = [pos_map[tag] if tag in pos_map else nltk.corpus.wordnet.NOUN for tag in pos_tags]\n",
|
|
" words = [lemmatizer.lemmatize(word, pos=pos_tags[i]) for i, word in enumerate(words)]\n",
|
|
" words = [stemmer.stem(word) for word in words]\n",
|
|
" for word in words:\n",
|
|
" if word not in word_code_map:\n",
|
|
" word_code_map[word] = len(word_code_map)\n",
|
|
" review_words += words\n",
|
|
" \n",
|
|
" pos_words.append(review_words)\n",
|
|
"\n",
|
|
"# negative reviews\n",
|
|
"neg_words = []\n",
|
|
"for sentences in track(neg_sentences, 'Word tokenize neg reivews ...'):\n",
|
|
" review_words = []\n",
|
|
" \n",
|
|
" for sentence in sentences:\n",
|
|
" words = nltk.tokenize.word_tokenize(sentence)\n",
|
|
" words = [word.lower() for word in words if word.isalnum() and not word.isdigit() and word not in eng_stopwords]\n",
|
|
" pos_tags = [ele[1][0] for ele in nltk.pos_tag(words)]\n",
|
|
" pos_tags = [pos_map[tag] if tag in pos_map else nltk.corpus.wordnet.NOUN for tag in pos_tags]\n",
|
|
" words = [lemmatizer.lemmatize(word, pos=pos_tags[i]) for i, word in enumerate(words)]\n",
|
|
" words = [stemmer.stem(word) for word in words]\n",
|
|
" for word in words:\n",
|
|
" if word not in word_code_map:\n",
|
|
" word_code_map[word] = len(word_code_map)\n",
|
|
" review_words += words\n",
|
|
" \n",
|
|
" neg_words.append(review_words)\n",
|
|
"\n",
|
|
"with open('word_code_map.pkl', 'wb') as f:\n",
|
|
" pickle.dump(word_code_map, f)\n",
|
|
"\n",
|
|
"# 4) Encoding\n",
|
|
"max_len = max([len(review) for review in pos_words+neg_words])\n",
|
|
"pos_x = []\n",
|
|
"for sentence in pos_words:\n",
|
|
" sentence_encode = [word_code_map[word] for word in sentence]\n",
|
|
" sentence_encode = np.pad(sentence_encode, (0, max_len-len(sentence_encode)), mode='constant')\n",
|
|
" pos_x.append(sentence_encode)\n",
|
|
"pos_y = [True] * len(pos_x)\n",
|
|
"\n",
|
|
"neg_x = []\n",
|
|
"for sentence in neg_words:\n",
|
|
" sentence_encode = [word_code_map[word] for word in sentence]\n",
|
|
" sentence_encode = np.pad(sentence_encode, (0, max_len-len(sentence_encode)), mode='constant')\n",
|
|
" neg_x.append(sentence_encode)\n",
|
|
"neg_y = [False] * len(neg_x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "5f3c6c05",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train_x shape: (20000, 1486)\n",
|
|
"train_y shape: (20000,)\n",
|
|
"valid_x shape: (5000, 1486)\n",
|
|
"valid_y shape: (5000,)\n",
|
|
"Model: \"sequential_3\"\n",
|
|
"_________________________________________________________________\n",
|
|
" Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
" embedding (Embedding) (None, 1486, 128) 6250752 \n",
|
|
" \n",
|
|
" dropout (Dropout) (None, 1486, 128) 0 \n",
|
|
" \n",
|
|
" conv1d (Conv1D) (None, 1486, 64) 41024 \n",
|
|
" \n",
|
|
" global_max_pooling1d (Globa (None, 64) 0 \n",
|
|
" lMaxPooling1D) \n",
|
|
" \n",
|
|
" dense_7 (Dense) (None, 32) 2080 \n",
|
|
" \n",
|
|
" batch_normalization_2 (Batc (None, 32) 128 \n",
|
|
" hNormalization) \n",
|
|
" \n",
|
|
" dropout_1 (Dropout) (None, 32) 0 \n",
|
|
" \n",
|
|
" dense_8 (Dense) (None, 1) 33 \n",
|
|
" \n",
|
|
"=================================================================\n",
|
|
"Total params: 6,294,017\n",
|
|
"Trainable params: 6,293,953\n",
|
|
"Non-trainable params: 64\n",
|
|
"_________________________________________________________________\n",
|
|
"Epoch 1/60\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-06-13 18:17:45.783779: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"625/625 [==============================] - ETA: 0s - loss: 0.7377 - Accuracy: 0.5212"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-06-13 18:18:53.424627: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"625/625 [==============================] - 70s 106ms/step - loss: 0.7377 - Accuracy: 0.5212 - val_loss: 0.7963 - val_Accuracy: 0.5042 - lr: 1.0000e-04\n",
|
|
"Epoch 2/60\n",
|
|
"625/625 [==============================] - 63s 101ms/step - loss: 0.6486 - Accuracy: 0.6177 - val_loss: 0.7735 - val_Accuracy: 0.5574 - lr: 1.0000e-04\n",
|
|
"Epoch 3/60\n",
|
|
"625/625 [==============================] - 62s 99ms/step - loss: 0.5696 - Accuracy: 0.7036 - val_loss: 0.5795 - val_Accuracy: 0.7050 - lr: 1.0000e-04\n",
|
|
"Epoch 4/60\n",
|
|
"625/625 [==============================] - 60s 96ms/step - loss: 0.4933 - Accuracy: 0.7629 - val_loss: 0.4854 - val_Accuracy: 0.7666 - lr: 1.0000e-04\n",
|
|
"Epoch 5/60\n",
|
|
"625/625 [==============================] - 63s 101ms/step - loss: 0.4167 - Accuracy: 0.8115 - val_loss: 0.4602 - val_Accuracy: 0.7908 - lr: 1.0000e-04\n",
|
|
"Epoch 6/60\n",
|
|
"625/625 [==============================] - 65s 103ms/step - loss: 0.3633 - Accuracy: 0.8471 - val_loss: 0.4089 - val_Accuracy: 0.8174 - lr: 1.0000e-04\n",
|
|
"Epoch 7/60\n",
|
|
"625/625 [==============================] - 66s 105ms/step - loss: 0.3154 - Accuracy: 0.8703 - val_loss: 0.3897 - val_Accuracy: 0.8308 - lr: 1.0000e-04\n",
|
|
"Epoch 8/60\n",
|
|
"625/625 [==============================] - 61s 98ms/step - loss: 0.2759 - Accuracy: 0.8866 - val_loss: 0.3817 - val_Accuracy: 0.8354 - lr: 1.0000e-04\n",
|
|
"Epoch 9/60\n",
|
|
"625/625 [==============================] - 64s 102ms/step - loss: 0.2406 - Accuracy: 0.9042 - val_loss: 0.3684 - val_Accuracy: 0.8470 - lr: 1.0000e-04\n",
|
|
"Epoch 10/60\n",
|
|
"625/625 [==============================] - 65s 104ms/step - loss: 0.2168 - Accuracy: 0.9165 - val_loss: 0.3664 - val_Accuracy: 0.8488 - lr: 1.0000e-04\n",
|
|
"Epoch 11/60\n",
|
|
"625/625 [==============================] - 64s 102ms/step - loss: 0.1838 - Accuracy: 0.9299 - val_loss: 0.3825 - val_Accuracy: 0.8484 - lr: 1.0000e-04\n",
|
|
"Epoch 12/60\n",
|
|
"625/625 [==============================] - 67s 106ms/step - loss: 0.1657 - Accuracy: 0.9376 - val_loss: 0.3853 - val_Accuracy: 0.8546 - lr: 1.0000e-04\n",
|
|
"Epoch 13/60\n",
|
|
"625/625 [==============================] - 66s 106ms/step - loss: 0.1514 - Accuracy: 0.9450 - val_loss: 0.3751 - val_Accuracy: 0.8590 - lr: 1.0000e-04\n",
|
|
"Epoch 14/60\n",
|
|
"625/625 [==============================] - 66s 106ms/step - loss: 0.1306 - Accuracy: 0.9509 - val_loss: 0.3860 - val_Accuracy: 0.8600 - lr: 1.0000e-04\n",
|
|
"Epoch 15/60\n",
|
|
"625/625 [==============================] - ETA: 0s - loss: 0.1237 - Accuracy: 0.9556\n",
|
|
"Epoch 15: ReduceLROnPlateau reducing learning rate to 1.9999999494757503e-05.\n",
|
|
"625/625 [==============================] - 68s 108ms/step - loss: 0.1237 - Accuracy: 0.9556 - val_loss: 0.4041 - val_Accuracy: 0.8604 - lr: 1.0000e-04\n",
|
|
"Epoch 16/60\n",
|
|
"625/625 [==============================] - 69s 111ms/step - loss: 0.1052 - Accuracy: 0.9626 - val_loss: 0.4071 - val_Accuracy: 0.8608 - lr: 2.0000e-05\n",
|
|
"Epoch 17/60\n",
|
|
"625/625 [==============================] - 70s 112ms/step - loss: 0.1012 - Accuracy: 0.9648 - val_loss: 0.4050 - val_Accuracy: 0.8608 - lr: 2.0000e-05\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"from tensorflow.keras import Sequential\n",
|
|
"from tensorflow.keras.losses import BinaryCrossentropy\n",
|
|
"from tensorflow.keras.layers import Dense, Embedding, GlobalMaxPooling1D, BatchNormalization\n",
|
|
"from tensorflow.keras.layers import Dropout, Conv1D\n",
|
|
"from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ReduceLROnPlateau\n",
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
"\n",
|
|
"\n",
|
|
"# 5) Split training & validation set\n",
|
|
"seed = 0\n",
|
|
"random.Random(seed).shuffle(pos_x)\n",
|
|
"random.Random(seed).shuffle(neg_x)\n",
|
|
"\n",
|
|
"train_ratio = 0.8\n",
|
|
"idx_p = int(len(pos_x) * train_ratio)\n",
|
|
"idx_n = int(len(neg_x) * train_ratio)\n",
|
|
"train_x = pos_x[:idx_p] + neg_x[:idx_n]\n",
|
|
"train_y = pos_y[:idx_p] + neg_y[:idx_n]\n",
|
|
"valid_x = pos_x[idx_p:] + neg_x[idx_n:]\n",
|
|
"valid_y = pos_y[idx_p:] + neg_y[idx_n:]\n",
|
|
"\n",
|
|
"seed = 1\n",
|
|
"random.Random(seed).shuffle(train_x)\n",
|
|
"random.Random(seed).shuffle(train_y)\n",
|
|
"random.Random(seed).shuffle(valid_x)\n",
|
|
"random.Random(seed).shuffle(valid_y)\n",
|
|
"\n",
|
|
"train_x = np.array(train_x)\n",
|
|
"train_y = np.array(train_y)\n",
|
|
"valid_x = np.array(valid_x)\n",
|
|
"valid_y = np.array(valid_y)\n",
|
|
"\n",
|
|
"print(f'train_x shape: {train_x.shape}')\n",
|
|
"print(f'train_y shape: {train_y.shape}')\n",
|
|
"print(f'valid_x shape: {valid_x.shape}')\n",
|
|
"print(f'valid_y shape: {valid_y.shape}')\n",
|
|
"\n",
|
|
"# 6) Build CNN model with word2vec in TF (I used a similar model to classify EMR in my master thesis)\n",
|
|
"# ref: extension://bfdogplmndidlpjfhoijckpakkdjkkil/pdf/viewer.html?file=https%3A%2F%2Farxiv.org%2Fpdf%2F1408.5882.pdf\n",
|
|
"embedding_dim = 128\n",
|
|
"model = Sequential([\n",
|
|
" Embedding(len(word_code_map), embedding_dim, input_length=max_len),\n",
|
|
" Dropout(0.5),\n",
|
|
" Conv1D(64, 5, padding='same', activation='relu', strides=1),\n",
|
|
" GlobalMaxPooling1D(),\n",
|
|
" Dense(32, activation='linear'),\n",
|
|
" BatchNormalization(),\n",
|
|
" Dropout(0.5),\n",
|
|
" Dense(1, activation='sigmoid')\n",
|
|
"])\n",
|
|
"\n",
|
|
"init_lr = 1e-4\n",
|
|
"model.compile(optimizer=Adam(learning_rate=init_lr),\n",
|
|
" loss=BinaryCrossentropy(from_logits=True),\n",
|
|
" metrics=['Accuracy'])\n",
|
|
"\n",
|
|
"model.summary()\n",
|
|
"\n",
|
|
"# 7) Fit the training data\n",
|
|
"callbacks = [TensorBoard(log_dir=\"logs\"),\n",
|
|
" EarlyStopping(patience=7), # prevent overfitting\n",
|
|
" ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, verbose=1, mode='auto', min_lr=init_lr/100)]\n",
|
|
"\n",
|
|
"model.fit(\n",
|
|
" train_x,\n",
|
|
" train_y,\n",
|
|
" batch_size=32,\n",
|
|
" validation_data=(valid_x, valid_y),\n",
|
|
" epochs=60,\n",
|
|
" callbacks=callbacks\n",
|
|
")\n",
|
|
"\n",
|
|
"model.save('imdb_cls_model.h5')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d9174db3",
|
|
"metadata": {},
|
|
"source": [
|
|
"## MODEL EVALUATION"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "08ee9c8f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package wordnet to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package wordnet is already up-to-date!\n",
|
|
"[nltk_data] Downloading package omw-1.4 to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package omw-1.4 is already up-to-date!\n",
|
|
"[nltk_data] Downloading package stopwords to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"[nltk_data] Downloading package averaged_perceptron_tagger to\n",
|
|
"[nltk_data] /Users/xiao_deng/nltk_data...\n",
|
|
"[nltk_data] Package averaged_perceptron_tagger is already up-to-\n",
|
|
"[nltk_data] date!\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a91d0135a31f45c18d410bc7f0e7e1ef",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "c5429bdd168d413087e8e6c43e71a5a8",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "fc0aafe9b45b45bda8b1e2bb1617b92f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "fdafa9095d7c4b6b97f727d2f926df04",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
],
|
|
"text/plain": []
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"import pickle\n",
|
|
"\n",
|
|
"import nltk\n",
|
|
"import numpy as np\n",
|
|
"from rich.progress import track\n",
|
|
"\n",
|
|
"\n",
|
|
"# 1) NLTK init\n",
|
|
"nltk.download(\"punkt\")\n",
|
|
"nltk.download('wordnet')\n",
|
|
"nltk.download('omw-1.4')\n",
|
|
"nltk.download('stopwords')\n",
|
|
"nltk.download('averaged_perceptron_tagger')\n",
|
|
"eng_stopwords = nltk.corpus.stopwords.words(\"english\")\n",
|
|
"lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()\n",
|
|
"stemmer = nltk.stem.porter.PorterStemmer()\n",
|
|
"\n",
|
|
"# 2) Read text files of train data\n",
|
|
"test_pos_dir = './aclImdb/test/pos'\n",
|
|
"test_neg_dir = './aclImdb/test/neg'\n",
|
|
"pos_paths = [f'{test_pos_dir}/{file}' for file in os.listdir(test_pos_dir)]\n",
|
|
"neg_paths = [f'{test_neg_dir}/{file}' for file in os.listdir(test_neg_dir)]\n",
|
|
"\n",
|
|
"pos_reviews = []\n",
|
|
"neg_reviews = []\n",
|
|
"\n",
|
|
"for path in pos_paths:\n",
|
|
" with open(path) as f:\n",
|
|
" pos_reviews.append(f.read())\n",
|
|
"\n",
|
|
"for path in neg_paths:\n",
|
|
" with open(path) as f:\n",
|
|
" neg_reviews.append(f.read())\n",
|
|
"\n",
|
|
"# 3) Text Preprocessing\n",
|
|
"# Step1: sentence seg\n",
|
|
"pos_sentences = [nltk.sent_tokenize(review) for review in track(pos_reviews, 'Sentence tokenize pos reivews ...')]\n",
|
|
"neg_sentences = [nltk.sent_tokenize(review) for review in track(neg_reviews, 'Sentence tokenize neg reivews ...')]\n",
|
|
"\n",
|
|
"# Step2: word seg (apply lowercase, mark removal, digit removal, stopword removal, lemma, stemming)\n",
|
|
"with open('word_code_map.pkl', 'rb') as f:\n",
|
|
" word_code_map = pickle.load(f)\n",
|
|
"pos_map = {'J': nltk.corpus.wordnet.ADJ,\n",
|
|
" 'V': nltk.corpus.wordnet.VERB,\n",
|
|
" 'R': nltk.corpus.wordnet.ADV}\n",
|
|
"# positive reviews\n",
|
|
"pos_words = []\n",
|
|
"for sentences in track(pos_sentences, 'Word tokenize pos reivews ...'):\n",
|
|
" review_words = []\n",
|
|
" \n",
|
|
" for sentence in sentences:\n",
|
|
" words = nltk.tokenize.word_tokenize(sentence)\n",
|
|
" words = [word.lower() for word in words if word.isalnum() and not word.isdigit() and word not in eng_stopwords]\n",
|
|
" pos_tags = [ele[1][0] for ele in nltk.pos_tag(words)]\n",
|
|
" pos_tags = [pos_map[tag] if tag in pos_map else nltk.corpus.wordnet.NOUN for tag in pos_tags]\n",
|
|
" words = [lemmatizer.lemmatize(word, pos=pos_tags[i]) for i, word in enumerate(words)]\n",
|
|
" words = [stemmer.stem(word) for word in words]\n",
|
|
" review_words += words\n",
|
|
" \n",
|
|
" pos_words.append(review_words)\n",
|
|
"\n",
|
|
"# negative reviews\n",
|
|
"neg_words = []\n",
|
|
"for sentences in track(neg_sentences, 'Word tokenize neg reivews ...'):\n",
|
|
" review_words = []\n",
|
|
" \n",
|
|
" for sentence in sentences:\n",
|
|
" words = nltk.tokenize.word_tokenize(sentence)\n",
|
|
" words = [word.lower() for word in words if word.isalnum() and not word.isdigit() and word not in eng_stopwords]\n",
|
|
" pos_tags = [ele[1][0] for ele in nltk.pos_tag(words)]\n",
|
|
" pos_tags = [pos_map[tag] if tag in pos_map else nltk.corpus.wordnet.NOUN for tag in pos_tags]\n",
|
|
" words = [lemmatizer.lemmatize(word, pos=pos_tags[i]) for i, word in enumerate(words)]\n",
|
|
" words = [stemmer.stem(word) for word in words]\n",
|
|
" review_words += words\n",
|
|
" \n",
|
|
" neg_words.append(review_words)\n",
|
|
"\n",
|
|
"# 4) Encoding\n",
|
|
"max_len = 1486 # max_len of training set\n",
|
|
"pos_x = []\n",
|
|
"for sentence in pos_words:\n",
|
|
" sentence_encode = [word_code_map[word] for word in sentence if word in word_code_map]\n",
|
|
" if len(sentence_encode) <= max_len:\n",
|
|
" sentence_encode = np.pad(sentence_encode, (0, max_len-len(sentence_encode)), mode='constant')\n",
|
|
" else:\n",
|
|
" sentence_encode = sentence_encode[:max_len]\n",
|
|
" pos_x.append(sentence_encode)\n",
|
|
"pos_y = [True] * len(pos_x)\n",
|
|
"\n",
|
|
"neg_x = []\n",
|
|
"for sentence in neg_words:\n",
|
|
" sentence_encode = [word_code_map[word] for word in sentence if word in word_code_map]\n",
|
|
" if len(sentence_encode) <= max_len:\n",
|
|
" sentence_encode = np.pad(sentence_encode, (0, max_len-len(sentence_encode)), mode='constant')\n",
|
|
" else:\n",
|
|
" sentence_encode = sentence_encode[:max_len]\n",
|
|
" neg_x.append(sentence_encode)\n",
|
|
"neg_y = [False] * len(neg_x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "828afa33",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"test_x shape: (25000, 1486)\n",
|
|
"test_y shape: (25000,)\n",
|
|
"Metal device set to: Apple M1 Pro\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-06-13 19:45:13.666996: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n",
|
|
"2022-06-13 19:45:13.667118: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 10/782 [..............................] - ETA: 4s "
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-06-13 19:45:13.949010: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n",
|
|
"2022-06-13 19:45:13.994430: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"782/782 [==============================] - 3s 4ms/step\n",
|
|
"test_pred shape: (25000,)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAElCAYAAAAFukKMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqj0lEQVR4nO3debwVdf3H8df7siuLgGgIKpj7Si6kuGsKVopZKqmJS5KGa2aplfualaWmSVbikoq7Zm4/Tc1yA1wQSCVJQUjcEhBl8/P7Y74XD9fL5dzDnXvvOef99DGPM/Odme/3Owif8z3f+c53FBGYmVllq2npCpiZWf4c7M3MqoCDvZlZFXCwNzOrAg72ZmZVwMHezKwKONjbCpPUSdK9kj6UdOsK5HOwpIeasm4tQdL9koa3dD3MCjnYVxFJB0kaK2mupJkpKO3QBFl/C1gd6BkR+5eaSUTcGBF7NkF9liJpF0kh6Y466Vuk9MeKzOcsSTcs77iI2CsiRpdYXbNcONhXCUk/AH4NXEAWmNcCrgSGNkH2awOvRsSiJsgrL+8AgyT1LEgbDrzaVAUo439T1ir5L2YVkNQNOAcYGRF3RMRHEbEwIu6NiFPSMR0k/VrSjLT8WlKHtG8XSdMlnSxpVvpVcHjadzZwBnBg+sVwZN0WsKR+qQXdNm0fJul1SXMkTZV0cEH6kwXnDZL0XOoeek7SoIJ9j0k6V9I/Uj4PSVq1gT+GBcBdwLB0fhvgAODGOn9Wv5E0TdJsSeMk7ZjShwCnF1zniwX1OF/SP4B5wDop7btp/1WSbivI/2JJj0hSsf//zJqCg3112A7oCNzZwDE/AbYFBgBbAAOBnxbs/wLQDegDHAn8VlL3iDiT7NfCLRHROSL+0FBFJK0MXAbsFRFdgEHAC/Uc1wO4Lx3bE/gVcF+dlvlBwOHAakB74IcNlQ1cBxya1gcDE4EZdY55juzPoAfwZ+BWSR0j4oE617lFwTnfAUYAXYA36uR3MrB5+iLbkezPbnh4nhJrZg721aEn8O5yulkOBs6JiFkR8Q5wNlkQq7Uw7V8YEX8F5gIblFifT4FNJXWKiJkRMbGeY74GvBYR10fEooi4CfgXsHfBMX+KiFcj4mNgDFmQXqaI+CfQQ9IGZEH/unqOuSEi3ktl/hLowPKv89qImJjOWVgnv3nAIWRfVjcAx0XE9OXkZ9bkHOyrw3vAqrXdKMuwBku3St9IaUvyqPNlMQ/o3NiKRMRHwIHA0cBMSfdJ2rCI+tTWqU/B9n9LqM/1wLHArtTzSyd1VU1OXUf/I/s101D3EMC0hnZGxLPA64DIvpTMmp2DfXV4CvgE2LeBY2aQ3WittRaf7+Io1kfASgXbXyjcGREPRsQeQG+y1vrvi6hPbZ3eKrFOta4Hvg/8NbW6l0jdLD8m68vvHhGrAB+SBWmAZXW9NNglI2kk2S+EGcCPSq652QpwsK8CEfEh2U3U30raV9JKktpJ2kvSz9NhNwE/ldQr3eg8g6zboRQvADtJWivdHD6tdoek1SXtk/ru55N1By2uJ4+/Auun4aJtJR0IbAz8pcQ6ARARU4Gdye5R1NUFWEQ2cqetpDOArgX73wb6NWbEjaT1gfPIunK+A/xI0oDSam9WOgf7KhERvwJ+QHbT9R2yrodjyUaoQBaQxgIvAROA8SmtlLIeBm5JeY1j6QBdQ3bTcgbwPlng/X49ebwHfD0d+x5Zi/jrEfFuKXWqk/eTEVHfr5YHgfvJhmO+QfZrqLCLpvaBsfckjV9eOanb7Abg4oh4MSJeIxvRc33tSCez5iIPCjAzq3xu2ZuZVQEHezOzKuBgb2ZWBRzszcyqgIO9mVkVcLA3M6sCDvZmZlXAwd7MrAo42JuZVQEHezOzKuBgb2ZWBRzszcyqgIO9mVkVcLA3M6sCDvZmZlXAwd7MrAo42JuZVQEHezOzKuBgb2ZWBdq2dAWWpdOXjvXLce1zPnjuipaugrVCHduiFc2jMTHn4+evWOHymlurDfZmZs1Kld3R4WBvZgagsmusN4qDvZkZuGVvZlYV3LI3M6sCNW1auga5crA3MwN345iZVQV345iZVQG37M3MqoBb9mZmVcAtezOzKuDROGZmVcAtezOzKlDjPnszs8rnlr2ZWRXwaBwzsyrgG7RmZlXA3ThmZlXA3ThmZlXALXszsyrglr2ZWRVwy97MrAp4NI6ZWRVwy97MrApUeJ99ZX+VmZkVSzXFL8vLSvqjpFmSXi5I6yHpYUmvpc/uBftOkzRF0iuSBhekbyVpQtp3mZR9I0nqIOmWlP6MpH7Lq5ODvZkZZC37YpfluxYYUiftVOCRiFgPeCRtI2ljYBiwSTrnSkm1NxCuAkYA66WlNs8jgQ8iYl3gUuDi5VXIwd7MDJq0ZR8RTwDv10keCoxO66OBfQvSb46I+RExFZgCDJTUG+gaEU9FRADX1TmnNq/bgN1rW/3L4mBvZgaopqb4RRohaWzBMqKIIlaPiJkA6XO1lN4HmFZw3PSU1iet101f6pyIWAR8CPRsqHDfoDUzA5bTMF5KRIwCRjVV0fUV0UB6Q+csk1v2ZmaQhc9il9K8nbpmSJ+zUvp0YM2C4/oCM1J633rSlzpHUlugG5/vNlqKg72ZGVnLvtilRPcAw9P6cODugvRhaYRNf7Ibsc+mrp45krZN/fGH1jmnNq9vAY+mfv1lcjeOmRmN68YpIq+bgF2AVSVNB84ELgLGSDoSeBPYHyAiJkoaA0wCFgEjI2JxyuoYspE9nYD70wLwB+B6SVPIWvTDllcnB3szM6Cmpuk6OiLi28vYtfsyjj8fOL+e9LHApvWkf0L6siiWg72ZGaxIX3xZyL3PXtLakr6S1jtJ6pJ3mWZmjdUMffYtKtdgL+kosgH/V6ekvsBdeZZpZlYKB/sVMxLYHpgNEBGv8dmDBGZmrUalB/u8++znR8SC2j+cNB60weFBZmYtoVyDeLHyDvaPSzod6CRpD+D7wL05l2lm1miqqexgn3c3zqnAO8AE4HvAX4Gf5lymmVmjuRtnxQwFrouI3+dcjpnZCinXIF6svFv2+wCvSrpe0tdSn72ZWeuT/9w4LSrXYB8RhwPrArcCBwH/lnRNnmWamZXC3TgrKCIWSrqfbBROJ7Kune/mXa6ZWWOUaxAvVq7BXtIQsgl6dgUeA64BDsizTDOzUjTl3DitUd4t+8OAm4HvRcT8nMsyMytdZTfs8w32EbHcaTfNzFoDd+OUQNKTEbGDpDks/cSsgIiIrnmUa2ZWKgf7EkTEDunTM1yaWVmo9GCf96yX1xeTZmbW0lSjopdylPcN2k0KN9JDVVvlXGar9LszD2avnTblnffnsPX+FwCw31e+xE+O/iob9l+dHb/zC8ZPehOA3b68Iecevw/t27VlwcJFnP7ru3j8uVfpvFIH/u+PJy3Js89qq3DzX5/jlF/czlq9u/O7Mw9h1e6d+WD2PI74yWjemvW/lrhUW0GX/+ZS7r3nLmZ/OJunxz6/JP26a//EnbffSpu2bejevQdnn3cBa6zRh39Nnsz5557F3LlzadOmhu+OOIYhe30VgJtuvIEbrx/NtGlv8tiTT9G9e4+WuqxWzy37Ekg6LfXXby5pdlrmAG/z2Qtzq8r19z7N0JG/XSpt4r9nMOzk3/Pk+H8vlf7e/+byrROvZpsDLuCoM67nj+cdCsDcefPZdthFS5Y3Z77PXY++AMCFJ32DG+97loEHXsgFo+7nnOP2aZbrsqa38y67cuPNt34ufcONNuLPY27ntjvvZY89B3PpLy8BoGOnjpx34cXcec99XHn1NVxy0QXMnj0bgAFbbsnVf/gTa6zRp1mvoRxV+kNVuQT7iLgw9ddfEhFd09IlInpGxGl5lNna/WP8v3n/w3lLpb0y9W1ee2PW54598ZXpzHznQwAm/XsmHdq3o327pX+EfXGtXqzWowv/SF8UG67Tm8eeeQWAx597la/vslkel2HNYPMtBtCr1+df+zDwy9vSqVMnADbbYgCz/vtfAPr168/aa/cDYLXVVqdHjx588MH7AGy00cb06dO3eSpe5hzsV0BEnCapu6SBknaqXfIss9J84ysDePGVaSxYuGip9AOGbMVtD41fsj3h1bfYd/cBAAzdbQu6du5Ej24rN2dVrRndefttbL/j5/8pTXjpJRYuWsiaa67VArUqc54bp3SSvgs8ATwInJ0+z2rg+BGSxkoau+jdiXlWrSxstM4XOO/4oRx73s2f27f/4K0Y88DYJdunXXonO261Lk/d9GN23Gpd3nr7AxYtXtyc1bVm8pd772bSxJc57IilZx15551Z/OS0UzjnvAsr/mnQPFR6yz7vG7QnANsAT0fErpI2JAv69YqIUcAogE5fOraq32jVZ7VVuOVXI/juz65n6vR3l9q32fp9aNumDc9PnrYkbeY7HzLsh9kccyt3as++uw9g9txPmrXOlr+nn/on14z6HX+49gbat2+/JH3u3Lkce8z3OPb4E9l8iwEtV8EyVlOmo2yKlffX/ycR8QmApA4R8S9gg5zLLHvdOnfijsuP5ozL7+GpF1//3P4DhizdqgfoucrKS1ocpxwxmNF3P90sdbXmM3nyJM49+wx+c8VV9OzZc0n6wgULOOn4key9z1D2HLxXC9awvFV6yz7vYD9d0irAXcDDku4GZuRcZqs0+sLDeGz0yay/9upMeeBchu+7HfvsujlTHjiXL2/ejzsuO5p7fjsSgKOH7cQX1+zFqUcN4embT+Xpm0+lV/fOS/L65h5bMuaBcUvlv9PW6/HSXT/jpbvOYLWeXbj4mgeb9fqs6Vz6i5+zx2478cknH7PHbjtx1W8vX5I+b948TjnpBA7YbyjHjzwagAcfvJ/x48Zyz113csB+Qzlgv6H8a/JkAG684Tr22G0n3n77v+z/jX0464yftNh1tXZS8Us5UkTz9JZI2hnoBjwQEQuWd3y1d+NY/T547oqWroK1Qh3brvht0w1+/GDRMeeViweXXcjPe4rjwic4JqRPB3Eza3XKtcVerLxv0I4H1gQ+IBuwtAowU9Is4KiIGNfAuWZmzcY3aFfMA8BXI2LViOgJ7AWMAb4PXJlz2WZmRaupUdFLOco72G8dEUvuFEbEQ8BOEfE00CHnss3MilbpN2jz7sZ5X9KPyd5WBXAg8IGkNsCnOZdtZla0ch1SWay8W/YHAX3Jhl7eRdZ/fxDQBr+L1sxakUofZ5/3awnfBY6T1Dki5tbZPSXPss3MGqNMY3jR8p4bZ5CkScCktL2FJN+YNbNWxzdoV8ylwGDgPYCIeBHwrJdm1uo0ZTeOpJMkTZT0sqSbJHWU1EPSw5JeS5/dC44/TdIUSa9IGlyQvpWkCWnfZVqBPqTcp8aLiGl1kjwVo5m1Ok01GkdSH+B4stGIm5LdoxwGnAo8EhHrAY+kbSRtnPZvAgwBrkyDWACuAkYA66VlSKnXl3ewnyZpEBCS2kv6ITA55zLNzBqtiW/QtgU6KXsV60pkc4INBUan/aOBfdP6UODmiJgfEVPJ7mcOlNQb6BoRT0U2r811Bec0Wt7B/mhgJNAHmA4MSNtmZq1KY1r2he/eSMuI2nwi4i3gF8CbwEzgw/SM0eoRMTMdMxOofR1ZH6CwB2R6SquNm3XTS9Ico3EOzrMMM7Om0Jju8MJ3b9STT3ey1np/4H/ArZIOaajo+opoIL0kuQR7SWc0sDsi4tw8yjUzK1UTjrL5CjA1It4BkHQHMAh4W1LviJiZumhqX0A9newZpFp9ybp9pqf1uuklyasb56N6FoAjgR/nVKaZWcmacLqEN4FtJa2URs/sTnav8h5geDpmOHB3Wr8HGCapg6T+ZDdin01dPXMkbZvyObTgnEbLpWUfEb+sXZfUhez1hIeTTZvwy2WdZ2bWUprqydiIeEbSbWSz/i4Cnifr8ukMjJF0JNkXwv7p+ImSxpA9j7QIGBkRtaMWjwGuBToB96elJLn12ae57H9A1mc/GtgyIj7IqzwzsxXRlE/QRsSZwJl1kueTtfLrO/584Px60scCmzZFnfLqs78E2I/s22yzeqZKMDNrVcp1zptiFRXs00MCaxceHxFPNHDKyWTfYj8FflLwh6js1OhaUm3NzHJS9cFe0sVkUxNP4rOnXwNYZrCPiNyfzDUza0rlOudNsYpp2e8LbBAR83Oui5lZi6nwhn1Rwf51oB1Zt4yZWUWq2m4cSZeTddfMA16Q9AgFAT8ijs+/emZmzaPCY32DLfux6XMc2aD/QiU/smtm1hrVVHi0X2awj4jRAJJOiIjfFO6TdELeFTMza06VfoO2mFEzw+tJO6yJ62Fm1qJqVPxSjhrqs/822cvB+0sq7MbpQnrzlJlZpajaG7TAP8nmYl6VpeezmQO8lGelzMyaW4XH+gb77N8A3gC2a77qmJm1DNU7fXzlKOYJ2jl8NvqmPdmY+4885YGZVZJy7Ysv1nKDfUR0KdyWtC8wMK8KmZm1BI/GqSMi7gJ2a/qqmJm1nBqp6KUcFdONs1/BZg2wNX6oyswqTJnG8KIVMzfO3gXri4D/kL1M18ysYlTz0EsktQFeiohLm6k+ZmYtosJjfcN99uk9iPs0U13MzFpMG6nopRwV043zT0lXALcAH9UmRsT43GplZtbMqrYbR9JDEbEnMCglnVOwO/CIHDOrIBU+8rLBln0vgIjYtZnqYmbWYqq2ZQ90qzPscikRcUcO9TEzaxEVHusbDvbA16HeCSMCcLA3s4pRzS37NyLiiGariZlZC2pT4Z32DQX7yr5yM7MClR7wGgr232m2WpiZtbBynfOmWA3NZ/9yc1bEzKwlVXisL+qhKjOzilfNN2jNzKpGhcf6Bp+gnUADUxlHxOa51MjMrAVU82icr6fPkenz+vR5MDAvtxqZmbWAqu3GSS8cR9L2EbF9wa5TJf2DpefKaXITHrwkz+ytTK160LUtXQVrheaOOWyF82j0a/vKTDHXt7KkHWo3JA0CVs6vSmZmzU9S0Us5KibYHwn8VtJ/JP0HuBLwk7VmVlFqVPyyPJJWkXSbpH9JmixpO0k9JD0s6bX02b3g+NMkTZH0iqTBBelbSZqQ9l2mFfimWW6wj4hxEbEFsDmwRUQM8Fz2ZlZp2tSo6KUIvwEeiIgNgS2AycCpwCMRsR7wSNpG0sbAMGATYAhwZXpLIMBVwAhgvbQMKfX6innheAfgm0A/oG3tF0tE5Npnb2bWnJpqMI6krsBOwGEAEbEAWCBpKLBLOmw08BjwY7J3et8cEfOBqZKmAANTT0rXiHgq5XsdsC9wfyn1KqYb5+5UmUVkb6qqXczMKobUmEUjJI0tWEYUZLUO8A7wJ0nPS7pG0srA6hExEyB9rpaO7wNMKzh/ekrrk9brppekmIeq+kZEyT8dzMzKQWPmxomIUcCoZexuC2wJHBcRz0j6DanLZhmWNY38stJLUkzL/p+SNiu1ADOzclDTiGU5pgPTI+KZtH0bWfB/W1JvgPQ5q+D4NQvO7wvMSOl960kvSTHBfgdgXLpL/FK6M/xSqQWambVGjenGaUhE/BeYJmmDlLQ7MAm4Bxie0oaTdZGT0odJ6iCpP9mN2GdTV88cSdumUTiHFpzTaMV04+xVauZmZuWiiadLOA64UVJ74HXgcLLG9RhJRwJvAvsDRMRESWPIvhAWASMjYnHK5xjgWqAT2Y3Zkm7OQnHBvuQ+IjOzctGUsT4iXgC2rmfX7ss4/nzg/HrSxwKbNkWdign29/HZzYKOQH/gFbIxoWZmFaFqX15SKyKWujkraUvge7nVyMysBVR4rG/8fPYRMV7SNnlUxsyspVT4DMdFPUH7g4LNGrIhRO/kViMzsxagCn/leDEt+y4F64vI+vBvz6c6ZmYto22Fz3FcTJ/92QCSumSbMTf3WpmZNbNynbq4WMv9LpO0qaTngZeBiZLGSWqSoUBmZq1FU05x3BoV040zCvhBRPwNQNIuKW1QftUyM2teFd6wLyrYr1wb6AEi4rE0g5uZWcWo+nH2wOuSfsZnLxw/BJiaX5XMzJpfmwq/QVvM5R0B9ALuSMuqZPM8mJlVjBpU9FKOGmzZp1dj3RoRX2mm+piZtYgK78VpONhHxGJJ8yR1i4gPm6tSZmbNrVxH2RSrmD77T4AJkh6m4HWEEXF8brUyM2tmvkGbPTF7X94VMTNrSRUe64t6gnZ0c1TEzKwlNfHLS1qdZY7GkTRU0siC7WckvZ6WbzVP9czMmkcTvoO2VWqo3j8iezdirQ7ANsAuZK/KMjOrGJKKXspRQ9047SNiWsH2kxHxHvCen6A1s0pTniG8eA0F++6FGxFxbMFmr3yqY2bWMip9NE5D3TjPSDqqbqKk7wHP5lclM7Pmp0Ys5aihlv1JwF2SDgLGp7StyPru9825XmZmzaqmwkfjLDPYR8QsYJCk3YBNUvJ9EfFos9TMzKwZlesom2IVM87+UcAB3swqWrmOsilWMU/QmplVvMoO9Q72ZmaAW/ZmZlWhjYO9mVnlq+xQn/MNaEnrS3pE0stpe3NJP82zTDOzUkjFL+Uo79FGvwdOAxYCRMRLwLCcyzQza7Sqfi1hE1gpIp6tc+NjUc5lmpk1Wrm22IuVd7B/V9IXgQBIUyPPzLlMM7NGU5m22IuVd7AfCYwCNpT0FjAVODjnMs3MGs2jcVbMGxHxlTQlck1EzMm5PDOzklR4rM/9Bu1USaOAbYG5OZdlZlYyj8ZZMRsA/0fWnTNV0hWSdsi5TDOzRlMj/isqP6mNpOcl/SVt95D0sKTX0mf3gmNPkzRF0iuSBhekbyVpQtp3mVbgMd9cg31EfBwRYyJiP+BLQFfg8TzLNDMrRY2KX4p0AjC5YPtU4JGIWA94JG0jaWOyIembAEOAKyW1SedcBYwA1kvLkJKvr9QTiyVpZ0lXks2J3xE4IO8yzcwaq0YqelkeSX2BrwHXFCQPBUan9dF89l6QocDNETE/IqYCU4CBknoDXSPiqYgI4DpW4F0iud6glTQVeAEYA5wSER/lWZ6ZWakaM/RS0giyFnetURExqmD718CPgC4FaatHxEyAiJgpabWU3gd4uuC46SltYVqvm16SvEfjbBERs3Muoyx98snHXPizU/jvjOnU1NQwcPudOfzoE5bs//ujD3LjH69Ggv7rrs+PzrwIgL133pK111kXgF6r9+bMi34DwK8vOosp/5pERNBnzbU56fRz6LTSSs1/YbbCvrldP07Zb3Pa1IgHxk/nZzeOA6Bvz5UZNXIHuq3cnjY14ow/j+Oh599acl6XTu0Yd+m+3Pvsm5z8x2cAeOjsvejcqR0Avbp2ZOy/3+Xbl/j1FPVpzIuqUmAfVd8+SV8HZkXEOEm7FJFdfSVHA+klySXYS/pRRPwcOF/S5yoXEcfnUW652e/bw9liy21YuHAhPzlxBGOffpKtt92Bt6a9wZgb/sglV11Lly5d+d8H7y85p32HDlzxpzGfy2vEcT9kpZU7A/D7y3/BvXfczAGHHNFs12JNo0fnDpz3na3Z8cf38u6c+Vw9cgd22bQ3j708kx9/c3PueOo/XPPwK2zYpxu3n7YHmxx725Jzf3bgl3hy0ttL5bfnmfcvWb/x5F34y3PTmu1ayk0TPlS1PbCPpK+SdV13lXQD8Lak3qlV3xuYlY6fDqxZcH5fYEZK71tPekny6rOvvSkxFhhXz1L1OnbsxBZbbgNAu3bt+OL6G/LurOwf6oP33sHXv3EgXbp0BWCV7j2Wm19toI8IFsyfX/Fzc1eqfqt3ZsqM2bw7Zz4Af3tpBkO/vDYAEdBlpayV3nWl9sz8YN6S8wb078lq3TrxyIv1x4LOHduy0ya9+ctzb+Z8BeWrqYZeRsRpEdE3IvqR3Xh9NCIOAe4BhqfDhgN3p/V7gGGSOkjqT3Yj9tnU5TNH0rZpFM6hBec0Wi4t+4i4N63Oi4hbC/dJ2j+PMsvZ3DmzeeYfT7DP/tnDxW9NewOAHx4znE8//ZSDjjiarb+8PQALFizghO8eRJs2bdj/4MPZbqfdluRz6QVnMPbpJ1mz3zoceewPmv9CbIW9/t85rN+nG2v16sxb733E3gPXol3bbGDG+be+wD0/3ZOjh2zESh3asve5DwFZ8Lnw0G347hVPsMuma9Sb794D1+bxl2cy5+OFzXYt5aYZmkcXAWMkHQm8CewPEBETJY0BJpHNHTYyIhanc44BrgU6AfenpSR599mfBtxaRBqw9E2Pcy+5nGGHHplv7VqBxYsW8fOzT2Ofb32b3mtkv9gWL17MjOlvctHl1/DurFn86NjDuXL0bXTu0pVrb7ufnquuxswZ0zn9hKPo98X16N0n+wV40unnsHjxYn7364v4+yMPssfX9m3BK7NS/O+jBZx4zVOMPnFnPo3gmVdm0X/17B7f/tv354bHpnD5XyYycL1eXHPcjmxz8l2M2HNDHnx+Om+9N2+Z+e6/fX9GP/pac11GWcpjuoSIeAx4LK2/B+y+jOPOB86vJ30ssGlT1CWvPvu9gK8CfSRdVrCrKw3Mell402PKrI9LvhFRTi6/5FzW6LsW+x5wyJK0VVdbnQ023oy2bdvxhTX60HfNfsyY/ibrb7QpPVfNbuD3XqMvmw3Ymn+/+q8lwR6gTZs27LTbYG6/abSDfZm6f9x07h+XDcI4fPf1Wfxp9k9h+G7rse8FDwPw7Gvv0KFdG1bt0pGB6/di0Earc9SeG9K5Y1vata1h7ieLOPPPWY9pj84d2GrdVfn2L/7WMhdULiq85zOvPvsZZP31n7B0X/09wOAGzqsq1/3+Cj76aC4jjj9lqfRtd9yVCc8/B8CH//uAt6a/wRfW6MucObNZuGDBkvTJL7/AWv3WISKYMT3ri40InvnnE/Rdu3/zXow1mV5dOwKwysrtOWrwhkta5NPe/WhJN80GfbrRsV0b3pn9CUde/nc2+v5tbHLsbZx+/VhueuLfSwI9wDe268cD46czf+HizxdmSzT1E7StTV599i8CL0q6MSI8f3093p31Nrdcdw191+7P8Udm73PZe79hDN57P7YaOIjnn32Kow/Zj5o2NRxxzEl07bYKkya8wBW/OI8a1fBpfMq3Dj6Ctfp/kU8//ZRfnf8z5s37CCLov+76jDz5Jy18hVaqnx8+kM3Wzm7KX3Tbi0yZmY1ePv2657j8e4M49msbE8D3rnyyqPy+Nag/v7xrQl7VrRiVPqZB2YNZTZypNCYiDpA0gaXHhQqIiNh8eXlUSzeONc6AY29p6SpYKzR3zGErHKqfe/3DomPONut0K7uvhrxu0NY+HfT1nPI3M2taZRe+Gyevbpzat1G9C3wcEZ9KWh/YkBUYOmRmlpdi5rwpZ3lPhPYE0FFSH7JZ3g4nGzNqZtaqqBFLOco72Csi5gH7AZdHxDeAjXMu08ys8So82uce7CVtR/be2ftSWt4PcpmZNZqHXq6YE8memL0zPRK8DuAnO8ys1anwLvt8g31EPA48LqmLpM4R8TrgGS/NrNWp9GCfazeOpM0kPQ+8DEySNE7SJnmWaWZWCnfjrJirgR9ExN8A0kT+vwcG5VyumVmjVHrLPu9gv3JtoIdsBjhJK+dcpplZo1V4rM892L8u6WfA9Wn7EGBqzmWamTVehUf7vIdeHgH0Au5Iy6pkD1aZmbUq7rMvgaSOwNHAusAE4OSI8CtyzKzVaswLx8tRXt04o4GFwN+BvYCNyMbcm5m1Tg72Jdk4IjYDkPQH4NmcyjEzaxLl2j1TrLyC/ZIum4hYpEof02RmZa/Sw1RewX4LSbPTuoBOabv25SVdcyrXzKwkFR7rc5vPvk0e+ZqZ5abCo71noDQzo/JfXuJgb2ZGxTfsHezNzICKj/YO9mZmeOilmVlVqPAuewd7MzNwsDczqwruxjEzqwJu2ZuZVYEKj/UO9mZm4Ja9mVmVqOxo72BvZkblv7wk79cSmpmVBan4peF8tKakv0maLGmipBNSeg9JD0t6LX12LzjnNElTJL0iaXBB+laSJqR9l2kF5ot3sDczo0nfQbuI7FWsGwHbAiMlbQycCjwSEesBj6Rt0r5hwCbAEOBKSbUzB18FjADWS8uQUq/Pwd7MDLIu+2KXBkTEzIgYn9bnAJOBPsBQsle2kj73TetDgZsjYn5ETAWmAAMl9Qa6RsRTERHAdQXnNJqDvZkZjYv1kkZIGluwjKg3T6kf8CXgGWD1iJgJ2RcCsFo6rA8wreC06SmtT1qvm14S36A1M6NxQy8jYhQwquH81Bm4HTgxImY30N1e345oIL0kDvZmZkBTvitbUjuyQH9jRNyRkt+W1DsiZqYumlkpfTqwZsHpfYEZKb1vPeklcTeOmRlN1mVPGjHzB2ByRPyqYNc9wPC0Phy4uyB9mKQOkvqT3Yh9NnX1zJG0bcrz0IJzGs0tezMzmvQJ2u2B7wATJL2Q0k4HLgLGSDoSeBPYHyAiJkoaA0wiG8kzMiIWp/OOAa4FOgH3p6UkDvZmZjTdrJcR8STL/gGw+zLOOR84v570scCmTVEvB3szMzw3jplZVXCwNzOrAn55iZlZFXDL3sysClR4rHewNzMDKj7aO9ibmeE+ezOzqlDpLy9xsDczA3fjmJlVA3fjmJlVgUofeqnsBSjWmkkakebPNlvCfy+sMTzFcXmo9y04VvX898KK5mBvZlYFHOzNzKqAg315cL+s1cd/L6xovkFrZlYF3LI3M6sCDvZmZlXAwb6JSQpJvyzY/qGks3Io5/Q62/9s6jIsH5IWS3pB0suSbpW0UiPPX0PSbWl9gKSvFuzbR9KpTV1nK38O9k1vPrCfpFVzLmepYB8Rg3Iuz5rOxxExICI2BRYARzfm5IiYERHfSpsDgK8W7LsnIi5qsppaxXCwb3qLyEZJnFR3h6Rekm6X9Fxati9If1jSeElXS3qj9stC0l2SxkmaKGlESrsI6JRahzemtLnp85Y6Lb1rJX1TUhtJl6RyX5L0vdz/JKwYfwfWldQj/b9+SdLTkjYHkLRz+v/8gqTnJXWR1C/9KmgPnAMcmPYfKOkwSVdI6ibpP5JqUj4rSZomqZ2kL0p6IP29+rukDVvw+q25RISXJlyAuUBX4D9AN+CHwFlp35+BHdL6WsDktH4FcFpaHwIEsGra7pE+OwEvAz1ry6lbbvr8BjA6rbcHpqVzRwA/TekdgLFA/5b+86rGpeD/VVvgbuAY4HLgzJS+G/BCWr8X2D6td07n9ANeTmmHAVcU5L1kO+W9a1o/ELgmrT8CrJfWvww82tJ/Jl7yXzwRWg4iYrak64DjgY8Ldn0F2FifzbjUVVIXYAeyIE1EPCDpg4Jzjpf0jbS+JrAe8F4Dxd8PXCapA9kXxxMR8bGkPYHNJdX+/O+W8ppa6nVayTpJeiGt/x34A/AM8E2AiHhUUk9J3YB/AL9Kv+DuiIjpKn7GrlvIgvzfgGHAlZI6A4OAWwvy6bDil2StnYN9fn4NjAf+VJBWA2wXEYVfAGgZ/3ol7UL2BbFdRMyT9BjQsaFCI+KTdNxgsn/oN9VmBxwXEQ828jqs6X0cEQMKE5bxdyAi4iJJ95H1yz8t6SvAJ0WWcw9woaQewFbAo8DKwP/qlm+Vz332OYmI94ExwJEFyQ8Bx9ZuSBqQVp8EDkhpewLdU3o34IMU6DcEti3Ia6Gkdsso/mbgcGBHoDa4PwgcU3uOpPUlrVza1VkOngAOhiVf8u+mX4hfjIgJEXExWddb3f71OUCX+jKMiLnAs8BvgL9ExOKImA1MlbR/KkuStsjjgqx1cbDP1y+BwlE5xwNbp5twk/hsFMbZwJ6SxgN7ATPJ/hE/ALSV9BJwLvB0QV6jgJdqb9DW8RCwE/B/EbEgpV0DTALGS3oZuBr/smtNziL93QAuAoan9BPTzdgXyboE769z3t/IugZfkHRgPfneAhySPmsdDByZ8pwIDG26y7DWytMltAKpf31xRCyStB1wlX9mm1lTcsuudVgLGJOGyS0Ajmrh+phZhXHL3sysCrjP3sysCjjYm5lVAQd7M7Mq4GBvTWpFZ3Ssk9e1tU/8SrpG0sYNHLuLpEEF20dLOrTUss0qjYO9NbUGZ3SU1KaUTCPiuxExqYFDdiGbBqD2+N9FxHWllGVWiRzsLU+1MzruIulvkv4MTFjWDJzpac4rJE1KUwSsVpuRpMckbZ3WhyibIfRFSY9I6kf2pXJS+lWxo6SzJP0wHT8gzST5kqQ7JXUvyPNiSc9KelXSjs37x2PWfDzO3nIhqS3Z08APpKSBwKYRMVXZVM0fRsQ26YGyf0h6CPgSsAGwGbA62RO/f6yTby/g98BOKa8eEfG+pN+RzSb5i3Tc7gWnXUc2L9Djks4BzgROTPvaRsRAZdNCn0k2F5FZxXGwt6ZW34yOg4BnI6J2hs1lzcC5E3BTRCwGZkh6tJ78tyWbyXMqLJmDaJnSzJGrRMTjKWk0cGvBIXekz3FkUwebVSQHe2tq9c3oCPBRYRL1zMCZWtfLe8pPRRzTGPPT52L878EqmPvsrSUsawbOJ4BhqU+/N7BrPec+BewsqX86t0dKr3f2x4j4EPigoD/+O8DjdY8zq3RuyVhLuIasy2R8msf9HWBf4E6ytzRNAF6lnqAcEe+kPv870lxCs4A9yN7odJukocBxdU4bDvwuDQN9nWz6Z7Oq4rlxzMyqgLtxzMyqgIO9mVkVcLA3M6sCDvZmZlXAwd7MrAo42JuZVQEHezOzKvD/IUnG2NUYJj4AAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" Negative 0.81 0.90 0.85 12500\n",
|
|
" Positive 0.89 0.79 0.84 12500\n",
|
|
"\n",
|
|
" accuracy 0.85 25000\n",
|
|
" macro avg 0.85 0.85 0.84 25000\n",
|
|
"weighted avg 0.85 0.85 0.84 25000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import seaborn\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from tensorflow.keras.models import load_model\n",
|
|
"from sklearn.metrics import classification_report, confusion_matrix, plot_confusion_matrix\n",
|
|
"\n",
|
|
"\n",
|
|
"# 5) Create testing set\n",
|
|
"test_x = pos_x + neg_x\n",
|
|
"test_y = pos_y + neg_y\n",
|
|
"test_x = np.array(test_x)\n",
|
|
"test_y = np.array(test_y)\n",
|
|
"\n",
|
|
"print(f'test_x shape: {test_x.shape}')\n",
|
|
"print(f'test_y shape: {test_y.shape}')\n",
|
|
"\n",
|
|
"# 6) Load pre-trained model and predict testing reviews\n",
|
|
"model = load_model('imdb_cls_model.h5')\n",
|
|
"test_pred = [True if pred >= 0.5 else False for pred in model.predict(test_x)]\n",
|
|
"test_pred = np.array(test_pred)\n",
|
|
"\n",
|
|
"print(f'test_pred shape: {test_pred.shape}')\n",
|
|
"\n",
|
|
"# 7) Draw confusion matrix & calculate classification report\n",
|
|
"labels = ['Negative', 'Positive']\n",
|
|
"\n",
|
|
"conf_matrix = confusion_matrix(test_y, test_pred)\n",
|
|
"ax = seaborn.heatmap(conf_matrix, annot=True, fmt='5', cmap='Blues')\n",
|
|
"ax.set_title('Confusion Matrix\\n');\n",
|
|
"ax.set_xlabel('Prediction')\n",
|
|
"ax.set_ylabel('Ground Truth');\n",
|
|
"ax.xaxis.set_ticklabels(labels)\n",
|
|
"ax.yaxis.set_ticklabels(labels)\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"print(classification_report(test_y, test_pred, target_names=labels))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fcfbdf6b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|