summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--requirements.txt1
-rwxr-xr-xyaksh/code_server.py10
-rw-r--r--yaksh/tests/__init__.py0
-rw-r--r--yaksh/tests/test_code_server.py117
4 files changed, 127 insertions, 1 deletions
diff --git a/requirements.txt b/requirements.txt
index 5438d8a..bea0017 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,3 +3,4 @@ mysql-python==1.2.5
django-taggit==0.18.1
pytz==2016.4
python-social-auth==0.2.19
+tornado
diff --git a/yaksh/code_server.py b/yaksh/code_server.py
index 2d8567e..a2cd08a 100755
--- a/yaksh/code_server.py
+++ b/yaksh/code_server.py
@@ -62,7 +62,7 @@ class CodeServer(object):
"""Calls relevant EvaluateCode class based on language to check the
answer code
"""
- code_evaluator = create_evaluator_instance(language,
+ code_evaluator = create_evaluator_instance(language,
test_case_type,
json_data,
in_dir
@@ -107,11 +107,13 @@ class ServerPool(object):
queue = Queue(maxsize=len(ports))
self.queue = queue
servers = []
+ self.processes = []
for port in ports:
server = CodeServer(port, queue)
servers.append(server)
p = Process(target=server.run)
p.start()
+ self.processes.append(p)
self.servers = servers
# Public Protocol ##########
@@ -140,6 +142,12 @@ class ServerPool(object):
server.register_instance(self)
server.serve_forever()
+ def stop(self):
+ """Stop all the code server processes.
+ """
+ for proc in self.processes:
+ proc.terminate()
+
###############################################################################
def main(args=None):
diff --git a/yaksh/tests/__init__.py b/yaksh/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/yaksh/tests/__init__.py
diff --git a/yaksh/tests/test_code_server.py b/yaksh/tests/test_code_server.py
new file mode 100644
index 0000000..18510c6
--- /dev/null
+++ b/yaksh/tests/test_code_server.py
@@ -0,0 +1,117 @@
+import json
+from multiprocessing import Process
+try:
+ from Queue import Queue
+except ImportError:
+ from queue import Queue
+from threading import Thread
+import unittest
+
+
+from yaksh.code_server import ServerPool, SERVER_POOL_PORT
+
+from yaksh import settings
+from yaksh.xmlrpc_clients import code_server
+
+
+class TestCodeServer(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ settings.code_evaluators['python']['standardtestcase'] = \
+ "yaksh.python_assertion_evaluator.PythonAssertionEvaluator"
+ ports = range(8001, 8006)
+ server_pool = ServerPool(ports=ports, pool_port=SERVER_POOL_PORT)
+ cls.server_pool = server_pool
+ cls.server_proc = p = Process(target=server_pool.run)
+ p.start()
+
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.server_pool.stop()
+ cls.server_proc.terminate()
+ settings.code_evaluators['python']['standardtestcase'] = \
+ "python_assertion_evaluator.PythonAssertionEvaluator"
+
+ def test_inifinite_loop(self):
+ # Given
+ testdata = {'user_answer': 'while True: pass',
+ 'test_case_data': [{'test_case':'assert 1==2'}]}
+
+ # When
+ result = code_server.run_code(
+ 'python', 'standardtestcase', json.dumps(testdata), ''
+ )
+
+ # Then
+ data = json.loads(result)
+ self.assertFalse(data['success'])
+ self.assertTrue('infinite loop' in data['error'])
+
+ def test_correct_answer(self):
+ # Given
+ testdata = {'user_answer': 'def f(): return 1',
+ 'test_case_data': [{'test_case':'assert f() == 1'}]}
+
+ # When
+ result = code_server.run_code(
+ 'python', 'standardtestcase', json.dumps(testdata), ''
+ )
+
+ # Then
+ data = json.loads(result)
+ self.assertTrue(data['success'])
+ self.assertEqual(data['error'], 'Correct answer')
+
+ def test_wrong_answer(self):
+ # Given
+ testdata = {'user_answer': 'def f(): return 1',
+ 'test_case_data': [{'test_case':'assert f() == 2'}]}
+
+ # When
+ result = code_server.run_code(
+ 'python', 'standardtestcase', json.dumps(testdata), ''
+ )
+
+ # Then
+ data = json.loads(result)
+ self.assertFalse(data['success'])
+ self.assertTrue('AssertionError' in data['error'])
+
+ def test_multiple_simultaneous_hits(self):
+ # Given
+ results = Queue()
+
+ def run_code():
+ """Run an infinite loop."""
+ testdata = {'user_answer': 'while True: pass',
+ 'test_case_data': [{'test_case':'assert 1==2'}]}
+ result = code_server.run_code(
+ 'python', 'standardtestcase', json.dumps(testdata), ''
+ )
+ results.put(json.loads(result))
+
+ N = 5
+ # When
+ import time
+ threads = []
+ for i in range(N):
+ t = Thread(target=run_code)
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ if t.isAlive():
+ t.join()
+
+ # Then
+ self.assertEqual(results.qsize(), N)
+ for i in range(N):
+ data = results.get()
+ self.assertFalse(data['success'])
+ self.assertTrue('infinite loop' in data['error'])
+
+
+if __name__ == '__main__':
+ unittest.main()