import torch from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights from torchvision.io.image import write_jpeg from torchvision.transforms.functional import convert_image_dtype from classify import predict import time import numpy as np import cv2 from http.server import HTTPServer, BaseHTTPRequestHandler # initial banana prediction model def predict_segmented(post_body): model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1, progress=False) model = model.eval() bytes = post_body img = torch.frombuffer(bytes, dtype=torch.uint8) img = torch.reshape(img, (240,320,3)) img = torch.flip(img, [0]) img = torch.transpose(img, 0, 1) img = torch.transpose(img, 0, 2) #write image for debugging write_jpeg(img, "image_out.jpg") batch_int = torch.stack([img]) batch = convert_image_dtype(batch_int, dtype=torch.float) start = time.time() output = model(batch) print(time.time()-start) inst_classes = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)} print(" the following instances were detected:") print([inst_classes[label] for label in output[0]['labels']]) return 'banana' in [inst_classes[label] for label in output[0]['labels']] #better banana detection model def predictYOLO(postbody): img = np.frombuffer(postbody, dtype=np.uint8) img = np.reshape(img, (240,320,3)) cv2.imwrite("out.jpg", img) # I made some changes to the yolov5 predict to return the top 20 predicted items # if a banana is one of the top 20 items, it returns true return 'banana' in predict.run(buffer=img) class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): def __init__(self, request, client_address, server): self.protocol_version = 'HTTP/1.1' super().__init__(request, client_address, server) def do_POST(self): content_len = int(self.headers.get('Content-Length')) post_body = self.rfile.read(content_len) print(len(post_body)) start = time.time() if predictYOLO(post_body): print("BANANANA") self.send_response(200) self.send_header('Content-type', 'text/plain') self.send_header('Content-length', len('banana')) self.end_headers() self.wfile.write(b'banana') else: self.send_response(200) self.send_header('Content-type', 'text/plain') self.send_header('Content-length', len('nope!')) self.end_headers() self.wfile.write(b'nope!') # times model print(time.time()-start) # simple get request for testing def do_GET(self): print("gotten") self.send_response(200) self.send_header('Content-type', 'text/plain') self.send_header('Content-length', len('hi')) self.end_headers() self.wfile.write(b'hi') try: httpd = HTTPServer(('', 8000), SimpleHTTPRequestHandler) print(httpd.server_name) httpd.timeout = 60 httpd.handle_timeout = lambda: (_ for _ in ()).throw(TimeoutError()) while True: httpd.handle_request() except TimeoutError: print("timed out")