Commit 27dc3df4 authored by Russ Fish's avatar Russ Fish

Merge Tim's error handling changes to sync with the testbed/xmlrpc copy, add...

Merge Tim's error handling changes to sync with the testbed/xmlrpc copy, add support for the Win32 platform.
parent 03921d0e
...@@ -3,7 +3,16 @@ ...@@ -3,7 +3,16 @@
# Copyright (c) 2004 University of Utah and the Flux Group. # Copyright (c) 2004 University of Utah and the Flux Group.
# All rights reserved. # All rights reserved.
# #
############################################################################ # Permission to use, copy, modify and distribute this software is hereby
# granted provided that (1) source code retains these copyright, permission,
# and disclaimer notices, and (2) redistributions including binaries
# reproduce the notices in supporting documentation.
#
# THE UNIVERSITY OF UTAH ALLOWS FREE USE OF THIS SOFTWARE IN ITS "AS IS"
# CONDITION. THE UNIVERSITY OF UTAH DISCLAIMS ANY LIABILITY OF ANY KIND
# FOR ANY DAMAGES WHATSOEVER RESULTING FROM THE USE OF THIS SOFTWARE.
#
##########################################################################
# Some bits of this file are from xmlrpclib.py, which is: # Some bits of this file are from xmlrpclib.py, which is:
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Copyright (c) 1999-2002 by Secret Labs AB # Copyright (c) 1999-2002 by Secret Labs AB
...@@ -39,10 +48,14 @@ import urllib ...@@ -39,10 +48,14 @@ import urllib
import popen2 import popen2
import rfc822 import rfc822
import xmlrpclib import xmlrpclib
import syslog if os.name != "nt":
import syslog
# XXX This should come from configure. # XXX This should come from configure.
LOG_TESTBED = syslog.LOG_LOCAL5; if os.name != "nt":
LOG_TESTBED = syslog.LOG_LOCAL5;
import traceback
## ##
# Base class for exceptions in this module. # Base class for exceptions in this module.
...@@ -66,6 +79,21 @@ class BadResponse(SSHException): ...@@ -66,6 +79,21 @@ class BadResponse(SSHException):
pass pass
##
# Indicates a poorly formatted request from the client.
#
class BadRequest(SSHException):
##
# @param host The client host name.
# @param arg Description of the problem.
#
def __init__(self, host, msg):
self.args = host, msg,
return
pass
## ##
# Class used to decode headers. # Class used to decode headers.
# #
...@@ -91,31 +119,47 @@ class SSHConnection: ...@@ -91,31 +119,47 @@ class SSHConnection:
self.host = host self.host = host
# ... initialize the read and write file objects. # ... initialize the read and write file objects.
self.myChild = None
if streams: if streams:
self.myChild = None
self.rfile = streams[0] self.rfile = streams[0]
self.wfile = streams[1] self.wfile = streams[1]
pass pass
else: else:
self.user, self.host = urllib.splituser(self.host) self.user, ssh_host = urllib.splituser(self.host)
# print self.user + " " + self.host + " " + handler # print self.user + " " + self.host + " " + handler
# Use ssh unless we're on Windows with no ssh-agent running.
nt = os.name == "nt"
use_ssh = not nt or os.environ.has_key("SSH_AGENT_PID")
flags = "" flags = ""
if self.user: if self.user:
flags = flags + " -l " + self.user flags = flags + " -l " + self.user
pass pass
if ssh_config: if use_ssh and ssh_config:
flags = flags + " -F " + ssh_config flags = flags + " -F " + ssh_config
pass pass
self.myChild = popen2.Popen3("ssh -x -C -o 'CompressionLevel 5' " args = flags + " " + ssh_host + " " + handler
+ flags
+ " " if use_ssh:
+ self.host cmd = "ssh -x -C -o 'CompressionLevel 5' " + args
+ " " pass
+ handler, else:
1) # Use the PyTTY plink, equivalent to the ssh command.
self.rfile = self.myChild.fromchild cmd = "plink -x -C " + args
self.wfile = self.myChild.tochild pass
if not nt:
# Popen3 objects, and the wait method, are Unix-only.
self.myChild = popen2.Popen3(cmd, 0)
self.rfile = self.myChild.fromchild
self.wfile = self.myChild.tochild
pass
else:
# Open the pipe in Binary mode so it doesn't mess with CR-LFs.
self.rfile, self.wfile, self.errfile = popen2.popen3(cmd, mode='b')
pass
# print "wfile", self.wfile, "rfile", self.rfile
pass pass
return return
...@@ -137,6 +181,7 @@ class SSHConnection: ...@@ -137,6 +181,7 @@ class SSHConnection:
# @return The amount of data written. # @return The amount of data written.
# #
def write(self, stuff): def write(self, stuff):
# print "write", stuff
return self.wfile.write(stuff) return self.wfile.write(stuff)
## ##
...@@ -190,10 +235,13 @@ class SSHTransport: ...@@ -190,10 +235,13 @@ class SSHTransport:
## ##
# @param ssh_config The ssh config file to use when making new connections. # @param ssh_config The ssh config file to use when making new connections.
# @param user_agent Symbolic name for the program acting on behalf of the
# user.
# #
def __init__(self, ssh_config=None): def __init__(self, ssh_config=None, user_agent=None):
self.connections = {} self.connections = {}
self.ssh_config = ssh_config self.ssh_config = ssh_config
self.user_agent = user_agent
return return
## ##
...@@ -226,7 +274,11 @@ class SSHTransport: ...@@ -226,7 +274,11 @@ class SSHTransport:
connection = self.connections[(host,handler)] connection = self.connections[(host,handler)]
# ... send our request, and # ... send our request, and
if self.user_agent:
connection.putheader("user-agent", self.user_agent)
pass
connection.putheader("content-length", len(request_body)) connection.putheader("content-length", len(request_body))
connection.putheader("content-type", "text/xml")
connection.endheaders() connection.endheaders()
connection.write(request_body) connection.write(request_body)
connection.flush() connection.flush()
...@@ -242,6 +294,14 @@ class SSHTransport: ...@@ -242,6 +294,14 @@ class SSHTransport:
def getparser(self): def getparser(self):
return xmlrpclib.getparser() return xmlrpclib.getparser()
##
# @param connection The connection to drop.
#
def drop_connection(self, connection):
del self.connections[(connection.host,connection.handler)]
connection.close()
return
## ##
# Parse the response from the server. # Parse the response from the server.
# #
...@@ -253,6 +313,11 @@ class SSHTransport: ...@@ -253,6 +313,11 @@ class SSHTransport:
try: try:
# Get the headers, # Get the headers,
headers = SSHMessage(connection, False) headers = SSHMessage(connection, False)
if headers.status != "":
self.drop_connection(connection)
raise BadResponse(connection.host,
connection.handler,
headers.status)
# ... the length of the body, and # ... the length of the body, and
length = int(headers['content-length']) length = int(headers['content-length'])
# ... read in the body. # ... read in the body.
...@@ -260,11 +325,11 @@ class SSHTransport: ...@@ -260,11 +325,11 @@ class SSHTransport:
pass pass
except KeyError, e: except KeyError, e:
# Bad header, drop the connection, and # Bad header, drop the connection, and
del self.connections[(connection.host,connection.handler)] self.drop_connection(connection)
connection.close()
# ... tell the user. # ... tell the user.
raise BadResponse(connection.host, connection.handler, e.args) raise BadResponse(connection.host, connection.handler, e.args[0])
# print "response /"+response+"/"
parser.feed(response) parser.feed(response)
return unmarshaller.close() return unmarshaller.close()
...@@ -291,11 +356,11 @@ class SSHServerWrapper: ...@@ -291,11 +356,11 @@ class SSHServerWrapper:
# #
# Init syslog # Init syslog
# #
syslog.openlog("sshxmlrpc", syslog.LOG_PID, LOG_TESTBED); if os.name != "nt":
syslog.syslog(syslog.LOG_INFO, syslog.openlog("sshxmlrpc", syslog.LOG_PID, LOG_TESTBED);
"Connect by " + os.environ['USER'] + " from " + syslog.syslog(syslog.LOG_INFO,
self.ssh_connection[0]); "Connect by " + os.environ['USER'] + " from " +
self.ssh_connection[0]);
return return
## ##
...@@ -310,9 +375,28 @@ class SSHServerWrapper: ...@@ -310,9 +375,28 @@ class SSHServerWrapper:
try: try:
# Read the request, # Read the request,
hdrs = SSHMessage(connection, False) hdrs = SSHMessage(connection, False)
if hdrs.status != "":
#sys.stderr.write("server error: Expecting rfc822 headers.\n");
raise BadRequest(connection.host, hdrs.status)
if not hdrs.has_key('content-length'):
sys.stderr.write("server error: "
+ "expecting content-length header\n")
raise BadRequest(connection.host,
"missing content-length header")
if hdrs.has_key('user-agent'):
user_agent = hdrs['user-agent']
pass
else:
user_agent = "unknown"
pass
length = int(hdrs['content-length']) length = int(hdrs['content-length'])
params, method = xmlrpclib.loads(connection.read(length)) params, method = xmlrpclib.loads(connection.read(length))
syslog.syslog(syslog.LOG_INFO, "Calling method '" + method + "'"); if os.name != "nt":
syslog.syslog(syslog.LOG_INFO,
"Calling method '"
+ method
+ "'; user-agent="
+ user_agent);
try: try:
# ... find the corresponding method in the wrapped object, # ... find the corresponding method in the wrapped object,
meth = getattr(self.myObject, method) meth = getattr(self.myObject, method)
...@@ -329,6 +413,7 @@ class SSHServerWrapper: ...@@ -329,6 +413,7 @@ class SSHServerWrapper:
pass pass
pass pass
except: except:
traceback.print_exc()
# Some other exception happened, convert it to an XML-RPC fault # Some other exception happened, convert it to an XML-RPC fault
response = xmlrpclib.dumps( response = xmlrpclib.dumps(
xmlrpclib.Fault(1, xmlrpclib.Fault(1,
...@@ -376,8 +461,9 @@ class SSHServerWrapper: ...@@ -376,8 +461,9 @@ class SSHServerWrapper:
pass pass
finally: finally:
connection.close() connection.close()
syslog.syslog(syslog.LOG_INFO, "Connection closed"); if os.name != "nt":
syslog.closelog() syslog.syslog(syslog.LOG_INFO, "Connection closed");
syslog.closelog()
pass pass
return return
...@@ -397,8 +483,16 @@ class SSHServerProxy: ...@@ -397,8 +483,16 @@ class SSHServerProxy:
# The default is to use a new SSHTransport object. # The default is to use a new SSHTransport object.
# @param encoding Content encoding. # @param encoding Content encoding.
# @param verbose unused. # @param verbose unused.
# @param user_agent Symbolic name for the program acting on behalf of the
# user.
# #
def __init__(self, uri, transport=None, encoding=None, verbose=0, path=None): def __init__(self,
uri,
transport=None,
encoding=None,
verbose=0,
path=None,
user_agent=None):
type, uri = urllib.splittype(uri) type, uri = urllib.splittype(uri)
if type not in ("ssh", ): if type not in ("ssh", ):
raise IOError, "unsupported XML-RPC protocol" raise IOError, "unsupported XML-RPC protocol"
...@@ -406,7 +500,7 @@ class SSHServerProxy: ...@@ -406,7 +500,7 @@ class SSHServerProxy:
self.__host, self.__handler = urllib.splithost(uri) self.__host, self.__handler = urllib.splithost(uri)
if transport is None: if transport is None:
transport = SSHTransport() transport = SSHTransport(user_agent=user_agent)
pass pass
self.__transport = transport self.__transport = transport
...@@ -454,4 +548,8 @@ class SSHServerProxy: ...@@ -454,4 +548,8 @@ class SSHServerProxy:
# magic method dispatcher # magic method dispatcher
return xmlrpclib._Method(self.__request, name) return xmlrpclib._Method(self.__request, name)
# Locally handle "if not server:".
def __nonzero__(self):
return True
pass pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment