Initial commit

This commit is contained in:
2021-01-22 19:01:08 +03:00
commit 19a614b29b
4 changed files with 222 additions and 0 deletions

1
README.md Normal file
View File

@@ -0,0 +1 @@
# Cat Fountain

BIN
TFLite_model/detect.tflite Normal file

Binary file not shown.

91
TFLite_model/labelmap.txt Normal file
View File

@@ -0,0 +1,91 @@
???
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
???
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
???
backpack
umbrella
???
???
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
???
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
???
dining table
???
???
toilet
???
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
???
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

130
main.py Normal file
View File

@@ -0,0 +1,130 @@
import os
import cv2
import numpy as np
import importlib.util
from threading import Thread
import time
IM_WIDTH = 1280
IM_HEIGHT = 720
camera_type = 'usb'
MODEL_NAME = 'TFLite_model'
PATH_TO_CKPT = os.path.join(os.getcwd(), MODEL_NAME, 'detect.tflite')
PATH_TO_LABELS = os.path.join(os.getcwd(), MODEL_NAME, 'labelmap.txt')
min_conf_threshold = 0.5
NUM_CLASSES = 90
# Import TensorFlow libraries
# If tflite_runtime is installed, import interpreter from tflite_runtime, else import from regular tensorflow
pkg = importlib.util.find_spec('tflite_runtime')
if pkg:
from tflite_runtime.interpreter import Interpreter
else:
from tensorflow.lite.python.interpreter import Interpreter
# Load the label map
with open(PATH_TO_LABELS, 'r') as f:
labels = [line.strip() for line in f.readlines()]
if labels[0] == '???':
del(labels[0])
interpreter = Interpreter(model_path=PATH_TO_CKPT)
interpreter.allocate_tensors()
# Get model details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
resW=1280
resH=720
floating_model = (input_details[0]['dtype'] == np.float32)
class VideoStream:
"""Camera object that controls video streaming from the Picamera"""
def __init__(self,resolution=(640,480),framerate=30):
# Initialize the PiCamera and the camera image stream
self.stream = cv2.VideoCapture(0)
ret = self.stream.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
ret = self.stream.set(3,resolution[0])
ret = self.stream.set(4,resolution[1])
# Read first frame from the stream
(self.grabbed, self.frame) = self.stream.read()
# Variable to control when the camera is stopped
self.stopped = False
def start(self):
# Start the thread that reads frames from the video stream
Thread(target=self.update,args=()).start()
return self
def update(self):
# Keep looping indefinitely until the thread is stopped
while True:
# If the camera is stopped, stop the thread
if self.stopped:
# Close camera resources
self.stream.release()
return
# Otherwise, grab the next frame from the stream
(self.grabbed, self.frame) = self.stream.read()
def read(self):
# Return the most recent frame
return self.frame
def stop(self):
# Indicate that the camera and thread should be stopped
self.stopped = True
input_mean = 127.5
input_std = 127.5
def pet_detector(frame, detection_time):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (width, height))
input_data = np.expand_dims(frame_resized, axis=0)
if floating_model:
input_data = (np.float32(input_data) - input_mean) / input_std
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# Retrieve detection results
#boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
for i in range(len(scores)):
if ((scores[i] > min_conf_threshold) and (scores[i] <= 1.0)):
obj_name = labels[int(classes[i])]
#print(obj_name)
if obj_name == 'cat' or obj_name == 'teddy bear':
detection_time += 1
else:
detection_time = 0
print(detection_time)
return frame, detection_time
videostream = VideoStream(resolution=(resW, resH), framerate=30).start()
time.sleep(1)
detection_time = 0
while(True):
frame = videostream.read()
frame, detection_time = pet_detector(frame, detection_time)
cv2.imshow('Cat-detector', frame)
if cv2.waitKey(1) == ord('q'):
break
cv2.destroyAllWindows()
videostream.stop()