#!/usr/bin/env python '''Basic 1d, 2d, and image plots using wxPython''' # AstroPy # Sky Coyote, Oct 2007 # Import modules used from numpy import * import wx class PlotPanel(wx.Panel): '''Panel with 1 or more 1d curves''' def __init__(self, parent, data, min=0, max=0, selected=None): '''Initialize panel''' wx.Panel.__init__(self, parent) self.data = 1.0 * data # numpy 1d or 2d array self.selected = selected # list of booleans if self.selected != None: self.selected = self.selected[:] # scale plot self.min = min self.max = max if self.min == 0 and self.max == 0: self.min = self.data.min() self.max = self.data.max() self.SetBackgroundColour('White') self.Bind(wx.EVT_PAINT, self.OnPaint) def OnPaint(self, event): '''Draw plot''' dc = wx.PaintDC(self) width, height = self.GetSize() # draw axes dc.DrawLine(10, height - 10, width - 10, height - 10) dc.DrawLine(10, 10, 10, height - 10) # draw zero if self.max > 0 and self.min < 0: dc.SetPen(wx.Pen('Green')) v = height - 10 - (height - 20) * (0 - self.min) / (self.max - self.min) dc.DrawLine(10, v, width - 10, v) dc.SetPen(wx.Pen('Red')) # max == min? if self.min == self.max: self.min -= 1; self.max += 1 # compute coords v = height - 10 - (height - 20) * (self.data - self.min) / (self.max - self.min) if len(self.data.shape) > 1: # draw several curves h = 10 + (width - 20) * arange(self.data.shape[1]) / (self.data.shape[1] - 1) for i in range(self.data.shape[0]): dc.DrawLines(concatenate((h, v[i])).reshape(2, self.data.shape[1]).transpose()) else: # draw one curve h = 10 + (width - 20) * arange(self.data.shape[0]) / (self.data.shape[0] - 1) dc.DrawLines(concatenate((h, v)).reshape(2, self.data.shape[0]).transpose()) # select points if self.selected != None: for i in range(len(self.selected)): if self.selected[i]: self.drawPoint(dc, h[i], v[i]) def drawPoint(self, dc, x, y): '''Draw a point''' pen = dc.GetPen() dc.SetPen(wx.Pen('Blue')) dc.DrawLine(x - 5, y, x + 5, y) dc.DrawLine(x, y - 5, x, y + 5) dc.SetPen(pen) def reload(self, data, min=0, max=0, selected=None): '''Load other data''' self.data = 1.0 * data # numpy 1d or 2d array self.selected = selected # list of booleans if self.selected != None: self.selected = self.selected[:] # scale plot self.min = min self.max = max if self.min == 0 and self.max == 0: self.min = self.data.min() self.max = self.data.max() self.Refresh() class PlotFrame(wx.Frame): '''Frame containing 1d plot''' def __init__(self, data, title='Plot', min=0, max=0, pos=(-1, -1), size=(500, 422), \ selected=None): '''Initialize frame''' wx.Frame.__init__(self, parent=None, title=title, pos=pos, size=size) self.panel = PlotPanel(self, data, min, max, selected=selected) def reload(self, data, title=None, min=0, max=0): '''Load other data''' self.panel.reload(data, min, max) if title != None: self.SetTitle(title) #------------------------------------------------------------------------------- class Plot2Panel(wx.Panel): '''Panel with 1 or more 2d curves''' def __init__(self, parent, data, min=(0, 0), max=(0, 0), selected=None): '''Initialize panel''' wx.Panel.__init__(self, parent) self.data = 1.0 * data # numpy 2d or 3d array self.selected = selected # list of booleans if self.selected != None: self.selected = self.selected[:] # scale plot self.min = 1.0 * array(min) self.max = 1.0 * array(max) for i in range(2): if self.min[i] == 0 and self.max[i] == 0: if len(self.data.shape) > 2: self.min[i] = self.data[:,:,i].min() self.max[i] = self.data[:,:,i].max() else: self.min[i] = self.data[:,i].min() self.max[i] = self.data[:,i].max() self.SetBackgroundColour('White') self.Bind(wx.EVT_PAINT, self.OnPaint) def OnPaint(self, event): '''Draw 2d plot''' dc = wx.PaintDC(self) width, height = self.GetSize() # draw axes dc.DrawLine(10, height - 10, width - 10, height - 10) dc.DrawLine(10, 10, 10, height - 10) # draw zeros if self.max[0] > 0 and self.min[0] < 0: dc.SetPen(wx.Pen('Green')) h = 10 + (width - 20) * (0 - self.min[0]) / (self.max[0] - self.min[0]) dc.DrawLine(h, 10, h, height - 10) if self.max[1] > 0 and self.min[1] < 0: dc.SetPen(wx.Pen('Green')) v = height - 10 - (height - 20) * (0 - self.min[1]) / (self.max[1] - self.min[1]) dc.DrawLine(10, v, width - 10, v) dc.SetPen(wx.Pen('Red')) # max == min? if self.min[0] == self.max[0]: self.min[0] -= 1; self.max[0] += 1 if self.min[1] == self.max[1]: self.min[1] -= 1; self.max[1] += 1 pass # draw curve(s): data is curves x points x 2 if len(self.data.shape) > 2: # draw several curves for i in range(self.data.shape[0]): xy = 1.0 * self.data[i, :, :] xy[:, 0] = 10 + (width - 20) * (xy[:, 0] - self.min[0]) / (self.max[0] - self.min[0]) xy[:, 1] = height - 10 - (height - 20) * (xy[:, 1] - self.min[1]) / (self.max[1] - self.min[1]) dc.DrawLines(xy) else: # draw one curve xy = 1.0 * self.data xy[:, 0] = 10 + (width - 20) * (xy[:, 0] - self.min[0]) / (self.max[0] - self.min[0]) xy[:, 1] = height - 10 - (height - 20) * (xy[:, 1] - self.min[1]) / (self.max[1] - self.min[1]) dc.DrawLines(xy) # select points if self.selected != None: for i in range(len(self.selected)): if self.selected[i]: self.drawPoint(dc, xy[i, 0], xy[i, 1]) def drawPoint(self, dc, x, y): '''Draw a point''' pen = dc.GetPen() dc.SetPen(wx.Pen('Blue')) dc.DrawLine(x - 5, y, x + 5, y) dc.DrawLine(x, y - 5, x, y + 5) dc.SetPen(pen) def reload(self, data, min=0, max=0, selected=None): '''Load other data''' self.data = 1.0 * data # numpy 2d or 3d array self.selected = selected # list of booleans if self.selected != None: self.selected = self.selected[:] # scale plot self.min = 1.0 * array(min) self.max = 1.0 * array(max) for i in range(2): if self.min[i] == 0 and self.max[i] == 0: if len(self.data.shape) > 2: self.min[i] = self.data[:,:,i].min() self.max[i] = self.data[:,:,i].max() else: self.min[i] = self.data[:,i].min() self.max[i] = self.data[:,i].max() self.Refresh() class Plot2Frame(wx.Frame): '''Frame containing 2d plot''' def __init__(self, data, title='Plot', min=(0, 0), max=(0, 0), pos=(-1, -1), size=(500, 522), \ selected=None): '''Initialize frame''' wx.Frame.__init__(self, parent=None, title=title, pos=pos, size=size) self.panel = Plot2Panel(self, data, min, max, selected=selected) def reload(self, data, title=None, min=(0, 0), max=(0, 0)): '''Load other data''' self.panel.reload(data, min, max) if title != None: self.SetTitle(title) #------------------------------------------------------------------------------- class FITSFrame(wx.Frame): '''Frame containing FITS image''' def __init__(self, data, title='FITS image', min=0, max=0, pos=(-1, -1), \ pts=None, circ=None, circ2=None): '''Initialize frame''' self.pts = pts self.circ = circ self.circ2 = circ2 # create bitmap bmp = self.fits2bitmap(data, min, max) size = bmp.GetWidth(), bmp.GetHeight() # plus title bar? # create frame wx.Frame.__init__(self, parent=None, id=-1, title=title, size=size, pos=pos, \ style=wx.DEFAULT_FRAME_STYLE ^ (wx.RESIZE_BORDER | wx.MAXIMIZE_BOX)) # create bitmap in frame self.bmp = wx.StaticBitmap(parent=self, bitmap=bmp) # resize frame to show all of bitmap self.SetSize(self.GetBestSize()) def fits2bitmap(self, data, min=0, max=0): '''Convert ndarray to wx.Bitmap''' # clip image if min == 0 and max == 0: min = data.min() max = data.max() data = data.clip(min=min, max=max) # Create color planes red = 1.0 * data green = 1.0 * data blue = 1.0 * data # Add geometric items if self.circ2 != None: self.addCircle(red, self.circ2, min) self.addCircle(green, self.circ2, min) self.addCircle(blue, self.circ2, max) if self.circ != None: self.addCircle(red, self.circ, min) self.addCircle(green, self.circ, max) self.addCircle(blue, self.circ, min) if self.pts != None: for i in range(self.pts.shape[0]): self.addPoint(red, self.pts[i], max) self.addPoint(green, self.pts[i], min) self.addPoint(blue, self.pts[i], min) # Flip arrays red = red[::-1, :] green = green[::-1, :] blue = blue[::-1, :] # Scale and convert to bytes if max == min: redBytes = (127.0 * ones(data.shape)).astype('b') greenBytes = (127.0 * ones(data.shape)).astype('b') blueBytes = (127.0 * ones(data.shape)).astype('b') else: redBytes = (255.0 * (red - min) / (max - min)).astype('b') greenBytes = (255.0 * (green - min) / (max - min)).astype('b') blueBytes = (255.0 * (blue - min) / (max - min)).astype('b') # Create composite composite = concatenate((redBytes, greenBytes, blueBytes)).reshape(3, \ data.shape[0], data.shape[1]) # Shuffle and get raw data byteString = composite.swapaxes(0, 2).swapaxes(0, 1).tostring() # create wx.Image from byte buffer return wx.ImageFromData(data.shape[1], data.shape[0], \ byteString).ConvertToBitmap() def addPoint(self, im, pt, val): self.addLine(im, pt[0] - 3, pt[1], pt[0] + 3, pt[1], val) self.addLine(im, pt[0], pt[1] - 3, pt[0], pt[1] + 3, val) def addLine(self, im, x1, y1, x2, y2, val): dx = 1.0 * (x2 - x1) dy = 1.0 * (y2 - y1) mag = 1.0 * sqrt(dx**2 + dy**2) if mag > 0: dx /= mag dy /= mag i = int(round(x1)) j = int(round(y1)) if i >= 0 and i < im.shape[1] and j >= 0 and j < im.shape[0]: im[j, i] = val z = 0.0 while z < mag: x = x1 + z * dx y = y1 + z * dy i = int(round(x)) j = int(round(y)) if i >= 0 and i < im.shape[1] and j >= 0 and j < im.shape[0]: im[j, i] = val z += 1.0 i = int(round(x2)) j = int(round(y2)) if i >= 0 and i < im.shape[1] and j >= 0 and j < im.shape[0]: im[j, i] = val def addCircle(self, im, circ, val): dw = 1.0 / circ[2] w = 0.0 while w < 2.0 * pi: x = circ[0] + circ[2] * cos(w) y = circ[1] + circ[2] * sin(w) i = int(round(x)) j = int(round(y)) if i >= 0 and i < im.shape[1] and j >= 0 and j < im.shape[0]: im[j, i] = val w += dw def reload(self, data, title=None, min=0, max=0, pts=None, circ=None): '''Load another image''' self.pts = pts self.circ = circ # create bitmap bmp = self.fits2bitmap(data, min, max) # reset bitmap self.bmp.SetBitmap(bmp) # resize frame to show all of bitmap self.SetSize(self.GetBestSize()) # reset title if title != None: self.SetTitle(title) #------------------------------------------------------------------------------- if __name__ == '__main__': # create app app = wx.PySimpleApp(False) pz = 35; px = pz; py = pz # create image data r = fromfunction(lambda x, y: sqrt((x - 256)**2 + (y - 256)**2), (512, 512)) data = exp(-r**2 / 1e4) * cos(r**2 / 1e3)**2 # create points pts = [] for i in range(0, 360, 15): ang = i * pi / 180.0 x = 256 + 128 * cos(ang) y = 256 + 128 * sin(ang) pts.append((x, y)) pts = array(pts) # plot image + points + circle win = FITSFrame(data, 'exp(-r**2/1e4)*cos(r**2/1e3)**2 +pts+circ at r=128', pos=(px, py), \ pts=pts, circ=(256, 256, 128)) px += pz; py += pz; win.Show() # create 1d data data2 = sin(arange(1000) * 10 * pi / 999) scale = arange(10) # select some points selected = data2 >= 0.5 # plot 1 curve win = PlotFrame(data2, '1 1d curve with y >= 0.5 selected', pos=(px, py), selected=selected) px += pz; py += pz; win.Show() # plot 10 curves win = PlotFrame(outer(scale, data2), '10 1d curves', pos=(px, py)) px += pz; py += pz; win.Show() # create 2d data ang = arange(1000) / 999.0 * 2.0 * pi x = cos(ang) y = sin(ang) data = concatenate((x, y)).reshape(2,len(ang)).transpose() # select some points selected = data[:, 1] >= 0.5 # plot 1 curve win = Plot2Frame(data, '1 2d curve with y >= 0.5 selected', pos=(px, py), selected=selected) px += pz; py += pz; win.Show() # plot 10 curves data = outer(scale, data).reshape(len(scale), data.shape[0], data.shape[1]) win = Plot2Frame(data, '10 2d curves', pos=(px, py)) px += pz; py += pz; win.Show() app.MainLoop()