ssh_helper.py 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#!/usr/bin/env python3

# Copyright (C) 2018 Simon Redman <sredman@cs.utah.edu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import getpass
import sys
from pexpect import pxssh
21
from typing import List, Tuple, Optional
22 23


24 25 26 27
DEFAULT_SSH_OPTIONS = {"StrictHostKeyChecking": "no",
                       "UserKnownHostsFile": "/dev/null"}


28
def log_in_password_failover(hostname, username, password="", ssh_options=DEFAULT_SSH_OPTIONS):
29 30 31 32 33 34
    """
    Attempt to log in a pxssh session using only ssh agent-provided public keys, then failover to requesting a password

    :param hostname: remote to connect to
    :param username: login username
    :param password: login password -- Will be prompted for if needed and not passed
35 36
    :param ssh_options: pxssh options, as would be in an ssh config file
    :return: connected pxssh session and the password used to log in to the host
37 38 39
    """
    # First, try in-memory authentication tools, including the password if it is defined
    try:
40 41
        session = pxssh.pxssh(options=ssh_options)
        session.force_password = False
42 43
        session.login(hostname, username, password)
    except pxssh.ExceptionPxssh:
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        # That didn't work. Try a password.
        session = pxssh.pxssh(options=ssh_options)
        session.force_password = True # Probably never necessary -- we already tried key-based login
        password = getpass.getpass("Password for {user}@{host}: ".format(user=username, host=hostname))
        session.login(hostname, username, password)

    return session, password


def log_in_many_sessions(hostnames, usernames, passwords=None, ssh_options=DEFAULT_SSH_OPTIONS):
    """
    Attempt to log in a pxssh session for all requested hosts

    :param hostnames: list of hosts on which commands should be run
    :param usernames: list of usernames to use to log in to each host
    :param passwords: (optional) list of passwords to log in to each host
    :param ssh_options: options to pass to ssh, as per pxssh documentation
    :return: list of pxssh sessions and list of passwords which were used to log in to the hosts, where empty string means no password was used
    """
    sessions = []
    num_hosts = len(hostnames)
    assert len(usernames) == num_hosts, "Please provide one username for every host"
    if passwords is None:
        # Generate empty passwords for every host if no passwords were provided
        passwords = ["" for i in range(len(usernames))]
    assert len(passwords) == num_hosts, "If provided, there must be one password per username"

    for host_idx in range(0, num_hosts):
        hostname = hostnames[host_idx]
        username = usernames[host_idx]
        password = passwords[host_idx]

        if passwords == "":
            # To try to avoid prompting for passwords, try the previously-used password first
            password = passwords[host_idx - 1]
        try:
            session, returned_password = log_in_password_failover(hostname, username, password, ssh_options)
        except pxssh.ExceptionPxssh as e:
            print("Unable to log in to {user}@{host} using in-memory SSH keys nor provided password".format(user=username, host=hostname), file=sys.stderr)
            print(e, file=sys.stderr)
            continue

        sessions.append(session)
        passwords[host_idx] = returned_password
88

89
    return sessions, passwords
90 91


92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
def network_graph_login(netgraph, username):
    """
    Log in to each host in the graph using the given username and annotate the graph with the new session

    :param netgraph: networkx graph of the network topology
    :param username: username to log in to each host as
    :return: None
    """
    hostnames = [node for node in netgraph.nodes]
    usernames = [username for i in range(len(hostnames))]
    sessions, passwords = log_in_many_sessions(hostnames, usernames)
    for host_idx in range(len(hostnames)):
        netgraph._node[hostnames[host_idx]]['session'] = sessions[host_idx]


def network_graph_logout(netgraph):
    """
    Log out of every host in the network and remove the sessions from the netgraph

    :param netgraph: networkx graph, annotated with logged-in SSH sessions
    :return: None
    """

    for node in netgraph._node:
        session = netgraph._node[node]['session']
        session.logout()
        del netgraph._node[node]['session']


121
def _get_output(session, encoding=sys.stdout.encoding, timeout=None):
122 123 124
    """
    Decode the raw bytes written by the SSH session

125
    :param session:  pxssh session to read
126
    :param encoding: encoding to use to interpret the output
127
    :param timeout:  time to wait for a prompt to appear
128 129
    :return: string-form of the read bytes
    """
130
    session.prompt(timeout=timeout)
131
    return str(session.before.decode(encoding))
132 133


134
def _run_command_on_host(session, command):
135 136
    """
    Run a specified command on a single host
137 138
    Note: be sure to match one command execution with one get_output, otherwise weird and
    hard-to-understand things will happen
139

140
    :param session: logged-in pxssh session to run the command on
141
    :param command: command to run
142
    :return: None
143 144 145 146
    """
    session.sendline(command)


147
def unchecked_run_commands_on_many_hosts(sessions: List[pxssh.pxssh], commands: List[str]) -> List[str]:
148
    """
149
    Run one command on each host in the network without error checking
150 151 152 153

    :param sessions: list of logged-in pxssh sessions to run commands on
    :param commands: list of commands to run, one per host
    :return: output from running each command, in the same order as the sessions were presented
154
    """
155 156
    outputs = []
    num_hosts = len(sessions)
157
    assert len(commands) == len(sessions), "Please provide one command for every session"
158 159

    for host_idx in range(0, num_hosts):
160
        session = sessions[host_idx]
161 162
        command = commands[host_idx]

163
        _run_command_on_host(session, command)
164 165 166

    for host_idx in range(0, num_hosts):
        session = sessions[host_idx]
167

168
        output = _get_output(session)
169
        outputs.append(output)
170

171
    return outputs
172 173


174 175 176 177 178 179
def run_commands_on_many_hosts(sessions: List[pxssh.pxssh], commands: List[str]) -> List[str]:
    """
    Run the specified command on each host in the network

    :param sessions: list of logged-in pxssh sessions to run commands on
    :param commands: list of commands to run, one per host
180
    :return: output from running each command, in the same order as the sessions were presented
181 182
    """

183
    outputs = unchecked_run_commands_on_many_hosts(sessions, commands)
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    codes = get_exit_codes(sessions)

    # Check for any error-indicating exit codes
    errors = None
    for index in range(0, len(codes)):
        if codes[index] != 0:
            errors = SSHCommandErrorError(session = sessions[index],
                                          output = outputs[index],
                                          code = codes[index],
                                          next = errors)

    if errors is not None: raise errors

    return outputs


200 201 202 203 204
def get_exit_codes(sessions) -> List[int]:
    """
    Get the exit code of the last command run in each session
    """
    commands = ["echo $?" for session in sessions]
205
    outputs = unchecked_run_commands_on_many_hosts(sessions, commands)
206
    codes_strs = [list(filter(lambda line: len(line) > 0, code.split("\r\n")))[-1] for code in outputs] # Get just the return code (not the echo'ed command)
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    codes = [int(code) for code in codes_strs]
    return codes


class SSHError(Exception):
    """
    Base class for errors related to SSH
    """
    pass


class SSHCommandErrorError(SSHError):
    """
    Raised to indicate one or more errors has occured while running a command
    """

    def __init__(self, session: pxssh.pxssh, output: str, code:int, next: Optional):
        """
        :param session: pxssh session which had the command run
        :param output:  output from the command which hopefully contains more error information
        :param code:    the exit code of the command
        :param next:    next exception, in case many commands were executed and more than one had a problem
        """
        super(SSHCommandErrorError, self).__init__(output)
        self.session = session
        self.output = output
        self.code = code
        self.next = next