Source code for chatterbot.trainers

import os
import sys
import csv
import time
import glob
import json
import tarfile
from tqdm import tqdm
from dateutil import parser as date_parser
from chatterbot.conversation import Statement


class Trainer(object):
    """
    Base class for all other trainer classes.

    :param boolean show_training_progress: Show progress indicators for the
           trainer. The environment variable ``CHATTERBOT_SHOW_TRAINING_PROGRESS``
           can also be set to control this. ``show_training_progress`` will override
           the environment variable if it is set.
    """

    def __init__(self, chatbot, **kwargs):
        self.chatbot = chatbot

        environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))

        self.show_training_progress = kwargs.get(
            'show_training_progress',
            environment_default
        )

    def get_preprocessed_statement(self, input_statement):
        """
        Preprocess the input statement.
        """
        for preprocessor in self.chatbot.preprocessors:
            input_statement = preprocessor(input_statement)

        return input_statement

    def train(self, *args, **kwargs):
        """
        This method must be overridden by a child class.
        """
        raise self.TrainerInitializationException()

    class TrainerInitializationException(Exception):
        """
        Exception raised when a base class has not overridden
        the required methods on the Trainer base class.
        """

        def __init__(self, message=None):
            default = (
                'A training class must be specified before calling train(). '
                'See https://docs.chatterbot.us/training.html'
            )
            super().__init__(message or default)

    def _generate_export_data(self):
        result = []
        for statement in self.chatbot.storage.filter():
            if statement.in_response_to:
                result.append([statement.in_response_to, statement.text])

        return result

    def export_for_training(self, file_path='./export.json'):
        """
        Create a file from the database that can be used to
        train other chat bots.
        """
        export = {'conversations': self._generate_export_data()}
        with open(file_path, 'w+', encoding='utf8') as jsonfile:
            json.dump(export, jsonfile, ensure_ascii=False)


[docs] class ListTrainer(Trainer): """ Allows a chat bot to be trained using a list of strings where the list represents a conversation. """ def train(self, conversation): """ Train the chat bot based on the provided list of statements that represents a single conversation. """ previous_statement_text = None previous_statement_search_text = '' statements_to_create = [] # Run the pipeline in bulk to improve performance documents = self.chatbot.storage.tagger.as_nlp_pipeline(conversation) # for text in enumerate(conversation): for document in tqdm(documents, desc='List Trainer', disable=not self.show_training_progress): statement_search_text = document._.search_index statement = self.get_preprocessed_statement( Statement( text=document.text, search_text=statement_search_text, in_response_to=previous_statement_text, search_in_response_to=previous_statement_search_text, conversation='training' ) ) previous_statement_text = statement.text previous_statement_search_text = statement_search_text statements_to_create.append(statement) self.chatbot.storage.create_many(statements_to_create)
[docs] class ChatterBotCorpusTrainer(Trainer): """ Allows the chat bot to be trained using data from the ChatterBot dialog corpus. """ def train(self, *corpus_paths): from chatterbot.corpus import load_corpus, list_corpus_files data_file_paths = [] # Get the paths to each file the bot will be trained with for corpus_path in corpus_paths: data_file_paths.extend(list_corpus_files(corpus_path)) for corpus, categories, _file_path in tqdm( load_corpus(*data_file_paths), desc='ChatterBot Corpus Trainer', disable=not self.show_training_progress ): statements_to_create = [] # Train the chat bot with each statement and response pair for conversation in corpus: # Run the pipeline in bulk to improve performance documents = self.chatbot.storage.tagger.as_nlp_pipeline(conversation) previous_statement_text = None previous_statement_search_text = '' for document in documents: statement_search_text = document._.search_index statement = Statement( text=document.text, search_text=statement_search_text, in_response_to=previous_statement_text, search_in_response_to=previous_statement_search_text, conversation='training' ) statement.add_tags(*categories) statement = self.get_preprocessed_statement(statement) previous_statement_text = statement.text previous_statement_search_text = statement_search_text statements_to_create.append(statement) if statements_to_create: self.chatbot.storage.create_many(statements_to_create)
[docs] class UbuntuCorpusTrainer(Trainer): """ Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. For more information about the Ubuntu Dialog Corpus visit: https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ """ def __init__(self, chatbot, **kwargs): super().__init__(chatbot, **kwargs) home_directory = os.path.expanduser('~') self.data_download_url = kwargs.get( 'ubuntu_corpus_data_download_url', 'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz' ) self.data_directory = kwargs.get( 'ubuntu_corpus_data_directory', os.path.join(home_directory, 'ubuntu_data') ) self.extracted_data_directory = os.path.join( self.data_directory, 'ubuntu_dialogs' ) def is_downloaded(self, file_path): """ Check if the data file is already downloaded. """ if os.path.exists(file_path): self.chatbot.logger.info('File is already downloaded') return True return False def is_extracted(self, file_path): """ Check if the data file is already extracted. """ if os.path.isdir(file_path): self.chatbot.logger.info('File is already extracted') return True return False def download(self, url, show_status=True): """ Download a file from the given url. Show a progress indicator for the download status. Based on: http://stackoverflow.com/a/15645088/1547223 """ import requests # Create the data directory if it does not already exist if not os.path.exists(self.data_directory): os.makedirs(self.data_directory) file_name = url.split('/')[-1] file_path = os.path.join(self.data_directory, file_name) # Do not download the data if it already exists if self.is_downloaded(file_path): return file_path with open(file_path, 'wb') as open_file: print('Downloading %s' % url) response = requests.get(url, stream=True) total_length = response.headers.get('content-length') if total_length is None: # No content length header open_file.write(response.content) else: download = 0 total_length = int(total_length) for data in response.iter_content(chunk_size=4096): download += len(data) open_file.write(data) if show_status: done = int(50 * download / total_length) sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done))) sys.stdout.flush() # Add a new line after the download bar sys.stdout.write('\n') print('Download location: %s' % file_path) return file_path def extract(self, file_path): """ Extract a tar file at the specified file path. """ print('Extracting {}'.format(file_path)) if not os.path.exists(self.extracted_data_directory): os.makedirs(self.extracted_data_directory) def track_progress(members): sys.stdout.write('.') for member in members: # This will be the current file being extracted yield member with tarfile.open(file_path) as tar: def is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == abs_directory def safe_extract(tar, path=".", members=None, *, numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") tar.extractall(path, members, numeric_owner=numeric_owner) safe_extract(tar, path=self.extracted_data_directory, members=track_progress(tar)) self.chatbot.logger.info('File extracted to {}'.format(self.extracted_data_directory)) return True def train(self, limit=None): """ limit: int If defined, the number of files to read from the data set. """ # Download and extract the Ubuntu dialog corpus if needed corpus_download_path = self.download(self.data_download_url) # Extract if the directory does not already exist if not self.is_extracted(self.extracted_data_directory): self.extract(corpus_download_path) extracted_corpus_path = os.path.join( self.extracted_data_directory, '**', '**', '*.tsv' ) def chunks(items, items_per_chunk): for start_index in range(0, len(items), items_per_chunk): end_index = start_index + items_per_chunk yield items[start_index:end_index] file_list = glob.glob(extracted_corpus_path) # Limit the number of files used if a limit is defined if limit is not None: file_list = file_list[:limit] file_groups = tuple(chunks(file_list, 5000)) start_time = time.time() for batch_number, tsv_files in enumerate(file_groups): statements_from_file = [] for tsv_file in tqdm(tsv_files, desc=f'Training with batch {batch_number} of {len(file_groups)}'): with open(tsv_file, 'r', encoding='utf-8') as tsv: reader = csv.reader(tsv, delimiter='\t') previous_statement_text = None previous_statement_search_text = '' documents = self.chatbot.storage.tagger.as_nlp_pipeline([ ( row[3], { 'persona': row[1], 'created_at': row[0], } ) for row in reader if len(row) > 0 ]) for document, context in documents: statement_search_text = document._.search_index statement = Statement( text=document.text, in_response_to=previous_statement_text, conversation='training', created_at=date_parser.parse(context['created_at']), persona=context['persona'], search_text=statement_search_text, search_in_response_to=previous_statement_search_text ) for preprocessor in self.chatbot.preprocessors: statement = preprocessor(statement) previous_statement_text = statement.text previous_statement_search_text = statement_search_text statements_from_file.append(statement) self.chatbot.storage.create_many(statements_from_file) print('Training took', time.time() - start_time, 'seconds.')