ssh_helper.py 6.48 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#!/usr/bin/env python3

# Copyright (C) 2018 Simon Redman <sredman@cs.utah.edu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import getpass
import sys
from pexpect import pxssh
21
from typing import List
22 23


24 25 26 27
DEFAULT_SSH_OPTIONS = {"StrictHostKeyChecking": "no",
                       "UserKnownHostsFile": "/dev/null"}


28
def log_in_password_failover(hostname, username, password="", ssh_options=DEFAULT_SSH_OPTIONS):
29 30 31 32 33 34
    """
    Attempt to log in a pxssh session using only ssh agent-provided public keys, then failover to requesting a password

    :param hostname: remote to connect to
    :param username: login username
    :param password: login password -- Will be prompted for if needed and not passed
35 36
    :param ssh_options: pxssh options, as would be in an ssh config file
    :return: connected pxssh session and the password used to log in to the host
37 38 39
    """
    # First, try in-memory authentication tools, including the password if it is defined
    try:
40 41
        session = pxssh.pxssh(options=ssh_options)
        session.force_password = False
42 43
        session.login(hostname, username, password)
    except pxssh.ExceptionPxssh:
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        # That didn't work. Try a password.
        session = pxssh.pxssh(options=ssh_options)
        session.force_password = True # Probably never necessary -- we already tried key-based login
        password = getpass.getpass("Password for {user}@{host}: ".format(user=username, host=hostname))
        session.login(hostname, username, password)

    return session, password


def log_in_many_sessions(hostnames, usernames, passwords=None, ssh_options=DEFAULT_SSH_OPTIONS):
    """
    Attempt to log in a pxssh session for all requested hosts

    :param hostnames: list of hosts on which commands should be run
    :param usernames: list of usernames to use to log in to each host
    :param passwords: (optional) list of passwords to log in to each host
    :param ssh_options: options to pass to ssh, as per pxssh documentation
    :return: list of pxssh sessions and list of passwords which were used to log in to the hosts, where empty string means no password was used
    """
    sessions = []
    num_hosts = len(hostnames)
    assert len(usernames) == num_hosts, "Please provide one username for every host"
    if passwords is None:
        # Generate empty passwords for every host if no passwords were provided
        passwords = ["" for i in range(len(usernames))]
    assert len(passwords) == num_hosts, "If provided, there must be one password per username"

    for host_idx in range(0, num_hosts):
        hostname = hostnames[host_idx]
        username = usernames[host_idx]
        password = passwords[host_idx]

        if passwords == "":
            # To try to avoid prompting for passwords, try the previously-used password first
            password = passwords[host_idx - 1]
        try:
            session, returned_password = log_in_password_failover(hostname, username, password, ssh_options)
        except pxssh.ExceptionPxssh as e:
            print("Unable to log in to {user}@{host} using in-memory SSH keys nor provided password".format(user=username, host=hostname), file=sys.stderr)
            print(e, file=sys.stderr)
            continue

        sessions.append(session)
        passwords[host_idx] = returned_password
88

89
    return sessions, passwords
90 91


92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
def network_graph_login(netgraph, username):
    """
    Log in to each host in the graph using the given username and annotate the graph with the new session

    :param netgraph: networkx graph of the network topology
    :param username: username to log in to each host as
    :return: None
    """
    hostnames = [node for node in netgraph.nodes]
    usernames = [username for i in range(len(hostnames))]
    sessions, passwords = log_in_many_sessions(hostnames, usernames)
    for host_idx in range(len(hostnames)):
        netgraph._node[hostnames[host_idx]]['session'] = sessions[host_idx]


def network_graph_logout(netgraph):
    """
    Log out of every host in the network and remove the sessions from the netgraph

    :param netgraph: networkx graph, annotated with logged-in SSH sessions
    :return: None
    """

    for node in netgraph._node:
        session = netgraph._node[node]['session']
        session.logout()
        del netgraph._node[node]['session']


121
def get_output(session, encoding=sys.stdout.encoding):
122 123 124 125 126 127 128
    """
    Decode the raw bytes written by the SSH session

    :param session: pxssh session to read
    :param encoding: encoding to use to interpret the output
    :return: string-form of the read bytes
    """
129
    session.prompt()
130
    return str(session.before.decode(encoding))
131 132


133
def run_command_on_host(session, command):
134 135 136
    """
    Run a specified command on a single host

137
    :param session: logged-in pxssh session to run the command on
138
    :param command: command to run
139
    :return: None
140 141 142 143
    """
    session.sendline(command)


144
def run_commands_on_many_hosts(sessions, commands):
145 146 147
    """
    Add the specified command on each host in the network

148
    :param sessions: list of logged-in pxssh sessions to run commands on
149
    :param commands: list of commands to run, one per host
150
    :return: output from running each command, in the same order as the sessions were presented
151
    """
152 153 154
    outputs = []
    num_hosts = len(sessions)
    assert len(commands) == num_hosts, "Please provide one command for every session"
155 156

    for host_idx in range(0, num_hosts):
157
        session = sessions[host_idx]
158 159
        command = commands[host_idx]

160 161 162 163
        run_command_on_host(session, command)

    for host_idx in range(0, num_hosts):
        session = sessions[host_idx]
164

165 166
        output = get_output(session)
        outputs.append(output)
167

168
    return outputs
169 170 171 172 173 174 175 176 177 178 179 180


def get_exit_codes(sessions) -> List[int]:
    """
    Get the exit code of the last command run in each session
    """

    commands = ["echo $?" for session in sessions]
    codes = run_commands_on_many_hosts(sessions, commands)
    codes = [code.split()[2] for code in codes] # Get just the return code (not the echo'ed command)
    codes = [int(code) for code in codes]
    return codes