Source code for chatterbot.trainers

import os
import csv
import time
import glob
import json
import tarfile
from typing import List, Union
from tqdm import tqdm
from dateutil import parser as date_parser
from chatterbot.chatterbot import ChatBot
from chatterbot.conversation import Statement


class Trainer(object):
    """
    Base class for all other trainer classes.

    :param boolean show_training_progress: Show progress indicators for the
           trainer. The environment variable ``CHATTERBOT_SHOW_TRAINING_PROGRESS``
           can also be set to control this. ``show_training_progress`` will override
           the environment variable if it is set.
    """

    def __init__(self, chatbot: ChatBot, **kwargs):
        self.chatbot = chatbot

        environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))

        self.disable_progress = not kwargs.get(
            'show_training_progress',
            environment_default
        )

    def get_preprocessed_statement(self, input_statement: Statement) -> Statement:
        """
        Preprocess the input statement.
        """
        for preprocessor in self.chatbot.preprocessors:
            input_statement = preprocessor(input_statement)

        return input_statement

    def train(self, *args, **kwargs):
        """
        This method must be overridden by a child class.
        """
        raise self.TrainerInitializationException()

    class TrainerInitializationException(Exception):
        """
        Exception raised when a base class has not overridden
        the required methods on the Trainer base class.
        """

        def __init__(self, message=None):
            default = (
                'A training class must be specified before calling train(). '
                'See https://docs.chatterbot.us/training/'
            )
            super().__init__(message or default)

    def _generate_export_data(self) -> list:
        result = []
        for statement in self.chatbot.storage.filter():
            if statement.in_response_to:
                result.append([statement.in_response_to, statement.text])

        return result

    def export_for_training(self, file_path='./export.json'):
        """
        Create a file from the database that can be used to
        train other chat bots.
        """
        export = {'conversations': self._generate_export_data()}
        with open(file_path, 'w+', encoding='utf8') as jsonfile:
            json.dump(export, jsonfile, ensure_ascii=False)


[docs] class ListTrainer(Trainer): """ Allows a chat bot to be trained using a list of strings where the list represents a conversation. """
[docs] def train(self, conversation: List[str]): """ Train the chat bot based on the provided list of statements that represents a single conversation. """ previous_statement_text = None previous_statement_search_text = '' statements_to_create = [] # Run the pipeline in bulk to improve performance documents = self.chatbot.tagger.as_nlp_pipeline(conversation) for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress): statement_search_text = document._.search_index statement = self.get_preprocessed_statement( Statement( text=document.text, search_text=statement_search_text, in_response_to=previous_statement_text, search_in_response_to=previous_statement_search_text, conversation='training' ) ) previous_statement_text = statement.text previous_statement_search_text = statement_search_text statements_to_create.append(statement) self.chatbot.storage.create_many(statements_to_create)
[docs] class ChatterBotCorpusTrainer(Trainer): """ Allows the chat bot to be trained using data from the ChatterBot dialog corpus. """
[docs] def train(self, *corpus_paths: Union[str, List[str]]): from chatterbot.corpus import load_corpus, list_corpus_files data_file_paths = [] # Get the paths to each file the bot will be trained with for corpus_path in corpus_paths: data_file_paths.extend(list_corpus_files(corpus_path)) for corpus, categories, _file_path in tqdm( load_corpus(*data_file_paths), desc='ChatterBot Corpus Trainer', disable=self.disable_progress ): statements_to_create = [] # Train the chat bot with each statement and response pair for conversation in corpus: # Run the pipeline in bulk to improve performance documents = self.chatbot.tagger.as_nlp_pipeline(conversation) previous_statement_text = None previous_statement_search_text = '' for document in documents: statement_search_text = document._.search_index statement = Statement( text=document.text, search_text=statement_search_text, in_response_to=previous_statement_text, search_in_response_to=previous_statement_search_text, conversation='training' ) statement.add_tags(*categories) statement = self.get_preprocessed_statement(statement) previous_statement_text = statement.text previous_statement_search_text = statement_search_text statements_to_create.append(statement) if statements_to_create: self.chatbot.storage.create_many(statements_to_create)
class GenericFileTrainer(Trainer): """ Allows the chat bot to be trained using data from a CSV or JSON file, or directory of those file types. """ # NOTE: If the value is an integer, this be the # column index instead of the key or header DEFAULT_STATEMENT_TO_HEADER_MAPPING = { 'text': 'text', 'conversation': 'conversation', 'created_at': 'created_at', 'persona': 'persona', 'tags': 'tags' } def __init__(self, chatbot: ChatBot, **kwargs): """ data_path: str The path to the data file or directory. field_map: dict A dictionary containing the column name to header mapping. """ super().__init__(chatbot, **kwargs) self.file_extension = None self.field_map = kwargs.get( 'field_map', self.DEFAULT_STATEMENT_TO_HEADER_MAPPING ) def _get_file_list(self, data_path: str, limit: Union[int, None]): """ Get a list of files to read from the data set. """ if self.file_extension is None: raise self.TrainerInitializationException( 'The file_extension attribute must be set before calling train().' ) # List all csv or json files in the specified directory if os.path.isdir(data_path): glob_path = os.path.join(data_path, '**', f'*.{self.file_extension}') # Use iglob instead of glob for better performance with # large directories because it returns an iterator data_files = glob.iglob(glob_path, recursive=True) for index, file_path in enumerate(data_files): if limit is not None and index >= limit: break yield file_path else: yield data_path def train(self, data_path: str, limit=None): """ Train a chatbot with data from the data file. :param str data_path: The path to the data file or directory. :param int limit: The maximum number of files to train from. """ if data_path is None: raise self.TrainerInitializationException( 'The data_path argument must be set to the path of a file or directory.' ) data_files = self._get_file_list(data_path, limit) files_processed = 0 for data_file in tqdm(data_files, desc='Training', disable=self.disable_progress): previous_statement_text = None previous_statement_search_text = '' file_extension = data_file.split('.')[-1].lower() statements_to_create = [] file_abspath = os.path.abspath(data_file) with open(file_abspath, 'r', encoding='utf-8') as file: if self.file_extension == 'json': data = json.load(file) data = data['conversation'] elif file_extension == 'csv': use_header = bool(isinstance(next(iter(self.field_map.values())), str)) if use_header: data = csv.DictReader(file) else: data = csv.reader(file) elif file_extension == 'tsv': use_header = bool(isinstance(next(iter(self.field_map.values())), str)) if use_header: data = csv.DictReader(file, delimiter='\t') else: data = csv.reader(file, delimiter='\t') else: self.logger.warning(f'Skipping unsupported file type: {file_extension}') continue files_processed += 1 text_row = self.field_map['text'] try: documents = self.chatbot.tagger.as_nlp_pipeline([ ( row[text_row], { # Include any defined metadata columns key: row[value] for key, value in self.field_map.items() if key != text_row } ) for row in data if len(row) > 0 ]) except KeyError as e: raise KeyError( f'{e}. Please check the field_map parameter used to initialize ' f'the training class and remove this value if it is not needed. ' f'Current mapping: {self.field_map}' ) response_to_search_index_mapping = {} if 'in_response_to' in self.field_map.keys(): # Generate the search_in_response_to value for the in_response_to fields response_documents = self.chatbot.tagger.as_nlp_pipeline([ ( row[self.field_map['in_response_to']] ) for row in data if len(row) > 0 and row[self.field_map['in_response_to']] is not None ]) # (Process the response values the same way as the text values) for document in response_documents: response_to_search_index_mapping[document.text] = document._.search_index for document, context in documents: statement = Statement( text=document.text, conversation=context.get('conversation', 'training'), persona=context.get('persona', None), tags=context.get('tags', []) ) if 'created_at' in context: statement.created_at = date_parser.parse(context['created_at']) statement.search_text = document._.search_index # Use the in_response_to attribute for the previous statement if # one is defined, otherwise use the last statement which was created if 'in_response_to' in self.field_map.keys(): statement.in_response_to = context.get(self.field_map['in_response_to'], None) statement.search_in_response_to = response_to_search_index_mapping.get( context.get(self.field_map['in_response_to'], None), '' ) else: # List-type data such as CSVs with no response specified can use # the previous statement as the in_response_to value statement.in_response_to = previous_statement_text statement.search_in_response_to = previous_statement_search_text for preprocessor in self.chatbot.preprocessors: statement = preprocessor(statement) previous_statement_text = statement.text previous_statement_search_text = statement.search_text statements_to_create.append(statement) self.chatbot.storage.create_many(statements_to_create) if files_processed: self.chatbot.logger.info( 'Training completed. {} files were read.'.format(files_processed) ) else: self.chatbot.logger.warning( 'No [{}] files were detected at: {}'.format( self.file_extension, data_path ) )
[docs] class CsvFileTrainer(GenericFileTrainer): """ .. note:: Added in version 1.2.4 Allow chatbots to be trained with data from a CSV file or directory of CSV files. TSV files are also supported, as long as the file_extension parameter is set to 'tsv'. :param str file_extension: The file extension to look for when searching for files (defaults to 'csv'). :param dict field_map: A dictionary containing the database column name to header mapping. Values can be either the header name (str) or the column index (int). """ def __init__(self, chatbot: ChatBot, **kwargs): super().__init__(chatbot, **kwargs) self.file_extension = kwargs.get('file_extension', 'csv')
[docs] class JsonFileTrainer(GenericFileTrainer): """ .. note:: Added in version 1.2.4 Allow chatbots to be trained with data from a JSON file or directory of JSON files. :param dict field_map: A dictionary containing the database column name to header mapping. """ DEFAULT_STATEMENT_TO_KEY_MAPPING = { 'text': 'text', 'conversation': 'conversation', 'created_at': 'created_at', 'in_response_to': 'in_response_to', 'persona': 'persona', 'tags': 'tags' } def __init__(self, chatbot: ChatBot, **kwargs): super().__init__(chatbot, **kwargs) self.file_extension = 'json' self.field_map = kwargs.get( 'field_map', self.DEFAULT_STATEMENT_TO_KEY_MAPPING )
[docs] class UbuntuCorpusTrainer(CsvFileTrainer): """ .. note:: PENDING DEPRECATION: Please use the ``CsvFileTrainer`` for data formats similar to this one. Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. For more information about the Ubuntu Dialog Corpus visit: https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ :param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted. """ def __init__(self, chatbot: ChatBot, **kwargs): super().__init__(chatbot, **kwargs) home_directory = os.path.expanduser('~') self.data_download_url = None self.data_directory = kwargs.get( 'ubuntu_corpus_data_directory', os.path.join(home_directory, 'ubuntu_data') ) # Directory containing extracted data self.data_path = os.path.join( self.data_directory, 'ubuntu_dialogs' ) self.field_map = { 'text': 3, 'created_at': 0, 'persona': 1, } def is_downloaded(self, file_path: str): """ Check if the data file is already downloaded. """ if os.path.exists(file_path): self.chatbot.logger.info('File is already downloaded') return True return False def is_extracted(self, file_path: str): """ Check if the data file is already extracted. """ if os.path.isdir(file_path): self.chatbot.logger.info('File is already extracted') return True return False def download(self, url: str, show_status=True): """ Download a file from the given url. Show a progress indicator for the download status. """ import requests # Create the data directory if it does not already exist if not os.path.exists(self.data_directory): os.makedirs(self.data_directory) file_name = url.split('/')[-1] file_path = os.path.join(self.data_directory, file_name) # Do not download the data if it already exists if self.is_downloaded(file_path): return file_path with open(file_path, 'wb') as open_file: if show_status: print('Downloading %s' % url) response = requests.get(url, stream=True) total_length = response.headers.get('content-length') if total_length is None: # No content length header open_file.write(response.content) else: for data in tqdm( response.iter_content(chunk_size=4096), desc='Downloading', disable=not show_status ): open_file.write(data) if show_status: print('Download location: %s' % file_path) return file_path def extract(self, file_path: str): """ Extract a tar file at the specified file path. """ if not self.disable_progress: print('Extracting {}'.format(file_path)) if not os.path.exists(self.data_path): os.makedirs(self.data_path) def is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == abs_directory def safe_extract(tar, path='.', members=None, *, numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception('Attempted Path Traversal in Tar File') tar.extractall(path, members, numeric_owner=numeric_owner) try: with tarfile.open(file_path, 'r') as tar: safe_extract(tar, path=self.data_path, members=tqdm(tar, disable=self.disable_progress)) except tarfile.ReadError as e: raise self.TrainerInitializationException( f'The provided data file is not a valid tar file: {file_path}' ) from e self.chatbot.logger.info('File extracted to {}'.format(self.data_path)) return True def _get_file_list(self, data_path: str, limit: Union[int, None]): """ Get a list of files to read from the data set. """ if self.data_download_url is None: raise self.TrainerInitializationException( 'The data_download_url attribute must be set before calling train().' ) # Download and extract the Ubuntu dialog corpus if needed corpus_download_path = self.download(self.data_download_url) # Extract if the directory does not already exist if not self.is_extracted(data_path): self.extract(corpus_download_path) extracted_corpus_path = os.path.join( data_path, '**', '**', '*.tsv' ) # Use iglob instead of glob for better performance with # large directories because it returns an iterator data_files = glob.iglob(extracted_corpus_path) for index, file_path in enumerate(data_files): if limit is not None and index >= limit: break yield file_path
[docs] def train(self, data_download_url: str, limit: Union[int, None] = None): """ :param str data_download_url: The URL to download the Ubuntu dialog corpus from. :param int limit: The maximum number of files to train from. """ self.data_download_url = data_download_url start_time = time.time() super().train(self.data_path, limit=limit) if not self.disable_progress: print('Training took', time.time() - start_time, 'seconds.')