{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "ok_toxic_classification_and_vectors.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "F1K-KbJRElOa" }, "source": [ "# Классификация токсичных комментариев логистической регрессией" ] }, { "cell_type": "markdown", "metadata": { "id": "0HusYfbch0qt" }, "source": [ "https://cups.mail.ru/ru/tasks/1048" ] }, { "cell_type": "code", "metadata": { "id": "7TwDSJ_N-cuF", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "683b057f-2fe6-4a4e-c92e-35975fd0e80f" }, "source": [ "!wget https://cups.mail.ru/api_v2/task/1048/27 -O train.txt" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "--2021-07-23 06:57:20-- https://cups.mail.ru/api_v2/task/1048/27\n", "Resolving cups.mail.ru (cups.mail.ru)... 188.93.56.95\n", "Connecting to cups.mail.ru (cups.mail.ru)|188.93.56.95|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 24525472 (23M) [text/plain]\n", "Saving to: ‘train.txt’\n", "\n", "train.txt 100%[===================>] 23.39M 10.2MB/s in 2.3s \n", "\n", "2021-07-23 06:57:24 (10.2 MB/s) - ‘train.txt’ saved [24525472/24525472]\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4secCUaUAHpf", "outputId": "7b589b40-571b-4cc8-b966-9c7b0250e25b" }, "source": [ "! head train.txt" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "41127\t__label__THREAT\tдворника надо тоже уничтожить!\n", "6812\t__label__NORMAL\tмоя старшая неделю шипела, не принимала подкидыша, которого я принесла. китя такой славный, потерпите немного, мои теперь душа в душу живут, ваша примет попозже. а котенок необыкновенно красивый.\n", "6256\t__label__NORMAL\tполностью с вами согласна!\n", "189636\t__label__NORMAL\tхоть ногу вверх, ничего не изменится\n", "99053\t__label__NORMAL\tа что значит - левого ребенка?\n", "98418\t__label__NORMAL\tвечер добрый! а, что он у вас уже постарел?!\n", "3619\t__label__NORMAL\tкакая порода .?\n", "176463\t__label__INSULT\tспасатель? просто петух чванливый, взял наших пацанов в заложники и торгуется. позор! как бы сам в щи не попал.\n", "13520\t__label__NORMAL\tс замечательным юбилеем!!! голос- чудо, заслушаешься 😍и про чтение все правильно- ничто не сравнится с бумажной книгой!\n", "133613\t__label__NORMAL\tеще бы .такой красавец.\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "j47IyDHK_1VU" }, "source": [ "import pandas as pd" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hNqzxwnM_q7L", "outputId": "a237b827-baee-4d1f-c338-329ceae02e78" }, "source": [ "data = []\n", "with open('train.txt', 'r') as f:\n", " for line in f.readlines():\n", " parts = line.split('\\t')\n", " data.append({\n", " 'id': parts[0],\n", " 'text': parts[-1].strip(),\n", " 'labels': parts[1:-1],\n", " 'toxic': int(parts[1] != '__label__NORMAL')\n", " })\n", " \n", "data = pd.DataFrame(data)\n", "data.toxic.value_counts()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0 122194\n", "1 26581\n", "Name: toxic, dtype: int64" ] }, "metadata": { "tags": [] }, "execution_count": 4 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jFtHkrn6Agr2", "outputId": "09ca01fe-3ef3-4063-cbe0-88530219925d" }, "source": [ "data.labels.apply(tuple).value_counts()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(__label__NORMAL,) 122194\n", "(__label__INSULT,) 17007\n", "(__label__INSULT, __label__THREAT) 3747\n", "(__label__THREAT,) 3263\n", "(__label__OBSCENITY,) 1323\n", "(__label__INSULT, __label__OBSCENITY) 1087\n", "(__label__INSULT, __label__OBSCENITY, __label__THREAT) 111\n", "(__label__OBSCENITY, __label__THREAT) 43\n", "Name: labels, dtype: int64" ] }, "metadata": { "tags": [] }, "execution_count": 5 } ] }, { "cell_type": "code", "metadata": { "id": "b362vEgRBWad" }, "source": [ "pd.options.display.max_colwidth = 200" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 362 }, "id": "sJpev-rS-7oh", "outputId": "5e2cee23-ca95-4d40-8d11-8049443f5385" }, "source": [ "data.sample(10)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idtextlabelstoxic
77892133009ты чурка тупа, выучи русский язык отбросы, набор слов написал какой-то... псина.[__label__INSULT]1
54060175778здравствуйте, подскажите пожалуйста сколько стоит детский билет,ребенку 7 лет и сколько стоит багаж?[__label__NORMAL]0
93254108185да, жаль то время прошло когда ценилась дружба, родство[__label__NORMAL]0
131894211922замораживать!!![__label__NORMAL]0
56550171953дейзики есть в наличии?[__label__NORMAL]0
39244154430аскитесь, патриот, практика показывает, что сша, вызывает эпидемии, выращивает террористов , и постоянно ставит мир на грань выживания.[__label__NORMAL]0
94700198586я жиу в испании и когда они спелые и идут дожди ,видно от сырости они взрываютса..здесь они очень сладкие но не такие красные ...видно разные сорта есть..[__label__NORMAL]0
2311255579надо всех чиновников посадить на эту пенсию и сказать..им ни в чем себе не отказывай... сучара[__label__INSULT]1
82710236179а там дальше зависит от типа пандемии[__label__NORMAL]0
88968110340а чем лучше предатель власов -такая же -------- как и бандера - так чо пидарья везде хватает[__label__INSULT]1
\n", "
" ], "text/plain": [ " id ... toxic\n", "77892 133009 ... 1\n", "54060 175778 ... 0\n", "93254 108185 ... 0\n", "131894 211922 ... 0\n", "56550 171953 ... 0\n", "39244 154430 ... 0\n", "94700 198586 ... 0\n", "23112 55579 ... 1\n", "82710 236179 ... 0\n", "88968 110340 ... 1\n", "\n", "[10 rows x 4 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "code", "metadata": { "id": "Eepq7FhhAzsn" }, "source": [ "from sklearn.model_selection import train_test_split, cross_val_score" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ELK_nJgZA8PJ" }, "source": [ "X_train, X_test, y_train, y_test = train_test_split(data.text, data.toxic, test_size=0.2, random_state=1)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "AGzib3JQBDxY" }, "source": [ "from sklearn.linear_model import LogisticRegression\n", "from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer\n", "from sklearn.pipeline import make_pipeline" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p2c6TnwyBoNV", "outputId": "0446a037-233f-40e1-99ca-c621ff32b431" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=10_000, ngram_range=[1,1]), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train, y_train, scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.9364279 0.93809144 0.93862685]\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "JRllkylKLpAl" }, "source": [ "В русских текстах часто имеет смысл использовать не слова, а n-граммы, состоящие из букв. В данном случае - от 2 до 6 букв. " ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tHmq0RBdLJrd", "outputId": "b345e844-6907-4630-9fa2-d6b958f6d043" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=10_000, ngram_range=[2,6], analyzer='char_wb'), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train, y_train, scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.97930706 0.97965476 0.97953644]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "b92p1l1gj30L", "outputId": "f6c9c8f6-5834-42dc-b1b2-adbd3e7a4d87" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=100, ngram_range=[1,1]), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train, y_train, scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.689192 0.68744058 0.69532014]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "U0m1Qd9Tj46Z", "outputId": "bcab42b2-3e5c-402f-cacc-71e94e66936b" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=100_000, ngram_range=[1,1]), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train, y_train, scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.95696636 0.9590711 0.95921095]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GKBT8V3Wj8W5", "outputId": "9bd25641-9554-40bc-996f-870a0d999784" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=100_000, ngram_range=[1,2]), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train, y_train, scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.94627584 0.94915932 0.94938827]\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "nAi0T8DBEp_q" }, "source": [ "# Пример использования предобученных эмбеддингов\n", "\n", "Беру маленькие эмбеддинги из своей библиотеки https://github.com/avidale/compress-fasttext" ] }, { "cell_type": "code", "metadata": { "id": "ENJB7Cf-kBLN", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "a7385646-dc70-4884-86cd-e5890a80302a" }, "source": [ "!pip install compress-fasttext gensim==3.8.3" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Collecting compress-fasttext\n", " Downloading compress-fasttext-0.0.7.tar.gz (9.5 kB)\n", "Collecting gensim==3.8.3\n", " Downloading gensim-3.8.3-cp37-cp37m-manylinux1_x86_64.whl (24.2 MB)\n", "\u001b[K |████████████████████████████████| 24.2 MB 71.3 MB/s \n", "\u001b[?25hRequirement already satisfied: six>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from gensim==3.8.3) (1.15.0)\n", "Requirement already satisfied: numpy>=1.11.3 in /usr/local/lib/python3.7/dist-packages (from gensim==3.8.3) (1.19.5)\n", "Requirement already satisfied: smart-open>=1.8.1 in /usr/local/lib/python3.7/dist-packages (from gensim==3.8.3) (5.1.0)\n", "Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.7/dist-packages (from gensim==3.8.3) (1.4.1)\n", "Building wheels for collected packages: compress-fasttext\n", " Building wheel for compress-fasttext (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for compress-fasttext: filename=compress_fasttext-0.0.7-py3-none-any.whl size=11079 sha256=ac3deb1a49700d452aae769dcce7615be933e9839be69a460ecb387f7fdbcd25\n", " Stored in directory: /root/.cache/pip/wheels/fa/cf/43/579ed0c5dc7f41928de0cbd42d06c0ffbd8731d48ba0ac2587\n", "Successfully built compress-fasttext\n", "Installing collected packages: gensim, compress-fasttext\n", " Attempting uninstall: gensim\n", " Found existing installation: gensim 3.6.0\n", " Uninstalling gensim-3.6.0:\n", " Successfully uninstalled gensim-3.6.0\n", "Successfully installed compress-fasttext-0.0.7 gensim-3.8.3\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9WZ8LpEYocmx", "outputId": "9e554f63-391a-4912-b6ea-3cb44642dff2" }, "source": [ "import compress_fasttext\n", "small_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load(\n", " 'https://github.com/avidale/compress-fasttext/releases/download/v0.0.1/ft_freqprune_100K_20K_pq_300.bin'\n", ")\n", "print(small_model['спасибо'])" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[ 1.61615876e-01 -3.80952631e-01 2.26632313e-01 3.03909637e-01\n", " 3.20562876e-01 -3.68155590e-02 -1.02513266e-01 8.39554454e-02\n", " -2.09805377e-01 -3.92980986e-01 -3.64192873e-01 3.67168336e-01\n", " 2.60246897e-01 -4.17317844e-01 6.38473234e-02 5.55262231e-01\n", " -1.13023188e-01 5.28435459e-02 1.88473114e-02 4.31614519e-01\n", " -6.51933124e-01 3.59791468e-01 -2.03108434e-02 2.74519134e-01\n", " 2.45121267e-01 2.16757885e-01 -1.51260239e-01 -3.88337434e-02\n", " 1.11820161e-01 2.55094536e-02 1.17422326e-01 2.61721188e-01\n", " -1.64664370e-01 2.03159175e-01 -1.61370660e-01 -2.81119821e-01\n", " 1.69485878e-01 1.65399749e-01 1.76555266e-01 -1.21750807e-01\n", " -3.54392871e-01 -2.26254042e-01 2.42314721e-01 -2.91829224e-01\n", " 3.33873943e-01 -1.68621924e-01 -5.94392285e-02 2.98805771e-01\n", " -1.96432182e-01 6.00420549e-02 1.39643177e-02 1.42371537e-01\n", " -5.37020105e-02 2.06195908e-01 -2.34923034e-01 -2.52630994e-01\n", " 8.11456894e-02 7.75969016e-02 1.66950779e-01 -9.49587288e-02\n", " 1.74773048e-01 -6.01128004e-02 2.43300735e-01 -8.05915690e-02\n", " 3.69941743e-03 1.65421759e-01 2.10352424e-02 -3.61011013e-01\n", " -1.63539051e-01 2.95837237e-01 1.37665761e-01 2.20584220e-01\n", " -4.04024511e-01 2.32131405e-01 1.53377918e-02 -1.45389737e-01\n", " 3.08798710e-02 6.34773358e-02 -5.58610774e-01 -1.48303868e-01\n", " 1.99646425e-01 -3.85183300e-01 3.40566656e-02 1.59161993e-01\n", " 3.10103464e-01 -3.52403348e-01 2.05617299e-01 -1.27653877e-02\n", " 2.65141530e-02 -2.30114791e-01 2.31600030e-02 -1.19834414e-01\n", " -2.17303699e-01 2.38900819e-02 -1.46927439e-01 3.75952777e-01\n", " -4.83471881e-01 -1.81697345e-02 -1.64596374e-01 2.14438204e-02\n", " -6.00285296e-01 -2.45735006e-02 -3.32263037e-01 8.40615085e-02\n", " -6.83442608e-02 1.27358968e-01 8.26212340e-03 -9.75221280e-02\n", " 9.80308819e-02 -2.08683408e-01 7.68353109e-02 6.75428913e-02\n", " -2.48139093e-01 5.34002016e-02 1.72544440e-01 1.69895828e-01\n", " -2.25810242e-01 1.69829409e-01 -6.33449239e-02 4.04133484e-01\n", " 1.35162277e-01 -3.32322920e-01 3.23809234e-01 3.99056974e-01\n", " -1.04846910e-01 -1.07610342e-01 3.18698091e-01 -1.65196609e-01\n", " 6.53688822e-03 1.23802902e-01 8.63148047e-02 2.46591371e-01\n", " -2.97026349e-01 -1.11622416e-02 -3.60517195e-01 2.91040565e-01\n", " -3.32069907e-01 -1.62152280e-02 -8.64313786e-02 2.31732699e-01\n", " -6.46613565e-01 1.59538959e-01 1.87982511e-01 1.64226625e-01\n", " -5.48643269e-01 -4.18351521e-01 1.59706698e-01 -6.79623822e-02\n", " -2.13029950e-01 4.16092811e-01 -1.27279411e-01 3.79515352e-01\n", " -4.10310020e-02 -3.07852644e-01 -2.15173804e-01 2.94661121e-01\n", " -3.57905777e-01 -2.13862730e-01 -2.74208153e-02 1.15006933e-01\n", " 1.95981823e-01 1.32423369e-01 2.39542574e-01 1.03746984e-01\n", " -2.23761173e-01 2.90788008e-01 1.14580697e-01 -1.06723550e-01\n", " -3.48427506e-02 -8.45187167e-02 2.28464678e-01 -1.68490110e-02\n", " 1.10741577e-01 8.59757568e-02 8.19669686e-02 -5.52344220e-02\n", " 1.15278565e-01 1.76453390e-01 1.23825877e-01 -1.67114894e-01\n", " -9.07890893e-02 -1.06378453e-01 -2.09171552e-02 -1.28674250e-01\n", " 3.38531753e-02 5.57877143e-02 -9.55982478e-03 3.30195071e-02\n", " -6.01286488e-02 3.48880944e-01 2.35897633e-01 1.38901786e-01\n", " 5.32211441e-02 -3.29748844e-01 -5.97986143e-01 -3.89263122e-02\n", " -2.78548556e-01 2.49484720e-01 -2.16167704e-01 8.67556795e-02\n", " 1.57360859e-01 -9.23949692e-02 -1.65816482e-01 3.87052268e-01\n", " 1.81331944e-01 2.18849768e-01 1.94413311e-01 2.61520564e-01\n", " -1.02200094e-01 -1.20663568e-01 -2.06318275e-02 1.32427033e-02\n", " -2.88134728e-02 -3.30335017e-01 -7.02265696e-02 -3.50569419e-01\n", " 5.10748838e-02 1.43469746e-01 -1.40340648e-01 2.50429845e-02\n", " -2.15370538e-01 2.07442103e-01 1.41600342e-02 -2.24800130e-01\n", " 7.26855989e-03 -2.18781528e-01 1.62791362e-01 -3.34501333e-01\n", " 2.73570884e-01 -6.10981763e-02 1.39826441e-01 -3.37187974e-01\n", " 6.23038200e-02 1.45910728e-01 -2.58549746e-01 -2.42065024e-01\n", " -4.94282944e-02 3.44084496e-01 7.65113527e-02 1.99977004e-01\n", " 3.90961861e-01 1.33362317e-01 2.69634888e-01 2.77443510e-01\n", " -9.49617548e-03 9.28371678e-02 9.14782829e-03 -1.23072508e-01\n", " -2.86175988e-01 -1.01785367e-01 7.23490352e-02 4.26618644e-02\n", " 2.13843149e-01 -1.68306838e-01 4.68371288e-02 -1.08101068e-01\n", " 3.18981584e-01 -5.06685720e-02 -2.29462277e-01 -2.97431647e-01\n", " -4.24168585e-02 -1.71971227e-01 1.99816109e-01 -2.15944355e-02\n", " -7.56473830e-02 1.05135867e-01 -3.01282371e-03 -2.63578675e-01\n", " 7.59607964e-02 3.21395825e-01 -6.18542818e-02 -2.78966223e-01\n", " 7.14540558e-02 -8.25641939e-02 2.49445574e-01 -1.37787532e-01\n", " -1.96797427e-02 2.24352000e-01 4.18048341e-02 1.68571320e-01\n", " -1.54399159e-01 1.40424669e-01 -1.62490213e-01 -5.86164415e-02\n", " 9.00153553e-02 -1.62794539e-01 -1.20653493e-01 5.66397051e-01\n", " 4.50488658e-02 2.40377445e-01 -3.89976876e-02 -1.79316719e-04\n", " 7.61616827e-02 7.85599382e-02 -1.82465750e-01 2.40700235e-01\n", " -2.18711806e-01 -1.18356974e-02 2.24259246e-02 4.89963293e-01]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HB9XDFCDPEwD", "outputId": "8429e080-e74b-45a3-f2cc-ba137d41d399" }, "source": [ "print(small_model['спасибо'].shape)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "(300,)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZXvTzvdjokzV", "outputId": "232b5ade-ee98-4ae7-9d40-ec5f479ff90c" }, "source": [ "sum((small_model['пес'] - small_model['кол'])**2)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "19.229890369031384" ] }, "metadata": { "tags": [] }, "execution_count": 19 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uP95J-lro32A", "outputId": "a538469e-c605-4444-cf30-0eceea18a8a2" }, "source": [ "sum((small_model['пес'] - small_model['кот'])**2)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "11.621345470343304" ] }, "metadata": { "tags": [] }, "execution_count": 20 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8ewOVa1rpEtB", "outputId": "44294e7c-32b2-440a-9fe9-96ac8a12823d" }, "source": [ "sum((small_model['кошка'] - small_model['кот'])**2)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "8.040876949916234" ] }, "metadata": { "tags": [] }, "execution_count": 21 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xT32Sb4wL4-L", "outputId": "219198f7-6b09-4fc4-b613-493c8727d301" }, "source": [ "small_model.most_similar('кот')" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('кошка', 0.6625277996063232),\n", " ('котенок', 0.6041390299797058),\n", " ('щенок', 0.4767090380191803),\n", " ('пес', 0.47578155994415283),\n", " ('кобель', 0.4608815312385559),\n", " ('собака', 0.45937010645866394),\n", " ('крыса', 0.45367270708084106),\n", " ('котик', 0.44258466362953186),\n", " ('собачонка', 0.44200772047042847),\n", " ('тимка', 0.4289957880973816)]" ] }, "metadata": { "tags": [] }, "execution_count": 22 } ] }, { "cell_type": "markdown", "metadata": { "id": "xBx5wg9_Mg0H" }, "source": [ "Кошка - мяукать + лаять = ???" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pgW4_H1RMazx", "outputId": "d2c6ef2e-f6cd-4019-ebc4-75e4b3daa3c9" }, "source": [ "small_model.most_similar(positive=['кошка', 'лаять'], negative=['мяукать'])" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('собака', 0.5540856719017029),\n", " ('пес', 0.5085850358009338),\n", " ('овчарка', 0.4742797017097473),\n", " ('волк', 0.4702416658401489),\n", " ('кот', 0.46600908041000366),\n", " ('собачонка', 0.45119690895080566),\n", " ('щенок', 0.45089417695999146),\n", " ('лисица', 0.43611547350883484),\n", " ('собачка', 0.4224735498428345),\n", " ('бездомный', 0.4224018156528473)]" ] }, "metadata": { "tags": [] }, "execution_count": 23 } ] }, { "cell_type": "markdown", "metadata": { "id": "OG9jcQwINNsM" }, "source": [ "Что русскому хорошо, то немцу - ?" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pBF2mKHvNCII", "outputId": "8c8b6e6a-7737-4300-c48e-68744bed74e5" }, "source": [ "small_model.most_similar(positive=['немец', 'хорошо'], negative=['русский'])" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('нехорошо', 0.45273256301879883),\n", " ('плохо', 0.45050546526908875),\n", " ('хорошенько', 0.43955713510513306),\n", " ('неплохо', 0.4276043176651001),\n", " ('хороший', 0.40715932846069336),\n", " ('по-хорошему', 0.3803216814994812),\n", " ('турок', 0.3786393105983734),\n", " ('пригорок', 0.3725931644439697),\n", " ('отлично', 0.3685779273509979),\n", " ('великолепно', 0.35628485679626465)]" ] }, "metadata": { "tags": [] }, "execution_count": 24 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "glg8qcV3Nh-L", "outputId": "cc6777ca-b13e-4345-fe20-b226b517cc38" }, "source": [ "small_model.most_similar(positive=['москва', 'италия'], negative=['россия'])" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('париж', 0.6359437704086304),\n", " ('ницца', 0.561758279800415),\n", " ('лондон', 0.5375697612762451),\n", " ('флоренция', 0.5359780788421631),\n", " ('рим', 0.5326616168022156),\n", " ('венеция', 0.5186587572097778),\n", " ('дрезден', 0.5119317770004272),\n", " ('мюнхен', 0.5098958611488342),\n", " ('бухарест', 0.4781143069267273),\n", " ('будапешт', 0.4774155616760254)]" ] }, "metadata": { "tags": [] }, "execution_count": 25 } ] }, { "cell_type": "code", "metadata": { "id": "VQXhWIRbE_MA", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "902f2352-3e52-4362-f408-1f4130fa2bd6" }, "source": [ "!pip install razdel pymorphy2 pymorphy2-dicts-ru" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Collecting razdel\n", " Downloading razdel-0.5.0-py3-none-any.whl (21 kB)\n", "Collecting pymorphy2\n", " Downloading pymorphy2-0.9.1-py3-none-any.whl (55 kB)\n", "\u001b[K |████████████████████████████████| 55 kB 2.8 MB/s \n", "\u001b[?25hCollecting pymorphy2-dicts-ru\n", " Downloading pymorphy2_dicts_ru-2.4.417127.4579844-py2.py3-none-any.whl (8.2 MB)\n", "\u001b[K |████████████████████████████████| 8.2 MB 15.1 MB/s \n", "\u001b[?25hCollecting dawg-python>=0.7.1\n", " Downloading DAWG_Python-0.7.2-py2.py3-none-any.whl (11 kB)\n", "Requirement already satisfied: docopt>=0.6 in /usr/local/lib/python3.7/dist-packages (from pymorphy2) (0.6.2)\n", "Installing collected packages: pymorphy2-dicts-ru, dawg-python, razdel, pymorphy2\n", "Successfully installed dawg-python-0.7.2 pymorphy2-0.9.1 pymorphy2-dicts-ru-2.4.417127.4579844 razdel-0.5.0\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "kzGkItr5pGjL" }, "source": [ "from sklearn.base import BaseEstimator, TransformerMixin\n", "from razdel import tokenize\n", "from pymorphy2 import MorphAnalyzer\n", "import numpy as np\n", "from functools import lru_cache\n", "\n", "\n", "class TextVectorizer(BaseEstimator, TransformerMixin):\n", " def __init__(self, morph, w2v, min_len=3, eps=1e-10):\n", " self.morph = morph\n", " self.w2v = w2v\n", " self.min_len = min_len\n", " self.eps = eps\n", " \n", " def fit(self, X, y=None):\n", " return self\n", " \n", " def transform(self, X):\n", " return np.stack([self.text2vec(text) for text in X])\n", " \n", " @lru_cache(maxsize=10000)\n", " def lemmatize(self, token):\n", " parsed = self.morph.parse(token)\n", " if not parsed or not parsed[0].normal_form:\n", " return token\n", " else:\n", " return parsed[0].normal_form\n", "\n", " @lru_cache(maxsize=10000)\n", " def word2vec(self, word):\n", " return self.w2v[word]\n", "\n", " @lru_cache(maxsize=10000)\n", " def text2vec(self, text):\n", " tokens = [self.lemmatize(t.text).lower() for t in tokenize(text)]\n", " tokens = [t for t in tokens if len(t) >= self.min_len]\n", " if not tokens:\n", " return self.w2v['кот'] * 0\n", " vector = np.sum([self.w2v[t] for t in tokens], axis=0)\n", " vector /= (sum(vector**2)**0.5 + self.eps)\n", " return vector" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "wyv6zxo1GQJD" }, "source": [ "vv = TextVectorizer(MorphAnalyzer(), small_model)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DKq7bK3vPrL7", "outputId": "8c21bb7c-afeb-4609-996d-dfe0a2c8b51e" }, "source": [ "vv.transform(['мама мыла раму']).shape" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1, 300)" ] }, "metadata": { "tags": [] }, "execution_count": 30 } ] }, { "cell_type": "markdown", "metadata": { "id": "G_67326bJku5" }, "source": [ "Сжатый fasttext разжимается медленно, плюс сама по себе модель не очень быстрая, поэтому покажу пример применения " ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yMZeNEGYFn5R", "outputId": "3fdc1103-e187-4ae3-976b-bdcec4d60014" }, "source": [ "%%time\n", "n = 1_000\n", "pipe = make_pipeline(vv, LogisticRegression(max_iter=300))\n", "print(cross_val_score(pipe, X_train[:n], y_train[:n], scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.92406106 0.91005594 0.91285278]\n", "CPU times: user 8.82 s, sys: 669 ms, total: 9.49 s\n", "Wall time: 8.87 s\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "kft-yxtBGMy4", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e7069825-887f-4953-a758-e6b6b2761f49" }, "source": [ "pipe = make_pipeline(HashingVectorizer(n_features=10_000, ngram_range=[2,6], analyzer='char_wb'), LogisticRegression(max_iter=1000))\n", "print(cross_val_score(pipe, X_train[:n], y_train[:n], scoring='roc_auc', cv=3))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[0.91836088 0.852657 0.83651157]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "aLFh_7GsP-eZ" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }