summaryrefslogtreecommitdiff
path: root/yaksh/python_stdio_evaluator.py
diff options
context:
space:
mode:
Diffstat (limited to 'yaksh/python_stdio_evaluator.py')
-rw-r--r--yaksh/python_stdio_evaluator.py66
1 files changed, 40 insertions, 26 deletions
diff --git a/yaksh/python_stdio_evaluator.py b/yaksh/python_stdio_evaluator.py
index a8c797d..ec5ed71 100644
--- a/yaksh/python_stdio_evaluator.py
+++ b/yaksh/python_stdio_evaluator.py
@@ -1,12 +1,16 @@
import sys
from contextlib import contextmanager
-
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
+try:
+ from itertools import zip_longest
+except ImportError:
+ from itertools import izip_longest as zip_longest
+
# Local imports
from .file_utils import copy_files, delete_files
from .base_evaluator import BaseEvaluator
@@ -21,35 +25,42 @@ def redirect_stdout():
finally:
sys.stdout = old_target # restore to the previous value
-
-def _show_expected_given(expected, given):
- return "Expected:\n{0}\nGiven:\n{1}\n".format(expected, given)
-
-
-def compare_outputs(given, expected):
- given_lines = given.splitlines()
+def _incorrect_user_lines(exp_lines, user_lines):
+ err_line_no = []
+ for i, (expected_line, user_line) in enumerate(zip_longest(exp_lines, user_lines)):
+ if not user_line or not expected_line:
+ err_line_no.append(i)
+ else:
+ if user_line.strip() != expected_line.strip():
+ err_line_no.append(i)
+ return err_line_no
+
+def compare_outputs(expected_output, user_output,given_input=None):
+ given_lines = user_output.splitlines()
+ exp_lines = expected_output.splitlines()
+ # if given_input:
+ # given_input = given_input.splitlines()
+ msg = {"given_input":given_input,
+ "expected_output": exp_lines,
+ "user_output":given_lines
+ }
ng = len(given_lines)
- exp_lines = expected.splitlines()
ne = len(exp_lines)
if ng != ne:
- msg = "ERROR: Got {0} lines in output, we expected {1}.\n".format(
- ng, ne
- )
- msg += _show_expected_given(expected, given)
+ err_line_no = _incorrect_user_lines(exp_lines, given_lines)
+ msg["error_no"] = err_line_no
+ msg["error"] = "We had expected {0} number of lines. We got {1} number of lines.".format(ne, ng)
return False, msg
else:
- for i, (given_line, expected_line) in \
- enumerate(zip(given_lines, exp_lines)):
- if given_line.strip() != expected_line.strip():
- msg = "ERROR:\n"
- msg += _show_expected_given(expected, given)
- msg += "\nError in line %d of output.\n" % (i+1)
- msg += "Expected line {0}:\n{1}\nGiven line {0}:\n{2}\n"\
- .format(
- i+1, expected_line, given_line
- )
- return False, msg
- return True, "Correct answer."
+ err_line_no = _incorrect_user_lines(exp_lines, given_lines)
+ if err_line_no:
+ msg["error_no"] = err_line_no
+ msg["error"] = "Line number(s) {0} did not match."\
+ .format(", ".join(map(str,[x+1 for x in err_line_no])))
+ return False, msg
+ else:
+ msg["error"] = "Correct answer"
+ return True, msg
class PythonStdIOEvaluator(BaseEvaluator):
@@ -89,5 +100,8 @@ class PythonStdIOEvaluator(BaseEvaluator):
def check_code(self):
mark_fraction = self.weight
- success, err = compare_outputs(self.output_value, self.expected_output)
+ success, err = compare_outputs(self.expected_output,
+ self.output_value,
+ self.expected_input
+ )
return success, err, mark_fraction