#!/usr/bin/env python '''Plot rectangular grid of values in 3d''' # Sky Coyote, June 2008 import sys, signal from numpy import * import wx class Grid3dPlot(wx.Frame): '''Frame containing 3d array as wireframe plot''' def __init__(self, data, mask, title='3d plot', size=(768, 768), pos=(-1, -1), alt=20.0, az=30.0, mag=None, zscale=1.0): '''Initialize frame''' # create frame wx.Frame.__init__(self, parent=None, id=-1, title=title, size=size, pos=pos) self.data = data[::-1, :] self.mask = mask[::-1, :] # create lists of {x, y, z} points x = arange(1.0 * mask.shape[1]).repeat(mask.shape[0]).reshape(mask.shape[1], mask.shape[0]).T.ravel() y = arange(1.0 * mask.shape[0]).repeat(mask.shape[1]) z = self.data.ravel() xyz = concatenate((x, y, z)).reshape(3, mask.shape[0], mask.shape[1]) self.xyz = xyz.swapaxes(0, 2).swapaxes(0, 1).reshape(mask.shape[0] * mask.shape[1], 3) self.alt = alt self.az = az if mag == None: # fill window self.mag = sqrt(size[0]**2 + size[1]**2) / sqrt(mask.shape[0]**2 + mask.shape[1]**2) else: self.mag = mag self.dataCenter = array((mask.shape[1] / 2, mask.shape[0] / 2, data[mask].mean())) self.zmax = data[mask].max() self.zmin = data[mask].min() if self.zmin == self.zmax: self.zmin -= 1.0 self.zmax += 1.0 self.scale = array((1.0, 1.0, zscale)) self.SetBackgroundColour('black') self.ctrlKeyDown = False self.shiftKeyDown = False self.leftDown = False # bind event handlers self.Bind(wx.EVT_KEY_DOWN, self.OnKeyDown) self.Bind(wx.EVT_KEY_UP, self.OnKeyUp) self.Bind(wx.EVT_PAINT, self.OnPaint) self.Bind(wx.EVT_LEFT_DOWN, self.OnLeftDown) self.Bind(wx.EVT_LEFT_UP, self.OnLeftUp) self.Bind(wx.EVT_MOTION, self.OnMotion) self.Bind(wx.EVT_SIZE, self.OnSize) def OnKeyDown(self, evt): '''Handle key down''' keycode = evt.GetKeyCode() if keycode == 308: # control self.ctrlKeyDown = True elif keycode == 67: # c if self.ctrlKeyDown: sys.exit(0) elif keycode == 314: # left arrow self.az -= 10.0 self.OnPaint(None) elif keycode == 315: # up arrow self.alt -= 10.0 self.OnPaint(None) elif keycode == 316: # right arrow self.az += 10.0 self.OnPaint(None) elif keycode == 317: # down arrow self.alt += 10.0 self.OnPaint(None) elif keycode == 61: # plus self.mag *= 1.1 self.OnPaint(None) elif keycode == 45: # minus self.mag /= 1.1 self.OnPaint(None) def OnKeyUp(self, evt): '''Handle key up''' keycode = evt.GetKeyCode() if keycode == 308: # control self.ctrlKeyDown = False def OnLeftDown(self, evt): '''Handle mouse down''' self.mousePosition0 = evt.GetPositionTuple() self.alt0 = self.alt self.az0 = self.az self.mag0 = self.mag self.shiftKeyDown = evt.ShiftDown() self.leftDown = True self.CaptureMouse() def OnLeftUp(self, evt): '''Handle mouse up''' if self.HasCapture(): self.ReleaseMouse() self.shiftKeyDown = False self.leftDown = False def OnMotion(self, evt): '''Handle mouse movement''' if evt.Dragging() and evt.LeftIsDown() and self.leftDown: mousePosition = evt.GetPositionTuple() if self.shiftKeyDown: # change magnification width, height = self.GetSize() cx = width / 2.0 cy = height / 2.0 self.mag = self.mag0 * sqrt((mousePosition[0] - cx)**2 + (mousePosition[1] - cy)**2) \ / sqrt((self.mousePosition0[0] - cx)**2 + (self.mousePosition0[1] - cy)**2) else: # rotate self.alt = self.alt0 + (mousePosition[1] - self.mousePosition0[1]) / 5.0 self.az = self.az0 + (mousePosition[0] - self.mousePosition0[0]) / 5.0 self.OnPaint(None) evt.Skip() def OnSize(self, evt): '''Handle resize event''' self.Refresh() def angleAxisMatrix(self, theta, axis): '''Create matrix for rotation about arbitrary axis''' # http://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation v1 = axis[0] v2 = axis[1] v3 = axis[2] t1 = cos(theta) t2 = 1 - t1 t3 = v1*v1 t6 = t2*v1 t7 = t6*v2 t8 = sin(theta) t9 = t8*v3 t11 = t6*v3 t12 = t8*v2 t15 = v2*v2 t19 = t2*v2*v3 t20 = t8*v1 t24 = v3*v3 R = matrix(zeros((3, 3))) R[0, 0] = t1 + t2*t3 R[0, 1] = t7 - t9 R[0, 2] = t11 + t12 R[1, 0] = t7 + t9 R[1, 1] = t1 + t2*t15 R[1, 2] = t19 - t20 R[2, 0] = t11 - t12 R[2, 1] = t19 + t20 R[2, 2] = t1 + t2*t24 return R def setXform(self): '''Set 3d transformtaion values from alt, az, mag''' width, height = self.GetSize() self.windowCenter = array((width / 2.0, height / 2.0)) # rotate about z axis r1 = self.angleAxisMatrix(-self.az * pi / 180.0, (0.0, 0.0, 1.0)) # rotate about x axis r2 = self.angleAxisMatrix((90.0 - self.alt) * pi / 180.0, (1.0, 0.0, 0.0)) # composite of both self.rMatrix = r2 * r1 def xformXYZtoHV(self, xyz): '''Transform entire (x, y, z) array to (h, v)''' # subtract data center # scale axes # transpose # rotate # transpose # strip coords # magnify # move to window center return self.windowCenter + (self.mag * array((self.rMatrix * matrix(self.scale \ * (xyz - self.dataCenter)).T).T[:, 0:2])) def hsv2rgb(self, h, s, v): '''Convert hue/saturation/value to red/green/blue''' # http://en.wikipedia.org/wiki/HSV_color_space hi = int(h / 60.0) % 6 f = h / 60.0 - int(h / 60.0) p = v * (1.0 - s) q = v * (1.0 - f * s) t = v * (1.0 - (1.0 - f) * s) if hi == 0: r, g, b = v, t, p elif hi == 1: r, g, b = q, v, p elif hi == 2: r, g, b = p, v, t elif hi == 3: r, g, b = p, q, v elif hi == 4: r, g, b = t, p, v elif hi == 5: r, g, b = v, p, q return (r * 255, g * 255, b * 255) def OnPaint(self, event): '''Draw wireframe plot''' # setup transform self.setXform() # transform all pts hv = self.xformXYZtoHV(self.xyz) h = hv[:, 0].reshape(self.mask.shape[0], self.mask.shape[1]) v = hv[:, 1].reshape(self.mask.shape[0], self.mask.shape[1]) dc = wx.PaintDC(self) dc.Clear() # accumulate lines and colors lineList = [] penList = [] dz = self.zmax - self.zmin # weft lines for j in range(self.mask.shape[0]): for i in range(self.mask.shape[1] - 1): if self.mask[j, i] and self.mask[j, i + 1]: lineList.append((h[j, i], v[j, i], h[j, i + 1], v[j, i + 1])) # convert to rgb x = (self.data[j, i] + self.data[j, i + 1]) / 2.0 x = 270.0 - 360.0 * (x - self.zmin) / dz x = x % 360.0 penList.append(wx.Pen(self.hsv2rgb(x, 1.0, 1.0))) # warp lines for i in range(self.mask.shape[1]): for j in range(self.mask.shape[0] - 1): if self.mask[j, i] and self.mask[j + 1, i]: lineList.append((h[j, i], v[j, i], h[j + 1, i], v[j + 1, i])) # convert to rgb x = (self.data[j, i] + self.data[j + 1, i]) / 2.0 x = 270.0 - 360.0 * (x - self.zmin) / dz x = x % 360.0 penList.append(wx.Pen(self.hsv2rgb(x, 1.0, 1.0))) dc.DrawLineList(lineList, penList) # draw labels dc.SetTextForeground('White') dc.SetTextBackground('Black') y = 2; dy = 13 dc.DrawText('Alt=%.2f' % (self.alt), 2, y); y += dy dc.DrawText('Az=%.2f' % (self.az), 2, y); y += dy dc.DrawText('Mag=%.2f' % (self.mag), 2, y); y += dy def readData(fname): '''Read data file and header''' header = {} dataList = [] dataFile = open(fname, 'r') # read each line for line in dataFile: # split line into fields fields = line.strip().split() if len(fields) < 1: continue try: row = map(float, fields) dataList.append(row) except: if len(fields) > 1: name = fields[0] value = fields[1] try: header[name] = float(value) except: header[name] = value dataFile.close() return (header, array(dataList)) def sigintHandler(a, b): '''Control-C handler''' sys.exit(0) if __name__ == '__main__': if len(sys.argv) < 2: print 'Usage: Plot3dGrid.py data-file [x-decimation y-decimation z-scale]' print ' data-file = text file of 2d array with optional header' print ' set to "demo" for demo mode' print ' x-decimation = plot every nth column of data (default 1)' print ' y-decimation = plot every mth row of data (default 1)' print ' z-scale = exaggerate z values by factor (default 1.0)' sys.exit(1) # install ctrl-c handler signal.signal(signal.SIGINT, sigintHandler) # get arguments fname = sys.argv[1] xdec = 1 if len(sys.argv) > 2: xdec = int(sys.argv[2]) ydec = 1 if len(sys.argv) > 3: ydec = int(sys.argv[3]) zscale = 1.0 if len(sys.argv) > 4: zscale = float(sys.argv[4]) if fname == 'demo': # create demo data data = fromfunction(lambda j, i: sin(j / 512.0 * 4.0 * pi) * sin(i / 512.0 * 4.0 * pi), (512, 512)) valid = data >= 0 else: # read data and header header, data = readData(fname) print 'Header = ', header print 'Shape = ', data.shape # set mask if 'NODATA_value' in header: valid = data != header['NODATA_value'] else: valid = data == data print 'Valid data: %d entries, max = %g, min= %g, mean = %g, std = %g' % \ (valid.sum(), data[valid].max(), data[valid].min(), data[valid].mean(), data[valid].std()) # create app app = wx.PySimpleApp(False) # create window win = Grid3dPlot(data[::ydec, ::xdec], valid[::ydec, ::xdec], fname, zscale=zscale) win.Show() # handle events app.MainLoop()